A simple, scalable way to learn convex conjugates in high dimensions.
DLT trains a neural network to approximate the convex conjugate
- Overview
- Mathematical Primer
- Method (DLT) in One Look
- Approximate Inverse Sampling
- Certificates: A‑Posteriori Error Estimator
- Applications
- Results at a Glance
- Minimal Working Example (PyTorch)
- Project Structure
- Citation
- License
- Contact
- Acknowledgments
Deep Legendre Transform (DLT) is a learning framework for computing convex conjugates in high dimensions. Classic grid methods for
suffer from the curse of dimensionality; while sup-smoothing methods still require costly integration loops. DLT avoids both by training on exact targets derived from the implicit Legendre–Fenchel identity:
Highlights
-
Scales to high‑D: Works with MLPs / ResNets / ICNNs / KANs; demonstrated up to
$d=200$ . -
Convex outputs (optional): Use an ICNN to guarantee convexity of the learned
$g_\theta \approx f^\ast$ . -
No closed‑form dual needed: Targets come from
$f$ and$\nabla f$ only. -
Built‑in validation: A Monte‑Carlo estimator certifies
$L^2$ approximation error of$g_\theta$ to$f^\ast$ . - Symbolic recovery: With KANs, DLT can rediscover exact closed‑form conjugates in low dimension.
Legendre–Fenchel transform
Legendre (gradient) form on
Implicit Fenchel identity
Train a network
or empirically,
Sampling in gradient space. When
Convexity. Choose
When
Using
- Sample
$Y \sim \nu^\dagger$ on$D$ . - Set
$X = \Psi_{\theta}(Y)$ . - Form targets
$T(Y) = \langle X, Y\rangle - f(X)$ and train$g_\theta(Y)$ to match$T(Y)$ .
-
Architectures for
$\Psi_{\theta}$ :- MLP with spectral normalization (simple, fast).
- Monotone/triangular flows (e.g., monotone splines per dimension) when
$C$ is box‑shaped. - i‑ResNets / coupling‑flow blocks if strict invertibility on
$D$ is desired.
Minimal pseudocode (PyTorch‑style)
# 1) Learn approximate inverse Psi_theta
for step in range(T_inv):
y = sample_from_nu_dagger(batch, D) # desired coverage on D
x_hat = Psi_theta(y)
loss_inv = ((grad_f(x_hat) - y)**2).mean()
loss_inv += lambda_C * barrier_C(x_hat)
update(theta_inv, loss_inv)
# 2) Train DLT using inverse-sampled pairs
for step in range(T_dlt):
y = sample_from_nu_dagger(batch, D)
x = Psi_theta(y).detach()
target = (x * y).sum(dim=1, keepdim=True) - f(x)
loss = ((g_theta(y) - target)**2).mean()
update(theta, loss)Let
This provides a straightforward Monte‑Carlo certificate of
Hamilton–Jacobi PDEs (Hopf formula):
DLT approximates the time‑parameterized dual
Optimal transport / WGANs: Learn convex potentials, see here and here.
Symbolic regression (KANs): Recover exact expressions for
Economics: Derive indirect utility and profit functions.
Physics/Thermodynamics: Switch between thermodynamic potentials by exchanging extensive variables (entropy, volume) for intensive conjugates (temperature, pressure).
Variational analysis: Construct Moreau envelopes.
Convex optimization: Compute convex potentials in optimal transport problems.
DLT vs. classical grid (Lucet LLT) at
| Dim |
Classical (grid/FFT/LLT) | DLT (ResNet/ICNN) |
|---|---|---|
| 2–6 | Fast & accurate on fine grids | Matches the error |
| 8–10 | Time/memory explode |
Trains in seconds–minutes |
| 20–200 | Infeasible | Can be trained to low RMSE |
Architectures: ResNet often gives the best approximation in high‑$d$; ICNN guarantees convexity (sometimes slightly higher error); KANs recover exact closed forms in 2D.
Below is a tiny end‑to‑end demo of DLT on a quadratic
# Minimal DLT demo (PyTorch) — quadratic example
# pip install torch
import torch, math
# Problem setup
d = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)
# Define f and its gradient (quadratic)
def f(x): # x: (batch, d)
return 0.5 * (x**2).sum(dim=1, keepdim=True) # (batch, 1)
def grad_f(x):
return x # ∇f(x) = x
# Approximator g_theta: R^d -> R (aims at f*)
g = torch.nn.Sequential(
torch.nn.Linear(d, 128),
torch.nn.GELU(),
torch.nn.Linear(128, 128),
torch.nn.GELU(),
torch.nn.Linear(128, 1),
).to(device)
opt = torch.optim.Adam(g.parameters(), lr=1e-3)
def dlt_loss(x):
y = grad_f(x) # (batch, d)
target = (x * y).sum(dim=1, keepdim=True) - f(x) # <x,∇f(x)> - f(x)
pred = g(y) # g_theta(∇f(x))
return ((pred - target)**2).mean()
# Training
for step in range(2000):
x = torch.randn(4096, d, device=device) # sample x ~ N(0,I)
loss = dlt_loss(x)
opt.zero_grad()
loss.backward()
opt.step()
if step % 200 == 0:
print(f"step {step:4d} | loss ~ L2 error^2: {loss.item():.3e}")
# A‑posteriori certificate on held‑out data
with torch.no_grad():
x = torch.randn(8192, d, device=device)
y = grad_f(x)
target = (x * y).sum(dim=1, keepdim=True) - f(x)
pred = g(y)
mse = ((pred - target)**2).mean().sqrt().item() # RMSE certificate
print(f"Certified RMSE on ∇f(C): {mse:.3e}")Note: For general
$f$ , replacegrad_fwithtorch.autograd.gradonf(x).sum()(keepingx.requires_grad_(True)), or use your analytic gradient. For convex outputs, swapgfor an ICNN.
├── main_part/ # Core implementation and experiment scripts
├── appendix/ # Supplementary experiments, figures, extended tables
├── images # Figures used in paper
├── LICENSE # Apache 2.0
└── README.md
If you use DLT in your research, please cite:
@inproceedings{minabutdinov2025deep,
title = {Deep Legendre Transform},
author = {Minabutdinov, Aleksey and Cheridito, Patrick},
booktitle = {NeurIPS},
year = {2025}
}This project is licensed under the Apache License 2.0. See LICENSE for details.
Aleksey Minabutdinov — aminabutdinov@ethz.ch (ETH Zurich) Patrick Cheridito — patrickc@ethz.ch (ETH Zurich)
Swiss National Science Foundation — Grant No. 10003723