Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path

X.Y. Han · Vardan Papyan · David Donoho

Keywords: [ dynamics ] [ deep learning ] [ gradient flow ] [ neural collapse ] [ deep learning theory ] [ invariance ] [ inductive bias ] [ adversarial robustness ]

award Outstanding Paper
[ Abstract ]
[ Visit Poster at Spot E0 in Virtual World ] [ OpenReview
Wed 27 Apr 6:30 p.m. PDT — 8:30 p.m. PDT
Oral presentation: Oral 2: Understanding Deep Learning
Tue 26 Apr 1 a.m. PDT — 2:45 a.m. PDT


The recently discovered Neural Collapse (NC) phenomenon occurs pervasively in today's deep net training paradigm of driving cross-entropy (CE) loss towards zero. During NC, last-layer features collapse to their class-means, both classifiers and class-means collapse to the same Simplex Equiangular Tight Frame, and classifier behavior collapses to the nearest-class-mean decision rule. Recent works demonstrated that deep nets trained with mean squared error (MSE) loss perform comparably to those trained with CE. As a preliminary, we empirically establish that NC emerges in such MSE-trained deep nets as well through experiments on three canonical networks and five benchmark datasets. We provide, in a Google Colab notebook, PyTorch code for reproducing MSE-NC and CE-NC: The analytically-tractable MSE loss offers more mathematical opportunities than the hard-to-analyze CE loss, inspiring us to leverage MSE loss towards the theoretical investigation of NC. We develop three main contributions: (I) We show a new decomposition of the MSE loss into (A) terms directly interpretable through the lens of NC and which assume the last-layer classifier is exactly the least-squares classifier; and (B) a term capturing the deviation from this least-squares classifier. (II) We exhibit experiments on canonical datasets and networks demonstrating that term-(B) is negligible during training. This motivates us to introduce a new theoretical construct: the central path, where the linear classifier stays MSE-optimal for feature activations throughout the dynamics. (III) By studying renormalized gradient flow along the central path, we derive exact dynamics that predict NC.

Chat is not available.