pytorch_lightning_spells.optimizers module

Classes:

Lookahead(optimizer[, alpha, k, ...])

Lookahead Wrapper

RAdam(params[, lr, betas, eps, ...])

RAdam optimizer, a theoretically sound variant of Adam.

class pytorch_lightning_spells.optimizers.Lookahead(optimizer, alpha=0.5, k=6, pullback_momentum='none')[source]

Bases: torch.optim.optimizer.Optimizer

Lookahead Wrapper

Works best with LookaheadCallback or LookaheadModelCheckpoint.

Parameters
  • optimizer (Optimizer) – The inner optimizer.

  • alpha (float, optional) – The linear interpolation factor. 1.0 recovers the inner optimizer. Defaults to 0.5.

  • k (int, optional) – The number of lookahead steps. Defaults to 6.

  • pullback_momentum (str, optional) – Change to inner optimizer momentum on interpolation update. Defaults to “none”.

Note

Currently pullback_momentum only supports SGD optimizers with momentum.

Raises

ValueError – Invalid slow update rate or invalid lookahead steps

Parameters
  • optimizer (torch.optim.optimizer.Optimizer) –

  • alpha (float) –

  • k (int) –

  • pullback_momentum (str) –

Example

>>> model = torch.nn.Linear(10, 1)
>>> optimizer = Lookahead(
...     torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.1),
...     alpha=0.5, k=6, pullback_momentum="pullback")
...
>>> for _ in range(10):
...     optimizer.zero_grad()
...     loss = model(torch.rand(10))
...     loss.backward()
...     optimizer.step()
...
load_state_dict(state_dict)[source]
state_dict()[source]
step(closure=None)[source]

Performs a single Lookahead optimization step.

zero_grad()[source]
class pytorch_lightning_spells.optimizers.RAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, degenerated_to_sgd=True)[source]

Bases: torch.optim.optimizer.Optimizer

RAdam optimizer, a theoretically sound variant of Adam.

Source: LiyuanLucasLiu/RAdam

Under Apache License 2.0

step(closure=None)[source]