View publication

Learning disentangled representations from unlabelled data is a fundamental challenge in machine learning. Solving it may unlock other problems, such as generalization, interpretability, or fairness. Although remarkably challenging to solve in theory, disentanglement is often achieved in practice through prior matching. Furthermore, recent works have shown that prior matching approaches can be enhanced by leveraging geometrical considerations, e.g., by learning representations that preserve geometric features of the data, such as distances or angles between points. However, matching the prior while preserving geometric features is challenging, as a mapping that fully preserves these features while aligning the data distribution with the prior does not exist in general. To address these challenges, we introduce a novel approach to disentangled representation learning based on quadratic optimal transport. We formulate the problem using Gromov-Monge maps that transport one distribution onto another with minimal distortion of predefined geometric features, preserving them as much as can be achieved. To compute such maps, we propose the Gromov-Monge-Gap (GMG), a regularizer quantifying whether a map moves a reference distribution with minimal geometry distortion. We demonstrate the effectiveness of our approach for disentanglement across four standard benchmarks, outperforming other methods leveraging geometric considerations.

*Equal contribution
**Equal advising
†CREST-ENSAE
‡Helmholtz Munich
§TU Munich
¶MCML
††Tubingen AI Center

Related readings and updates.

Given a source and a target probability measure supported on Rd\mathbb{R}^dRd, the Monge problem aims for the most efficient way to map one distribution to the other. This efficiency is quantified by defining a cost function between source and target data. Such a cost is often set by default in the machine learning literature to the squared-Euclidean distance, ℓ22(x,y)=12∥x−y∥22\ell^2_2(x,y)=\tfrac12\|x-y\|_2^2ℓ22​(x,y)=21​∥x−y∥22​. The benefits…
Read more
Optimal transport (OT) theory has been been used in machine learning to study and characterize maps that can push-forward efficiently a probability measure onto another. Recent works have drawn inspiration from Brenier's theorem, which states that when the ground cost is the squared-Euclidean distance, the "best" map to morph a continuous measure in P(Rd)\mathcal{P}(\mathbb{R}^d)P(Rd) into another must be the gradient of a convex function. To…
Read more