MASKER: Masked Keyword Regularization for Reliable Text Classification

Authors

  • Seung Jun Moon KAIST
  • Sangwoo Mo KAIST
  • Kimin Lee UC Berkeley
  • Jaeho Lee KAIST
  • Jinwoo Shin KAIST

DOI:

https://doi.org/10.1609/aaai.v35i15.17601

Keywords:

Text Classification & Sentiment Analysis

Abstract

Pre-trained language models have achieved state-of-the-art accuracies on various text classification tasks, e.g., sentiment analysis, natural language inference, and semantic textual similarity. However, the reliability of the fine-tuned text classifiers is an often underlooked performance criterion. For instance, one may desire a model that can detect out-of-distribution (OOD) samples (drawn far from training distribution) or be robust against domain shifts. We claim that one central obstacle to the reliability is the over-reliance of the model on a limited number of keywords, instead of looking at the whole context. In particular, we find that (a) OOD samples often contain in-distribution keywords, while (b) cross-domain samples may not always contain keywords; over-relying on the keywords can be problematic for both cases. In light of this observation, we propose a simple yet effective fine-tuning method, coined masked keyword regularization (MASKER), that facilitates context-based prediction. MASKER regularizes the model to reconstruct the keywords from the rest of the words and make low-confidence predictions without enough context. When applied to various pre-trained language models (e.g., BERT, RoBERTa, and ALBERT), we demonstrate that MASKER improves OOD detection and cross-domain generalization without degrading classification accuracy. Code is available at https://github.com/alinlab/MASKER.

Downloads

Published

2021-05-18

How to Cite

Moon, S. J., Mo, S., Lee, K., Lee, J., & Shin, J. (2021). MASKER: Masked Keyword Regularization for Reliable Text Classification. Proceedings of the AAAI Conference on Artificial Intelligence, 35(15), 13578-13586. https://doi.org/10.1609/aaai.v35i15.17601

Issue

Section

AAAI Technical Track on Speech and Natural Language Processing II