Source code for pytorch_lightning_spells.utils

import math
import warnings
from typing import Sequence, Union, Iterable, Optional, List

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import numpy as np

Layer = Union[torch.nn.Module, torch.nn.ModuleList]


[docs] class EMATracker: """Keeps the exponential moving average for a single series. Args: alpha (float, optional): the weight of the new value, by default 0.05 Examples: >>> tracker = EMATracker(0.1) >>> tracker.update(1.) >>> tracker.value 1.0 >>> tracker.update(2.) >>> tracker.value # 1 * 0.9 + 2 * 0.1 1.1 >>> tracker.update(float('nan')) # this won't have any effect >>> tracker.value 1.1 """ def __init__(self, alpha: float = 0.05): super().__init__() self.alpha = alpha self._value: Optional[Union[float, torch.Tensor]] = None
[docs] def update(self, new_value: Union[float, torch.Tensor]): """Adds a new value to the tracker. It will ignore NaNs and raise a warning in those cases. Args: new_value (Union[float, torch.Tensor]): the incoming value. """ if isinstance(new_value, torch.Tensor): new_value = new_value.detach() if torch.isnan(new_value): warnings.warn("NaN encountered as training loss!") return elif math.isnan(new_value): warnings.warn("NaN encountered as training loss!") return if self._value is None: self._value = new_value else: self._value = new_value * self.alpha + self._value * (1 - self.alpha)
@property def value(self): """The smoothed value.""" return self._value
[docs] def count_parameters(parameters: Iterable[Union[torch.Tensor, Parameter]]): """Count the number of parameters Args: parameters (Iterable[Union[torch.Tensor, Parameter]]): parameters you want to count. Returns: int: the number of parameters counted. Example: >>> count_parameters([torch.rand(100), torch.rand(10)]) 110 >>> count_parameters([torch.rand(100, 2), torch.rand(10, 3)]) 230 """ return int(np.sum(list(p.numel() for p in parameters)))
# ----------------------------------- # Layer freezing from fast.ai v1 # ----------------------------------- def _children(m): return m if isinstance(m, (list, tuple)) else list(m.children()) def _set_trainable_attr(m, b): if isinstance(m, torch.nn.Parameter): m.requires_grad = b return m.trainable = b for p in m.parameters(): p.requires_grad = b def _apply_leaf(m, f): if isinstance(m, (torch.nn.Module, torch.nn.Parameter)): f(m) if isinstance(m, torch.nn.Parameter): return c = _children(m) if len(c) > 0: for l in c: _apply_leaf(l, f)
[docs] def set_trainable(layer: Layer, trainable: bool): """Freeze or unfreeze all parameters in the layer. Args: layer (Union[torch.nn.Module, torch.nn.ModuleList]): the target layer trainable (bool): True to unfreeze; False to freeze Example: >>> model = nn.Sequential(nn.Linear(10, 100), nn.Linear(100, 1)) >>> model[0].weight.requires_grad True >>> set_trainable(model, False) >>> model[0].weight.requires_grad False >>> set_trainable(model, True) >>> model[0].weight.requires_grad True """ _apply_leaf(layer, lambda m: _set_trainable_attr(m, trainable))
[docs] def freeze_layers(layer_groups: Sequence[Layer], freeze_flags: Sequence[bool]): """Freeze or unfreeze groups of layers Args: layer_groups (Sequence[Layer]): the target lists of layers freeze_flags (Sequence[bool]): the corresponding trainable flags .. warning:: The value in `freeze_flag` has the opposite meaning as in `trainable` of `set_trainable`. Set True to freeze; False to unfreeze. Examples: >>> model = nn.Sequential(nn.Linear(10, 100), nn.Linear(100, 1)) >>> freeze_layers([model[0], model[1]], [True, False]) >>> model[0].weight.requires_grad False >>> model[1].weight.requires_grad True >>> freeze_layers([model[0], model[1]], [False, True]) >>> model[0].weight.requires_grad True >>> model[1].weight.requires_grad False """ assert len(freeze_flags) == len(layer_groups) for layer, flag in zip(layer_groups, freeze_flags): set_trainable(layer, not flag)
# ----------------------------------------------------------- # Separate BatchNorm2d and GroupNorm paremeters from others # -----------------------------------------------------------
[docs] def separate_parameters( module: Union[Parameter, nn.Module, List[nn.Module]], skip_list: Sequence[str] = ("bias",), ): """Separate BatchNorm2d, GroupNorm, and LayerNorm paremeters from others Args: module (Union[Parameter, nn.Module, List[nn.Module]]): to be separated. Returns: Tuple[List[Parameter], List[Parameter]]: lists of decay and no-decay parameters. Example: >>> model = nn.Sequential(nn.Linear(100, 10, bias=True), nn.BatchNorm1d(10)) >>> _ = nn.init.constant_(model[0].weight, 2.) >>> _ = nn.init.constant_(model[0].bias, 1.) >>> _ = nn.init.constant_(model[1].weight, 1.) >>> _ = nn.init.constant_(model[1].bias, 1.) >>> model[0].weight.data.sum().item() 2000.0 >>> model[0].bias.data.sum().item() 10.0 >>> model[1].weight.data.sum().item() 10.0 >>> model[1].bias.data.sum().item() 10.0 >>> decay, no_decay = separate_parameters(model) # separate the parameters >>> np.sum([x.sum().detach().numpy() for x in decay]) # nn.Linear 2000.0 >>> np.sum([x.sum().detach().numpy() for x in no_decay]) # nn.BatchNorm1d 30.0 >>> optimizer = torch.optim.AdamW([{ ... "params": decay, "weight_decay": 0.1 ... }, { ... "params": no_decay, "weight_decay": 0 ... }], lr = 1e-3) """ decay, no_decay = [], [] if isinstance(module, list): for entry in module: tmp = separate_parameters(entry, skip_list=skip_list) decay.extend(tmp[0]) no_decay.extend(tmp[1]) elif isinstance(module, torch.nn.Parameter): no_decay.append(module) else: if isinstance(module, torch.nn.Identity): return decay, no_decay if isinstance( module, ( torch.nn.GroupNorm, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.BatchNorm1d, ), ): no_decay.extend(list(module.parameters())) else: for module_name, submodule in module.named_children(): if module_name in skip_list: no_decay.extend(list(submodule.parameters())) else: tmp = separate_parameters(submodule, skip_list=skip_list) decay.extend(tmp[0]) no_decay.extend(tmp[1]) for name, parameter in module.named_parameters(): if "." in name: continue if name in skip_list: no_decay.append(parameter) else: decay.append(parameter) return decay, no_decay