MLE-Guided Parameter Search for Task Loss Minimization in Neural Sequence Modeling

Authors

  • Sean Welleck New York University
  • Kyunghyun Cho New York University

DOI:

https://doi.org/10.1609/aaai.v35i16.17652

Keywords:

Learning & Optimization for SNLP

Abstract

Neural autoregressive sequence models are used to generate sequences in a variety of natural language processing (NLP) tasks, where they are evaluated according to sequence-level task losses. These models are typically trained with maximum likelihood estimation, which ignores the task loss, yet empirically performs well as a surrogate objective. Typical approaches to directly optimizing the task loss such as policy gradient and minimum risk training are based around sampling in the sequence space to obtain candidate update directions that are scored based on the loss of a single sequence. In this paper, we develop an alternative method based on random search in the parameter space that leverages access to the maximum likelihood gradient. We propose maximum likelihood guided parameter search (MGS), which samples from a distribution over update directions that is a mixture of random search around the current parameters and around the maximum likelihood gradient, with each direction weighted by its improvement in the task loss. MGS shifts sampling to the parameter space, and scores candidates using losses that are pooled from multiple sequences. Our experiments show that MGS is capable of optimizing sequence-level losses, with substantial reductions in repetition and non-termination in sequence completion, and similar improvements to those of minimum risk training in machine translation.

Downloads

Published

2021-05-18

How to Cite

Welleck, S., & Cho, K. (2021). MLE-Guided Parameter Search for Task Loss Minimization in Neural Sequence Modeling. Proceedings of the AAAI Conference on Artificial Intelligence, 35(16), 14032-14040. https://doi.org/10.1609/aaai.v35i16.17652

Issue

Section

AAAI Technical Track on Speech and Natural Language Processing III