pytorch_lightning_spells.callbacks module
Classes:
|
Callback that perform CutMix augmentation on the input batch. |
Switch to the slow weights before evaluation and switch back after. |
|
|
Combines LookaheadCallback and ModelCheckpoint |
|
Callback that perform MixUp augmentation on the input batch. |
|
Randomly pick an augmentation callback to use for each batch. |
|
Callback that perform SnapMix augmentation on the input batch. |
|
A Telegram notification callback |
- class pytorch_lightning_spells.callbacks.CutMixCallback(alpha=0.4, softmax_target=False, minmax=None)[source]
Bases:
Callback
Callback that perform CutMix augmentation on the input batch.
Assumes the first dimension is batch.
Reference: rwightman/pytorch-image-models/
- Parameters:
alpha (float) –
softmax_target (bool) –
minmax (Tuple[float, float] | None) –
- class pytorch_lightning_spells.callbacks.LookaheadCallback[source]
Bases:
Callback
Switch to the slow weights before evaluation and switch back after.
- class pytorch_lightning_spells.callbacks.LookaheadModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)[source]
Bases:
ModelCheckpoint
Combines LookaheadCallback and ModelCheckpoint
- Parameters:
dirpath (str | Path | None) –
filename (str | None) –
monitor (str | None) –
verbose (bool) –
save_last (bool | None) –
save_top_k (int) –
save_weights_only (bool) –
mode (str) –
auto_insert_metric_name (bool) –
every_n_train_steps (int | None) –
train_time_interval (timedelta | None) –
every_n_epochs (int | None) –
save_on_train_epoch_end (bool | None) –
enable_version_counter (bool) –
- class pytorch_lightning_spells.callbacks.MixUpCallback(alpha=0.4, softmax_target=False)[source]
Bases:
Callback
Callback that perform MixUp augmentation on the input batch.
Assumes the first dimension is batch.
Works best with pytorch_lightning_spells.losses.MixupSoftmaxLoss
Reference: Fast.ai’s implementation
- Parameters:
alpha (float) –
softmax_target (bool) –
- class pytorch_lightning_spells.callbacks.RandomAugmentationChoiceCallback(callbacks, p, no_op_warmup=0, no_op_prob=0)[source]
Bases:
Callback
Randomly pick an augmentation callback to use for each batch.
Also supports no-op warmups and no-op probability.
- Parameters:
callbacks (Sequence[Callback]) – A sequence of calbacks to choose from.
p (Sequence[Callback]) – A sequence of probabilities for the callbacks.
no_op_warmup (int, optional) – the number of initial steps that should not have any augmentation. Defaults to 0.
no_op_prob (float, optional) – the probability of a step that has no augmentation. Defaults to 0.
- class pytorch_lightning_spells.callbacks.SnapMixCallback(model, image_size, half=False, minmax=None, cutmix_bbox=False, alpha=0.4, softmax_target=True)[source]
Bases:
Callback
Callback that perform SnapMix augmentation on the input batch.
Reference: Shaoli-Huang/SnapMix
Warning
Requires the model to have implemented `extract_features` and `get_fc` methods.
Can only run in CUDA-enabled environments.
- Parameters:
half (bool) –
minmax (Tuple[float, float] | None) –
cutmix_bbox (bool) –
alpha (float) –
softmax_target (bool) –
- class pytorch_lightning_spells.callbacks.TelegramCallback(token, chat_id, name, report_evals=False)[source]
Bases:
Callback
A Telegram notification callback
Reference: huggingface/knockknock
- Parameters:
token (str) –
chat_id (int) –
name (str) –
report_evals (bool) –
- DATE_FORMAT = '%Y-%m-%d %H:%M:%d'
- on_train_end(trainer, pl_module)[source]
- Parameters:
trainer (Trainer) –
pl_module (LightningModule) –
- on_train_start(trainer, pl_module)[source]
- Parameters:
trainer (Trainer) –
pl_module (LightningModule) –