Skip to yearly menu bar Skip to main content


Poster

Learning Hierarchical Polynomials with Three-Layer Neural Networks

Zihao Wang · Eshaan Nichani · Jason Lee

Halle B #129

Abstract: We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form h=gp where p:RdR is a degree k polynomial and g:RR is a degree q polynomial. This function class generalizes the single-index model, which corresponds to k=1, and is a natural class of functions possessing an underlying hierarchical structure. Our main result shows that for a large subclass of degree k polynomials p, a three-layer neural network trained via layerwise gradient descent on the square loss learns the target h up to vanishing test error in O~(dk) samples and polynomial time. This is a strict improvement over kernel methods, which require Θ~(dkq) samples, as well as existing guarantees for two-layer networks, which require the target function to be low-rank. Our result also generalizes prior works on three-layer neural networks, which were restricted to the case of p being a quadratic. When p is indeed a quadratic, we achieve the information-theoretically optimal sample complexity O~(d2), which is an improvement over prior work (Nichani et al., 2023) requiring a sample size of Θ~(d4). Our proof proceeds by showing that during the initial stage of training the network performs feature learning to recover the feature p with O~(dk) samples. This work demonstrates the ability of three-layer neural networks to learn complex features and as a result, learn a broad class of hierarchical functions.

Chat is not available.