Skip to yearly menu bar Skip to main content


Poster

Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel

Shivam Gupta · Linda Cai · Sitan Chen

Hall 3 + Hall 2B #165
[ ]
Fri 25 Apr 7 p.m. PDT — 9:30 p.m. PDT

Abstract: Sampling algorithms play an important role in controlling the quality and runtime of diffusion model inference. In recent years, a number of works (Chen et al., 2023c;b; Benton et al., 2023; Lee et al., 2022) have analyzed algorithms for diffusion sampling with provable guarantees; these works show that for essentially any data distribution, one can approximately sample in polynomial time given a sufficiently accurate estimate of its score functions at different noise levels. In this work, we propose a new scheme inspired by Shen and Lee's randomized midpoint method for log-concave sampling (Shen & Lee, 2019). We prove that this approach achieves the best known dimension dependence for sampling from arbitrary smooth distributions in total variation distance (O~(d5/12) compared to O~(d) from prior work). We also show that our algorithm can be parallelized to run in only O~(log2d) parallel rounds, constituting the first provable guarantees for parallel sampling with diffusion models. As a byproduct of our methods, for the well-studied problem of log-concave sampling in total variation distance, we give an algorithm and simple analysis achieving dimension dependence O~(d5/12) compared to O~(d) from prior work.

Live content is unavailable. Log in and register to view live content