Getting Started with TransportMaps.jl

This guide will help you get started with TransportMaps.jl for constructing and using transport maps.

Basic Concepts

What is a Transport Map?

A transport map $T$ is a function that transforms samples from a reference distribution (typically standard Gaussian) to a target distribution [1]. The key property is that if $X \sim \rho_0$ (reference) and $Y = T(X)$, then $Y \sim \rho_1$ (target).

Triangular Maps

TransportMaps.jl focuses on triangular transport maps [3], where:

\[T(\boldsymbol{x}) = \left(\begin{array}{c} T_1(x_1) \\ T_2(x_1, x_2) \\ T_3(x_1, x_2, x_3) \\ \vdots \\ T_n(x_1, x_2 \dots, x_n) \end{array} \right)\]

This structure ensures that the map is invertible and the Jacobian determinant is easy to compute. The construction follows the Knothe-Rosenblatt rearrangement [6].

First Example: A Simple 2D Transport Map

using TransportMaps
using Distributions
using Random
using Plots
using LinearAlgebra

Let's create a simple 2D transport map:

Set random seed for reproducibility

Random.seed!(1234)

Create a 2D polynomial map with degree 2

M = PolynomialMap(2, 2, Normal(), Softplus())
PolynomialMap:
  Dimensions: 2
  Total coefficients: 9
  Reference density: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
  Maximum degree: 2
  Basis: Hermite
  Rectifier: Softplus
  Components:
    Component 1: 3 basis functions
    Component 2: 6 basis functions
  Coefficients: min=0.0, max=6.92681825896096e-310, mean=3.0755954533446e-310

The map is initially identity (coefficients are zero)

println("Initial coefficients: ", getcoefficients(M))
Initial coefficients: [6.92681825896096e-310, 8.4879832073e-314, 6.89912258832325e-310, 6.9267847172494e-310, 6.9267847172474e-310, 0.0, 0.0, 0.0, 0.0]

Defining a Target Distribution

For optimization, you need to define your target probability density. Let's start with a simple correlated Gaussian:

Example: Correlated Gaussian

function correlated_gaussian(x; ρ=0.8)
    Σ = [1.0 ρ; ρ 1.0]
    return pdf(MvNormal(zeros(2), Σ), x)
end

Create a MapTargetDensity object for optimization

target_density = MapTargetDensity(correlated_gaussian, :auto_diff)
MapTargetDensity(density=correlated_gaussian, gradient_type=auto_diff)

Setting up Quadrature

Choose an appropriate quadrature scheme for map optimization:

Gauss-Hermite quadrature (good for Gaussian-like targets)

quadrature = GaussHermiteWeights(5, 2)  # 5 points per dimension, 2D
GaussHermiteWeights:
  Number of points: 25
  Dimensions: 2
  Quadrature type: Tensor product Gauss-Hermite
  Reference measure: Standard Gaussian
  Weight range: [0.00012672930980149358, 0.28444444444444505]
  Weight sum: 1.0000000000000002
  Points (first 5):
    [-2.8569700138728056, -2.8569700138728056] → weight: 0.00012672930980149358
    [-1.3556261799742675, -2.8569700138728056] → weight: 0.0024999999999999927
    [-1.2560739669470201e-15, -2.8569700138728056] → weight: 0.0060039527081176955
    [1.3556261799742675, -2.8569700138728056] → weight: 0.0024999999999999927
    [2.8569700138728056, -2.8569700138728056] → weight: 0.00012672930980149358
    ... and 20 more

Alternative options (commented out): quadrature = MonteCarloWeights(1000, 2) # 1000 samples, 2D quadrature = LatinHypercubeWeights(1000, 2)

Optimizing the Map

Fit the transport map to your target distribution:

println("Optimizing the map...")
@time result = optimize!(M, target_density, quadrature)

println("Optimization result: ", result)
println("Final coefficients: ", getcoefficients(M))
Optimizing the map...
  1.004850 seconds (3.43 M allocations: 187.633 MiB, 3.84% gc time, 89.32% compilation time)
Optimization result:  * Status: success

 * Candidate solution
    Final objective value:     2.837877e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 3.50e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 4.37e-08 ≰ 0.0e+00
    |f(x) - f(x')|         = 2.22e-15 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 7.82e-16 ≰ 0.0e+00
    |g(x)|                 = 6.45e-11 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    7
    f(x) calls:    19
    ∇f(x) calls:   19

Final coefficients: [2.944013880231288e-12, 0.5413248546146763, -7.222147401065582e-12, -2.8435575707418865e-12, 0.8000000000013396, -0.19587036834756433, 7.950291974275175e-12, -6.317184866841605e-17, 2.548287156231699e-12]

Generating Samples

Once optimized, use the map to generate samples:

Generate reference samples (standard Gaussian)

n_samples = 1000
reference_samples = randn(n_samples, 2)
1000×2 Matrix{Float64}:
  0.970656   -0.563375
 -0.979218   -0.321198
  0.901861   -1.08085
 -0.0328031   0.1828
 -0.600792   -1.10277
 -1.44518     0.0973357
  2.70742    -1.50738
  1.52445     0.495961
  0.759804    1.65377
 -0.881437   -0.902006
  ⋮          
  0.736417    0.898635
  0.191944    0.0989677
  0.764671   -0.723075
  0.460548    0.805013
 -1.45535    -0.952593
 -0.73168     0.66637
 -0.463285   -0.0398125
  0.511219    0.288282
 -1.29112    -3.55823

Transform to target distribution

target_samples = evaluate(M, reference_samples)
1000×2 Matrix{Float64}:
  0.970656    0.4385
 -0.979218   -0.976093
  0.901861    0.0729767
 -0.0328031   0.0834377
 -0.600792   -1.1423
 -1.44518    -1.09774
  2.70742     1.26151
  1.52445     1.51713
  0.759804    1.60011
 -0.881437   -1.24635
  ⋮          
  0.736417    1.12831
  0.191944    0.212936
  0.764671    0.177892
  0.460548    0.851446
 -1.45535    -1.73584
 -0.73168    -0.185522
 -0.463285   -0.394516
  0.511219    0.581945
 -1.29112    -3.16784

Visualizing Results

Let's plot both the reference and target samples:

p1 = scatter(reference_samples[:, 1], reference_samples[:, 2],
            alpha=0.6, title="Reference Samples",
            xlabel="Z₁", ylabel="Z₂", legend=false, aspect_ratio=:equal)

p2 = scatter(target_samples[:, 1], target_samples[:, 2],
            alpha=0.6, title="Target Samples",
            xlabel="X₁", ylabel="X₂", legend=false, aspect_ratio=:equal)

plot(p1, p2, layout=(1,2), size=(800, 400))

Transport Map Samples

Evaluating Map Quality

Check how well your map approximates the target:

Variance diagnostic (should be close to 1 for good maps)

var_diag = variance_diagnostic(M, target_density, reference_samples)
println("Variance diagnostic: ", var_diag)
Variance diagnostic: 6.2788905504523345e-22

You can also check the Jacobian determinant

sample_point = [0.0, 0.0]
jac = jacobian(M, sample_point)
det_jac = det(jac)
println("Jacobian determinant at origin: ", det_jac)
Jacobian determinant at origin: 0.6000000000001051

Working with Different Rectifiers

The rectifier function affects the map's behavior. Let's compare different options:

ShiftedELU rectifier

M_elu = PolynomialMap(2, 2, Normal(), ShiftedELU())
result_elu = optimize!(M_elu, target_density, quadrature)
var_diag_elu = variance_diagnostic(M_elu, target_density, reference_samples)

println("Variance diagnostics:")
println("  Softplus: ", var_diag)
println("  ShiftedELU: ", var_diag_elu)
Variance diagnostics:
  Softplus: 6.2788905504523345e-22
  ShiftedELU: 7.343633321845726e-24

More Complex Example: Banana Distribution

Now let's try a more challenging target - the banana distribution:

Define banana density

banana_density(x) = pdf(Normal(), x[1]) * pdf(Normal(), x[2] - x[1]^2)
target_density_banana = MapTargetDensity(banana_density, :auto_diff)
MapTargetDensity(density=banana_density, gradient_type=auto_diff)

Create a new map for this target

M_banana = PolynomialMap(2, 2, Normal(), Softplus())
result_banana = optimize!(M_banana, target_density_banana, quadrature)
 * Status: success

 * Candidate solution
    Final objective value:     2.837877e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.22e-07 ≰ 0.0e+00
    |x - x'|/|x'|          = 1.22e-07 ≰ 0.0e+00
    |f(x) - f(x')|         = 1.53e-13 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 5.38e-14 ≰ 0.0e+00
    |g(x)|                 = 2.18e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    13
    f(x) calls:    34
    ∇f(x) calls:   34

Display optimized map

display(M_banana)

Generate samples

banana_samples = evaluate(M_banana, reference_samples)
1000×2 Matrix{Float64}:
  0.970656    0.378799
 -0.979218    0.637671
  0.901861   -0.2675
 -0.0328031   0.183876
 -0.600792   -0.741822
 -1.44518     2.18587
  2.70742     5.82276
  1.52445     2.8199
  0.759804    2.23107
 -0.881437   -0.125075
  ⋮          
  0.736417    1.44094
  0.191944    0.13581
  0.764671   -0.138353
  0.460548    1.01712
 -1.45535     1.16545
 -0.73168     1.20173
 -0.463285    0.174821
  0.511219    0.549628
 -1.29112    -1.89123

Visualize the banana distribution

x1_grid = range(-3, 3, length=100)
x2_grid = range(-3, 6, length=100)
posterior_values = [banana_density([x₁, x₂]) for x₂ in x2_grid, x₁ in x1_grid]

scatter(banana_samples[:, 1], banana_samples[:, 2],
        alpha=0.6, title="Banana Distribution Samples",
        xlabel="X₁", ylabel="X₂", legend=false, aspect_ratio=:equal)
contour!(x1_grid, x2_grid, posterior_values, colormap=:viridis, label="Posterior Density")

Banana Distribution Samples

Check quality

var_diag_banana = variance_diagnostic(M_banana, target_density_banana, reference_samples)
println("Banana distribution variance diagnostic: ", var_diag_banana)
Banana distribution variance diagnostic: 1.2687437142582097e-18