Welcome to PyTorch Lightning Spells’ documentation!
This package contains some useful plugins for PyTorch Lightning. Many of those are based on others’ implementations; I just made some adaptations to make it work with PyTorch Lightning. Please let me know (ceshine at veritable.pw) if you feel any original authors are not credited adequately in my code and documentation.
The following is a categorized list of available classes and functions:
CV
Augmentation
Warning
The following three callbacks require MixupSoftmaxLoss to be used. The target 1-D tensor will be converted to a 2-D one after the callback. The MixupSoftmaxLoss will calculate the correct cross-entropy loss from the 2-D tensor.
A notebook is available on Kaggle demonstrating the effect of MixUp, CutMix, and SnapMix.
RandomAugmentationChoiceCallback randomly picks one of the given callbacks for each batch. It also supports a no-op warmup period and setting a no-op probability.
NLP
The training and inference speed of NLP models can be improved by sorting the input examples by their lengths. This reduces the average number of padding tokens per batch (i.e., the input matrices are smaller). Two samplers are provided to achieve this goal:
SortSampler: Suitable for validation and test datasets, where the order of input examples doesn’t matter.
SortishSampler: Suitable for training datasets, where we want to add some randomness in the order of input examples between epochs.
Optimization
set_trainable: a function that freezes or unfreezes a layer group (a nn.Module or nn.ModuleList).
freeze_layers: a function that freezes or unfreezes a list of layer groups.
Lookahead
Lookahead: A PyTorch optimizer wrapper to implement the lookahead mechanism.
LookaheadCallback: A callback that switches the model parameters to the slow ones before a validation round starts and switches back to the fast ones after it ends.
LookaheadModelCheckpoint: A combination of LookaheadCallback and ModelCheckpoint, so the slow parameters are kept in the checkpoints instead of the fast ones.
Learning Rate Schedulers
MultiStageScheduler: Allows you to combine several schedulers (e.g., linear warmup and cosine decay).
LinearLR: Can be used to achieve both linear warmups and linear decays.
Metrics
PyTorch Lightning did not implement metrics that require the entire dataset to have predictions (e.g., AUC, the Spearman correlation). They do have implemented some of them now in the new TorchMetrics package.
GlobalMetric: Extends this class to create new metrics.
Warning
These metrics require the entire set of labels and predictions to be stored in memory. You might encounter out-of-memory errors if your target tensor is relatively large (e.g., in semantic segmentation tasks) or your validation/test dataset is too large. You’ll have to use some approximation techniques in those cases.
Utility
BaseModule: A boilerplate Lightning Module to be extended upon.
ScreenLogger: A logger that prints metrics to the screen.
TelegramCallback: Sent a Telegram message to you when the training starts, ends, and a validation round is finished.
EMATracker: A exponential moving average aggregator.
count_parameters: A function that returns the total number of parameters in a model.
separate_parameters: A function that split the parameters of a module into two groups (BatchNorm/GroupNorm/LayerNorm and others), so you can use weight decay on only one of them.
Contents:
- Welcome to PyTorch Lightning Spells’ documentation!
- Indices and tables
- pytorch_lightning_spells package
- Submodules
- pytorch_lightning_spells.callbacks module
- pytorch_lightning_spells.cutmix_utils module
- pytorch_lightning_spells.loggers module
- pytorch_lightning_spells.losses module
- pytorch_lightning_spells.lr_schedulers module
- pytorch_lightning_spells.metrics module
- pytorch_lightning_spells.optimizers module
- pytorch_lightning_spells.samplers module
- pytorch_lightning_spells.snapmix_utils module
- pytorch_lightning_spells.utils module
- Submodules