Transformer Uncertainty Estimation with Hierarchical Stochastic Attention

Authors

  • Jiahuan Pei University of Amsterdam, Amsterdam, the Netherlands Amazon Development Center Germany GmbH, Berlin, Germany
  • Cheng Wang Amazon Development Center Germany GmbH, Berlin, Germany
  • György Szarvas Amazon Development Center Germany GmbH, Berlin, Germany

DOI:

https://doi.org/10.1609/aaai.v36i10.21364

Keywords:

Speech & Natural Language Processing (SNLP)

Abstract

Transformers are state-of-the-art in a wide range of NLP tasks and have also been applied to many real-world products. Understanding the reliability and certainty of transformer models is crucial for building trustable machine learning applications, e.g., medical diagnosis. Although many recent transformer extensions have been proposed, the study of the uncertainty estimation of transformer models is under-explored. In this work, we propose a novel way to enable transformers to have the capability of uncertainty estimation and, meanwhile, retain the original predictive performance. This is achieved by learning hierarchical stochastic self-attention that attends to values and a set of learnable centroids, respectively. Then new attention heads are formed with a mixture of sampled centroids using the Gumbel-Softmax trick. We theoretically show that the self-attention approximation by sampling from a Gumbel distribution is upper bounded. We empirically evaluate our model on two text classification tasks with both in-domain (ID) and out-of-domain (OOD) datasets. The experimental results demonstrate that our approach: (1) achieves the best predictive-uncertainty trade-off among compared methods; (2) exhibits very competitive (in most cases, better) predictive performance on ID datasets; (3) is on par with Monte Carlo dropout and ensemble methods in uncertainty estimation on OOD datasets.

Downloads

Published

2022-06-28

How to Cite

Pei, J., Wang, C., & Szarvas, G. (2022). Transformer Uncertainty Estimation with Hierarchical Stochastic Attention. Proceedings of the AAAI Conference on Artificial Intelligence, 36(10), 11147-11155. https://doi.org/10.1609/aaai.v36i10.21364

Issue

Section

AAAI Technical Track on Speech and Natural Language Processing