pytorch_lightning_spells.callbacks module

Classes:

CutMixCallback([alpha, softmax_target, minmax])

Callback that perform CutMix augmentation on the input batch.

LookaheadCallback()

Switch to the slow weights before evaluation and switch back after.

LookaheadModelCheckpoint([dirpath, ...])

Combines LookaheadCallback and ModelCheckpoint

MixUpCallback([alpha, softmax_target])

Callback that perform MixUp augmentation on the input batch.

RandomAugmentationChoiceCallback(callbacks, p)

Randomly pick an augmentation callback to use for each batch.

SnapMixCallback(model, image_size[, half, ...])

Callback that perform SnapMix augmentation on the input batch.

TelegramCallback(token, chat_id, name[, ...])

A Telegram notification callback

class pytorch_lightning_spells.callbacks.CutMixCallback(alpha=0.4, softmax_target=False, minmax=None)[source]

Bases: Callback

Callback that perform CutMix augmentation on the input batch.

Assumes the first dimension is batch.

Reference: rwightman/pytorch-image-models/

Parameters:
  • alpha (float) –

  • softmax_target (bool) –

  • minmax (Tuple[float, float] | None) –

on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]
class pytorch_lightning_spells.callbacks.LookaheadCallback[source]

Bases: Callback

Switch to the slow weights before evaluation and switch back after.

on_validation_end(trainer, pl_module)[source]
on_validation_start(trainer, pl_module)[source]
class pytorch_lightning_spells.callbacks.LookaheadModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)[source]

Bases: ModelCheckpoint

Combines LookaheadCallback and ModelCheckpoint

Parameters:
  • dirpath (str | Path | None) –

  • filename (str | None) –

  • monitor (str | None) –

  • verbose (bool) –

  • save_last (bool | None) –

  • save_top_k (int) –

  • save_weights_only (bool) –

  • mode (str) –

  • auto_insert_metric_name (bool) –

  • every_n_train_steps (int | None) –

  • train_time_interval (timedelta | None) –

  • every_n_epochs (int | None) –

  • save_on_train_epoch_end (bool | None) –

  • enable_version_counter (bool) –

on_validation_end(trainer, pl_module)[source]
on_validation_start(trainer, pl_module)[source]
class pytorch_lightning_spells.callbacks.MixUpCallback(alpha=0.4, softmax_target=False)[source]

Bases: Callback

Callback that perform MixUp augmentation on the input batch.

Assumes the first dimension is batch.

Works best with pytorch_lightning_spells.losses.MixupSoftmaxLoss

Reference: Fast.ai’s implementation

Parameters:
  • alpha (float) –

  • softmax_target (bool) –

on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]
class pytorch_lightning_spells.callbacks.RandomAugmentationChoiceCallback(callbacks, p, no_op_warmup=0, no_op_prob=0)[source]

Bases: Callback

Randomly pick an augmentation callback to use for each batch.

Also supports no-op warmups and no-op probability.

Parameters:
  • callbacks (Sequence[Callback]) – A sequence of calbacks to choose from.

  • p (Sequence[Callback]) – A sequence of probabilities for the callbacks.

  • no_op_warmup (int, optional) – the number of initial steps that should not have any augmentation. Defaults to 0.

  • no_op_prob (float, optional) – the probability of a step that has no augmentation. Defaults to 0.

get_callback()[source]
on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]
class pytorch_lightning_spells.callbacks.SnapMixCallback(model, image_size, half=False, minmax=None, cutmix_bbox=False, alpha=0.4, softmax_target=True)[source]

Bases: Callback

Callback that perform SnapMix augmentation on the input batch.

Reference: Shaoli-Huang/SnapMix

Warning

  1. Requires the model to have implemented `extract_features` and `get_fc` methods.

  2. Can only run in CUDA-enabled environments.

Parameters:
  • half (bool) –

  • minmax (Tuple[float, float] | None) –

  • cutmix_bbox (bool) –

  • alpha (float) –

  • softmax_target (bool) –

on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]
class pytorch_lightning_spells.callbacks.TelegramCallback(token, chat_id, name, report_evals=False)[source]

Bases: Callback

A Telegram notification callback

Reference: huggingface/knockknock

Parameters:
  • token (str) –

  • chat_id (int) –

  • name (str) –

  • report_evals (bool) –

DATE_FORMAT = '%Y-%m-%d %H:%M:%d'
on_train_end(trainer, pl_module)[source]
Parameters:
  • trainer (Trainer) –

  • pl_module (LightningModule) –

on_train_start(trainer, pl_module)[source]
Parameters:
  • trainer (Trainer) –

  • pl_module (LightningModule) –

on_validation_end(trainer, pl_module)[source]
Parameters:
  • trainer (Trainer) –

  • pl_module (LightningModule) –

send_message(text)[source]