https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
https://math.stackexchange.com/questions/3250317/derivative-of-axb-with-respect-to-x
import torch import torch.nn as nn import numpy as np
class FocalLoss(nn.Module): def init(self, alpha=1, gamma=2, reduction: str = 'mean'): super().init() if reduction not in ['mean', 'none', 'sum']: raise NotImplementedError('Reduction {} not implemented.'.format(reduction)) self.reduction = reduction self.alpha = alpha self.gamma = gamma
def forward(self, x, target):
eps = np.finfo(float).eps
x = torch.sigmoid(x)
p_t = torch.where(target == 1, x, 1-x)
fl = - 1 * (1 - p_t) ** self.gamma * torch.log(p_t + eps)
fl = torch.where(target == 1, fl * self.alpha, fl * (1 - self.alpha))
return self._reduce(fl)
def _reduce(self, x):
if self.reduction == 'mean':
return x.mean()
elif self.reduction == 'sum':
return x.sum()
else:
return x