Welcome to PyTorch Lightning Spells’ documentation!

This package contains some useful plugins for PyTorch Lightning. Many of those are based on others’ implementations; I just made some adaptations to make it work with PyTorch Lightning. Please let me know (ceshine at veritable.pw) if you feel any original authors are not credited adequately in my code and documentation.

The following is a categorized list of available classes and functions:




The following three callbacks require MixupSoftmaxLoss to be used. The target 1-D tensor will be converted to a 2-D one after the callback. The MixupSoftmaxLoss will calculate the correct cross-entropy loss from the 2-D tensor.

A notebook is available on Kaggle demonstrating the effect of MixUp, CutMix, and SnapMix.

RandomAugmentationChoiceCallback randomly picks one of the given callbacks for each batch. It also supports a no-op warmup period and setting a no-op probability.


The training and inference speed of NLP models can be improved by sorting the input examples by their lengths. This reduces the average number of padding tokens per batch (i.e., the input matrices are smaller). Two samplers are provided to achieve this goal:

  • SortSampler: Suitable for validation and test datasets, where the order of input examples doesn’t matter.

  • SortishSampler: Suitable for training datasets, where we want to add some randomness in the order of input examples between epochs.


  • RAdam

  • set_trainable: a function that freezes or unfreezes a layer group (a nn.Module or nn.ModuleList).

  • freeze_layers: a function that freezes or unfreezes a list of layer groups.


  • Lookahead: A PyTorch optimizer wrapper to implement the lookahead mechanism.

  • LookaheadCallback: A callback that switches the model parameters to the slow ones before a validation round starts and switches back to the fast ones after it ends.

  • LookaheadModelCheckpoint: A combination of LookaheadCallback and ModelCheckpoint, so the slow parameters are kept in the checkpoints instead of the fast ones.

Learning Rate Schedulers


PyTorch Lightning did not implement metrics that require the entire dataset to have predictions (e.g., AUC, the Spearman correlation). They do have implemented some of them now in the new TorchMetrics package.


These metrics require the entire set of labels and predictions to be stored in memory. You might encounter out-of-memory errors if your target tensor is relatively large (e.g., in semantic segmentation tasks) or your validation/test dataset is too large. You’ll have to use some approximation techniques in those cases.


  • BaseModule: A boilerplate Lightning Module to be extended upon.

  • ScreenLogger: A logger that prints metrics to the screen.

  • TelegramCallback: Sent a Telegram message to you when the training starts, ends, and a validation round is finished.

  • EMATracker: A exponential moving average aggregator.

  • count_parameters: A function that returns the total number of parameters in a model.

  • separate_parameters: A function that split the parameters of a module into two groups (BatchNorm/GroupNorm/LayerNorm and others), so you can use weight decay on only one of them.

Indices and tables