pytorch_lightning_spells package
Classes:
|
A boilerplate module with some sensible defaults. |
- class pytorch_lightning_spells.BaseModule(ema_alpha=0.02)[source]
Bases:
LightningModule
A boilerplate module with some sensible defaults.
It logs the exponentially smoothed training losses, and the validation metrics.
You need to implement the training_step() and validation_step() methods. Please refer to the training_step_end() and validation_step_end() methods for the expected output format.
- Parameters:
ema_alpha (float, optional) – the weight of the new training loss for the EMA aggregator. Defaults to 0.02.
Example
>>> class TestModule(BaseModule): ... def __init__(self): ... super().__init__(ema_alpha=0.02) ... ... def training_step(self, batch, batch_idx): ... return { ... "loss": torch.tensor(1), ... "log": batch_idx % self.trainer.accumulate_grad_batches == 0 ... } ... ... def validation_step(self, batch, batch_idx): ... return { ... 'loss': torch.tensor(1), ... 'preds': torch.ones_like(batch[1]), ... 'target': batch[1] ... } ... >>> module = TestModule()
- test_step_end(outputs)[source]
Basically the same as .validation_step_ends() method, but with a different prefix.
- training_step_end(outputs)[source]
This method logs the training loss for you.
It follows the log_every_n_steps attribute of the associated Trainer.
The output from .validation_step() method must contains these two entries:
loss: the training loss.
log: a boolean value indicating if this is a loggable step.
A loggable step is a step that involves an optimizer step. The opposite is a step that only updates the gradients but not the parameters(e.g., in gradient accumulation).
- Parameters:
outputs (Dict) – the output from .training_step() method.
- validation_step_end(outputs)[source]
This method logs the validation loss and metrics for you.
The output from .validation_step() method must contains these three entries:
loss: the validation loss.
pred: the predicted labels or values.
target: the ground truth lables or values.
- Parameters:
outputs (Dict) – the output from .validation_step() method.
Submodules
- pytorch_lightning_spells.callbacks module
- pytorch_lightning_spells.cutmix_utils module
- pytorch_lightning_spells.loggers module
- pytorch_lightning_spells.losses module
- pytorch_lightning_spells.lr_schedulers module
- pytorch_lightning_spells.metrics module
- pytorch_lightning_spells.optimizers module
- pytorch_lightning_spells.samplers module
- pytorch_lightning_spells.snapmix_utils module
- pytorch_lightning_spells.utils module