This repository has been archived on 2022-07-14. You can view files and clone it, but cannot push or open issues or pull requests.
NJIT-2021GXS-Detection/src/vision/nn/scaled_l2_norm.py

20 lines
594 B
Python
Raw Normal View History

2022-03-02 19:32:08 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledL2Norm(nn.Module):
def __init__(self, in_channels, initial_scale):
super(ScaledL2Norm, self).__init__()
self.in_channels = in_channels
self.scale = nn.Parameter(torch.Tensor(in_channels))
self.initial_scale = initial_scale
self.reset_parameters()
def forward(self, x):
return (F.normalize(x, p=2, dim=1)
* self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3))
def reset_parameters(self):
self.scale.data.fill_(self.initial_scale)