Continuous Self-Attention Models with Neural ODE Networks

Authors

  • Jing Zhang Tianjin University
  • Peng Zhang Tianjin University
  • Baiwen Kong Tianjin University
  • Junqiu Wei Huawei Noah’s Ark Lab
  • Xin Jiang Huawei Noah's Ark Lab

Keywords:

Language Models

Abstract

Stacked self-attention models receive widespread attention, due to its ability of capturing global dependency among words. However, the stacking of many layers and components generates huge parameters, leading to low parameter efficiency. In response to this issue, we propose a lightweight architecture named Continuous Self-Attention models with neural ODE networks (CSAODE). In CSAODE, continuous dynamical models (i.e., neural ODEs) are coupled with our proposed self-attention block to form a self-attention ODE solver. This solver continuously calculates and optimizes the hidden states via only one layer of parameters to improve the parameter efficiency. In addition, we design a novel accelerated continuous dynamical model to reduce computing costs, and integrate it in CSAODE. Moreover, since the original self-attention ignores local information, CSAODE makes use of N-gram convolution to encode local representations, and a fusion layer with only two trainable scalars are designed for generating sentence vectors. We perform a series of experiments on text classification, neural language inference (NLI) and text matching tasks. With fewer parameters, CSAODE outperforms state-of-the-art models on text classification tasks (e.g., 1.3% accuracy improved on SUBJ task), and has competitive performances for NLI and text matching tasks as well.

Downloads

Published

2021-05-18

How to Cite

Zhang, J., Zhang, P., Kong, B., Wei, J., & Jiang, X. (2021). Continuous Self-Attention Models with Neural ODE Networks. Proceedings of the AAAI Conference on Artificial Intelligence, 35(16), 14393-14401. Retrieved from https://ojs.aaai.org/index.php/AAAI/article/view/17692

Issue

Section

AAAI Technical Track on Speech and Natural Language Processing III