You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm confused of the function of ComplementCrossEntropyLoss? what's the difference between ComplementCrossEntropyLoss and nn.CrossEntropyLoss()
In ComplementCrossEntropyLoss.py, when call for torch.nn.functional.nll_loss, why input is torch.log(1. - torch.nn.functional.softmax(input) + 1e-4)? what does 1- means?
class ComplementCrossEntropyLoss(torch.nn.Module):
Note: This is the cross entropy of the sum of all probabilities of other indices, except for the
This is used in Bayesian GAN semi-supervised learning
def forward(self, input, target=None):
# Use target if not None, else use self.except_index
if target is not None: assert_no_grad(target)
else:
assert self.except_index is not None
target = torch.autograd.Variable(torch.LongTensor(input.data.shape[0]).fill(self.except_index).cuda())
result = torch.nn.functional.nll_loss(
torch.log(1. - torch.nn.functional.softmax(input) + 1e-4),
target, weight=self.weight,
size_average=self.size_average,
ignore_index=self.ignore_index)
return result
The text was updated successfully, but these errors were encountered:
Thanks for sharing your code, my question are:
class ComplementCrossEntropyLoss(torch.nn.Module):
Note: This is the cross entropy of the sum of all probabilities of other indices, except for the
This is used in Bayesian GAN semi-supervised learning
def init(self, except_index=None, weight=None, ignore_index=-100, size_average=True, reduce=True):
super(ComplementCrossEntropyLoss, self).init()
self.except_index = except_index
self.weight = weight
self.ignore_index = ignore_index
self.size_average = size_average
self.reduce = reduce
def forward(self, input, target=None):
# Use target if not None, else use self.except_index
if target is not None:
assert_no_grad(target)
else:
assert self.except_index is not None
target = torch.autograd.Variable(torch.LongTensor(input.data.shape[0]).fill(self.except_index).cuda())
result = torch.nn.functional.nll_loss(
torch.log(1. - torch.nn.functional.softmax(input) + 1e-4),
target, weight=self.weight,
size_average=self.size_average,
ignore_index=self.ignore_index)
return result
The text was updated successfully, but these errors were encountered: