Multi-Proxy Wasserstein Classifier for Image Classification

Authors

  • Benlin Liu UCLA
  • Yongming Rao Tsinghua University
  • Jiwen Lu Tsinghua University
  • Jie Zhou Tsinghua University
  • Cho-Jui Hsieh UCLA

DOI:

https://doi.org/10.1609/aaai.v35i10.17045

Keywords:

Classification and Regression

Abstract

Most widely-used convolutional neural networks (CNNs) end up with a global average pooling layer and a fully-connected layer. In this pipeline, a certain class is represented by one template vector preserved in the feature banks of fully-connected layer. Yet, a class may have multiple properties useful for recognition while the above formulation only captures one of them. Therefore, it is desired to represent a class by multiple proxies. However, directly adding multiple linear layers turns out to be a trivial solution as no improvement can be observed. To tackle this problem, we adopt optimal transport theory to calculate a non-uniform matching flow between the elements in the feature map of a sample and the proxies of a class in a closed way. By doing so, the models are enabled to achieve partial matching as both the feature maps and the proxy set can now focus on a subset of elements from the counterpart. Such formulation also enables us to embed the samples into the Wasserstein metric space, which has many advantages over the original Euclidean space. This formulation can be achieved by a lightweight iterative algorithm, which can be easily embedded into the automatic differentiation framework. Empirical studies are performed on two widely-used classification datasets, CIFAR, and ILSVRC2012, and the substantial improvements on these two benchmarks demonstrate the effectiveness of our method.

Downloads

Published

2021-05-18

How to Cite

Liu, B., Rao, Y., Lu, J., Zhou, J., & Hsieh, C.-J. (2021). Multi-Proxy Wasserstein Classifier for Image Classification. Proceedings of the AAAI Conference on Artificial Intelligence, 35(10), 8618-8626. https://doi.org/10.1609/aaai.v35i10.17045

Issue

Section

AAAI Technical Track on Machine Learning III