Source code for pytorch_lightning_spells.lr_schedulers

import weakref
from functools import wraps
from typing import Sequence, Union

import torch
import numpy as np
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
from torch.optim import Optimizer

__all__ = [
    "BaseLRScheduler",
    "LinearLR",
    "ExponentialLR",
    "MultiStageScheduler",
    "CosineAnnealingScheduler",
]


[docs] class BaseLRScheduler(_LRScheduler):
[docs] def switch_optimizer(self, optimizer): self.optimizer = optimizer self.optimizer._step_count = self._step_count
[docs] def clear_optimizer(self): self.optimizer = None
[docs] class CosineAnnealingScheduler(CosineAnnealingLR, BaseLRScheduler): pass
[docs] class LinearLR(BaseLRScheduler): """Linearly increases or decrease the learning rate between two boundaries over a number of iterations. """
[docs] def __init__( self, optimizer: torch.optim.Optimizer, min_lr_ratio: float, total_epochs: float, upward: bool = True, last_epoch: int = -1, ): """Initialize a scheduler. Args: optimizer (Union[torch.optim.Optimizer, apex.fp16_utils.fp16_optimizer.FP16_Optimizer]): min_lr_ratio: min_lr_ratio * base_lr will be the starting learning rate. total_epochs: the total number of "steps" in this run. upward: whether the learning rate goes up or down. Defaults to True. last_epoch: the index of last epoch. Defaults to -1. """ assert min_lr_ratio < 1 self.upward = upward self.min_lr_ratio = min_lr_ratio self.total_epochs = total_epochs - 1 # starts from zero super(LinearLR, self).__init__(optimizer, last_epoch)
[docs] def get_lr(self): current_epoch = self.last_epoch if self.upward: progress = 1 - current_epoch / self.total_epochs # 1 to 0 else: progress = current_epoch / self.total_epochs # 0 to 1 # safety measure progress = max(min(progress, 1.0), 0.0) return [ base_lr - progress * (base_lr - self.min_lr_ratio * base_lr) for base_lr in self.base_lrs ]
[docs] class ExponentialLR(BaseLRScheduler): """Exponentially increases the learning rate between two boundaries over a number of iterations. Mainly used by LR finders. """
[docs] def __init__(self, optimizer, min_lr_ratio, total_epochs, last_epoch=-1): """Initialize a scheduler. Parameters ---------- optimizer : Union[torch.optim.Optimizer, apex.fp16_utils.fp16_optimizer.FP16_Optimizer] min_lr_ratio : float min_lr_ratio * base_lr will be the starting learning rate. total_epochs : int the total number of "steps" in this run. last_epoch : int, optional the index of last epoch, by default -1. """ assert min_lr_ratio < 1 self.min_lr_ratio = min_lr_ratio self.total_epochs = total_epochs - 1 # start from zero super(ExponentialLR, self).__init__(optimizer, last_epoch)
[docs] def get_lr(self): current_epoch = self.last_epoch + 1 progress = 1 - current_epoch / self.total_epochs # 1 to 0 return [base_lr * (self.min_lr_ratio) ** progress for base_lr in self.base_lrs]
[docs] class MultiStageScheduler(_LRScheduler):
[docs] def __init__( self, schedulers: Sequence, start_at_epochs: Sequence[int], last_epoch: int = -1 ): assert len(schedulers) == len(start_at_epochs) schedulers, start_at_epochs = (np.array(schedulers), np.array(start_at_epochs)) # sort starting epochs in descending order idx = np.flip(np.argsort(start_at_epochs)) self.schedulers = schedulers[idx] self.start_at_epochs = start_at_epochs[idx] self.last_epoch = last_epoch # Explicitly run step(). Otherwise the initial LR will be initialized by the last sub-scheduler self.step(0) self.optimizer = self.schedulers[0].optimizer
[docs] def step(self, epoch=None): if epoch is None: self.last_epoch = self.last_epoch + 1 else: self.last_epoch = epoch - 1 for scheduler, starting_epoch in zip(self.schedulers, self.start_at_epochs): if self.last_epoch + 1 >= starting_epoch: scheduler.last_epoch = self.last_epoch - starting_epoch return scheduler.step()
[docs] def switch_optimizer(self, optimizer): for scheduler in self.schedulers: scheduler.optimizer = optimizer scheduler.optimizer._step_count = scheduler._step_count self.optimizer = self.schedulers[0].optimizer
[docs] def clear_optimizer(self): for scheduler in self.schedulers: scheduler.optimizer = None self.optimizer = None
[docs] def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ results = { key: value for key, value in self.__dict__.items() if key != "optimizer" } del results["schedulers"] for i, scheduler in enumerate(self.schedulers): results["schedulers_" + str(i)] = scheduler.state_dict() return results
[docs] def load_state_dict(self, state_dict): """Loads the schedulers state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ for i, scheduler in enumerate(self.schedulers): scheduler.load_state_dict(state_dict["schedulers_" + str(i)]) del state_dict["schedulers_" + str(i)] self.__dict__.update(state_dict) # Manually bind optimizer to make sure self.switch_optimizer(self.optimizer)