pytorch_lightning_spells package



A boilerplate module with some sensible defaults.

class pytorch_lightning_spells.BaseModule(ema_alpha=0.02)[source]

Bases: pytorch_lightning.core.lightning.LightningModule

A boilerplate module with some sensible defaults.

It logs the exponentially smoothed training losses, and the validation metrics.

You need to implement the training_step() and validation_step() methods. Please refer to the training_step_end() and validation_step_end() methods for the expected output format.


ema_alpha (float, optional) – the weight of the new training loss for the EMA aggregator. Defaults to 0.02.


>>> class TestModule(BaseModule):
...     def __init__(self):
...         super().__init__(ema_alpha=0.02)
...     def training_step(self, batch, batch_idx):
...         return {
...             "loss": torch.tensor(1),
...             "log": batch_idx % self.trainer.accumulate_grad_batches == 0
...         }
...     def validation_step(self, batch, batch_idx):
...         return {
...             'loss': torch.tensor(1),
...             'preds': torch.ones_like(batch[1]),
...             'target': batch[1]
...         }
>>> module = TestModule()

Removes v_num from the progress bar.

test_step(batch, batch_idx)[source]

Simply defer to .validation_step() method.


Basically the same as .validation_step_ends() method, but with a different prefix.

training: bool

This method logs the training loss for you.

It follows the log_every_n_steps attribute of the associated Trainer.

The output from .validation_step() method must contains these two entries:

  1. loss: the training loss.

  2. log: a boolean value indicating if this is a loggable step.

A loggable step is a step that involves an optimizer step. The opposite is a step that only updates the gradients but not the parameters(e.g., in gradient accumulation).


outputs (Dict) – the output from .training_step() method.


This method logs the validation loss and metrics for you.

The output from .validation_step() method must contains these three entries:

  1. loss: the validation loss.

  2. pred: the predicted labels or values.

  3. target: the ground truth lables or values.


outputs (Dict) – the output from .validation_step() method.