DAPS

Published as a conference paper at ICLR 2026

Discrete Variational Autoencoding via Policy Search

Train discrete autoregressive encoders without straight-through gradient estimators! DAPS uses an ESS-based trust region and a weighted maximum-likelihood update, resulting in stable training and high performance on high-dimensional reconstruction tasks.

  • Reinforcement Learning
  • Variational Inference
  • Discrete Autoregressive Modeling
Teaser figure (placeholder)
DAPS policy at a glance: Autoregressive encoding of the latent sequence.

Overview

Discrete VAEs often rely on biased surrogate gradients (e.g., straight-through estimators) or continuous relaxations with temperature-sensitive bias–variance tradeoffs. DAPS replaces these with a policy-search-style update: sample discrete latents, compute advantages, form an optimal KL-regularized target $q^*$, and update the encoder via weighted MLE.

$$ q^*(z \mid x) \propto q_\theta(z \mid x)\,\exp\!\left(\frac{A(x,z) - \beta \log q_\theta(z \mid x)}{\eta + \beta}\right) $$

Optimal non-parametric target distribution under a KL trust-region and entropy regularization.

Results

ImageNet-256 validation reconstructions

DAPS
DAPS reconstructions on ImageNet-256
FSQ
FSQ reconstructions on ImageNet-256
Ground truth
Ground truth images from ImageNet-256

Reconstructions are highly compressed: all methods use the same bottleneck capacity (1.28 KB per image), yielding a compact discrete latent.

Codebook utilization

MNIST
CIFAR-10

Robotics

Dancing

Running

Walking

We use DAPS as a compact command space for goal-conditioned robot control. A high-level policy generates discrete latent codes autoregressively from (i) a language prompt and (ii) a desired center-of-mass (COM) velocity/trajectory. A low-level imitation policy then decodes these latents into physically consistent torques in simulation (implemented with LocoMujoco).

Method Overview

Stochastic rollouts

We estimate the value function on-the-fly using K independent latent samples from the encoder policy $q_\theta(z\mid x)$.

Rewards and advantages

Score each sampled latent with the reconstruction log-likelihood, then form (baseline-subtracted) advantages.

$$ \mathcal{R}^k(x,z^k)=\log p_\phi(x\mid z^k)+c, \qquad A(x,z^k)=\mathcal{R}^k-\log\sum_{j=1}^{K}\exp(\mathcal{R}^j). $$

Optimal policy $q^*$ (derivation)

Solve the optimization problem to get a closed-form target distribution.

Show derivation

Optimal non-parametric target \(q^*\). For each \(x\), we solve:

$$ \begin{aligned} \max_{q}\;& \sum_{z} q(z\mid x)\,A(x,z) \;+\; \beta\,\mathcal{H}\!\left(q(\cdot\mid x)\right) \\ \text{s.t.}\;& \mathrm{KL}\!\left(q(\cdot\mid x)\,\|\,q_\theta(\cdot\mid x)\right) \le \epsilon_\eta, \quad \sum_{z} q(z\mid x)=1, \end{aligned} $$

where \(\mathcal{H}(q)=-\sum_z q(z\mid x)\log q(z\mid x)\). Introducing Lagrange multipliers \(\eta\) and \(\lambda(x)\), the Lagrangian is:

$$ \begin{aligned} \mathcal{L}(q,\eta,\lambda) &= \sum_z q(z\mid x)\,A(x,z) \;+\; \beta\,\mathcal{H}(q) \;-\;\eta\!\left(\sum_z q(z\mid x)\log\frac{q(z\mid x)}{q_\theta(z\mid x)} - \epsilon_\eta\right) \;+\;\lambda(x)\!\left(1-\sum_z q(z\mid x)\right). \end{aligned} $$

Setting \(\partial \mathcal{L}/\partial q=0\) yields:

$$ \begin{aligned} 0 &= A(x,z) \;-\;(\eta+\beta)\big(\log q(z\mid x)+1\big) \;+\;\eta\log q_\theta(z\mid x)\;-\;\lambda(x). \end{aligned} $$

Rearranging gives an unnormalized form for \(q^*(z\mid x)\):

$$ \begin{aligned} q^*(z\mid x) &\propto \exp\!\left(\frac{A(x,z) + \eta \log q_\theta(z\mid x)}{\eta+\beta}\right) \\ &= q_\theta(z\mid x)\, \exp\!\left(\frac{A(x,z) - \beta \log q_\theta(z\mid x)}{\eta+\beta}\right). \end{aligned} $$

Normalizing over \(z\) (using the freedom in \(\lambda(x)\)) then gives the unique distribution satisfying \(\sum_z q^*(z\mid x)=1\).

Parameter updates

We update the encoder by minimizing $\mathrm{KL}(q^*\,\|\,q_\theta)$ (a weighted MLE objective), update the decoder by maximum likelihood, and adapt $\eta$ via effective sample size (ESS). Here $N$ is the minibatch size and $K$ is the number of latent samples per datapoint.

Encoder update (forward KL) — self-normalized weights $\tilde w_{ik}\propto w_{ik}$ with $w_{ik}=q^*(z_i^k\mid x_i;\eta)/q_\theta(z_i^k\mid x_i)$.

$$ \mathcal{L}(\theta) \;=\; \mathbb{E}_{x\sim p(x)}\!\left[\mathrm{KL}\!\left(q^*(\cdot\mid x)\,\|\,q_\theta(\cdot\mid x)\right)\right] \;\approx\; -\frac{1}{N}\sum_{i=1}^{N}\sum_{k=1}^{K}\tilde w_{ik}\,\log q_\theta(z_i^k\mid x_i) \;+\;\text{const}. $$

Decoder update — standard maximum likelihood under $z_i^k\sim q_\theta(\cdot\mid x_i)$.

$$ \mathcal{L}(\phi) \;\approx\; -\frac{1}{N}\sum_{i=1}^{N}\frac{1}{K}\sum_{k=1}^{K}\log p_\phi(x_i\mid z_i^k), \qquad z_i^k\sim q_\theta(\cdot\mid x_i). $$

Step-size adaptation ($\eta$) — match a target ESS level (hyperparameter).

$$ \mathcal{L}(\eta) \;=\; \left(\widehat{\mathrm{ESS}}_\eta-\mathrm{ESS}_{\text{target}}\right)^2, \qquad \widehat{\mathrm{ESS}}_\eta \;=\; \frac{1}{N}\sum_{i=1}^N \frac{\Big(\sum_{k=1}^K w_{ik}\Big)^2}{\sum_{k=1}^K w_{ik}^2}. $$

Citation

@inproceedings{drolet2026daps,
  title     = {Discrete Variational Autoencoding via Policy Search},
  author    = {Drolet, Michael and Al-Hafez, Firas and Bhatt, Aditya and Peters, Jan and Arenz, Oleg},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year      = {2026},
  url       = {https://www.drolet.io/daps/}
}