Source code for pytorch_lightning_spells.snapmix_utils

"""SnapMix Utility Functions

Reference: `Shaoli-Huang/SnapMix <https://github.com/Shaoli-Huang/SnapMix/>`_
"""
import torch
import torch.nn.functional as F


[docs] def get_spm(input_tensor, target, model, image_size, half: bool = False): bs = input_tensor.size(0) with torch.no_grad(): fc = model.get_fc() if half: input_tensor = input_tensor.half() fms = model.extract_features(input_tensor.to(fc.weight.device)) weight, bias = fc.weight, fc.bias weight = weight.view(weight.size(0), weight.size(1), 1, 1) activations = F.conv2d(fms, weight, bias=bias) tmp = [] for i in range(bs): target_activation = activations[i, target[i]] tmp.append(target_activation) target_activation_map = torch.stack(tmp) target_activation_map = target_activation_map.unsqueeze(1) target_activation_map = F.interpolate( target_activation_map, image_size, mode='bilinear', align_corners=False) target_activation_map = target_activation_map.squeeze(1) for i in range(bs): target_activation_map[i] -= target_activation_map[i].min() target_activation_map[i] /= target_activation_map[i].sum() return target_activation_map