This is the official repository for Improved Natural Language Generation via Loss Truncation.
We provide code for loss dropping.
Neural language models are typically trained via log loss. While straightforward to optimize, even small fractions of noisy data (e.g., misannotations and hallucinated facts) can degrade the performance of log loss. As an alternative, prior work has shown that minimizing the distinguishability of generated samples is a principled and robust loss that can handle invalid references. However, distinguishability has not been used in practice due to challenges in optimization and estimation.
Loss truncation a simple and scalable procedure which adaptively removes high log loss examples as a way to optimize for distinguishability. We demonstrate that loss truncation outperforms existing baselines on distinguishability on a summarization task, and show that samples generated by the loss truncation model have factual accuracy ratings that exceed those of baselines and match human references.
See our paper for full details.
We require Python 3.5+ and torch 1.0+.
To install loss_dropper
, in a virtual environment of your choice, run:
pip install -U git+https://github.com/ddkang/loss_dropper.git
- Import loss_dropper:
from loss_dropper import LossDropper
- Initialize
LossDropper
:
self.dropper = LossDropper(dropc=dropc)
- Initialize your loss:
self.criterion = nn.NLLLoss(weight, reduction='none')
IMPORTANT: loss truncation performs dropping at the sequence level. The reductions other than none
will aggregate over the wrong dimensions for truncation.
- Do loss dropping:
loss = loss.view(-1, batch_size) # view by batch size
loss = loss.mean(dim=0) # aggregate by sequence
mask = self.dropper(loss) # The dropper returns a mask of 0s where data should be dropped.
loss *= mask # Mask out the high losses
loss = loss.mean() # Aggregate
IMPORTANT: depending on how your loss functions, you may have to aggregate in different ways.
If you find this useful in your research, please consider citing:
@article{kang2020improved,
title={Improved Natural Language Generation via Loss Truncation},
author={Daniel Kang and Tatsunori Hashimoto},
journal={ACL},
year={2020}
}