Map from Samples Example

This example demonstrates how to use TransportMaps.jl to approximate a "banana" distribution using polynomial transport maps when only samples from the target distribution are available.

Unlike the density-based approach, this method learns the transport map directly from sample data using optimization techniques. This is particularly useful when the target density is unknown or difficult to evaluate [1].

We start with the necessary packages:

using TransportMaps
using Distributions
using LinearAlgebra
using Plots

Generating Target Samples

The banana distribution has the density:

\[p(x) = \phi(x_1) \cdot \phi(x_2 - x_1^2)\]

where $\phi$ is the standard normal PDF.

banana_density(x) = pdf(Normal(), x[1]) * pdf(Normal(), x[2] - x[1]^2)
banana_density (generic function with 1 method)

Set up the log-target function for sampling:

num_samples = 500

Generate samples using rejection sampling (no external dependencies)

function generate_banana_samples(n_samples::Int)
    samples = Matrix{Float64}(undef, n_samples, 2)

    count = 0
    while count < n_samples
        x1 = randn() * 2
        x2 = randn() * 3 + x1^2

        if rand() < banana_density([x1, x2]) / 0.4
            count += 1
            samples[count, :] = [x1, x2]
        end
    end

    return samples
end

println("Generating samples from banana distribution...")
target_samples = generate_banana_samples(num_samples)
println("Generated $(size(target_samples, 1)) samples")
Generating samples from banana distribution...
Generated 500 samples

Creating the Transport Map

We create a 2-dimensional polynomial transport map with degree 2. For sample-based optimization, we typically start with lower degrees and can increase complexity as needed.

M = PolynomialMap(2, 2)
PolynomialMap:
  Dimensions: 2
  Total coefficients: 9
  Reference density: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
  Maximum degree: 2
  Basis: Hermite
  Rectifier: Softplus
  Components:
    Component 1: 3 basis functions
    Component 2: 6 basis functions
  Coefficients: uninitialized

Optimizing from Samples

The key difference from density-based optimization is that we optimize directly from the sample data without requiring the density function. Inside the optimization the map is arranged s.t. the "forward" direction is from the (unknown) target distribution to the standard normal distribution:

@time res = optimize!(M, target_samples)
println("Optimization result: ", res)
  6.065906 seconds (94.94 M allocations: 5.540 GiB, 6.23% gc time, 7.16% compilation time)
Optimization result:  * Status: success

 * Candidate solution
    Final objective value:     2.654886e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 5.01e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 4.78e-08 ≰ 0.0e+00
    |f(x) - f(x')|         = 4.88e-15 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.84e-15 ≰ 0.0e+00
    |g(x)|                 = 4.02e-10 ≤ 1.0e-08

 * Work counters
    Seconds run:   6  (vs limit Inf)
    Iterations:    19
    f(x) calls:    58
    ∇f(x) calls:   58

Testing the Map

Let's generate new samples from the banana density and standard normal samples and map them through our optimized transport map to verify the learned distribution:

new_samples = generate_banana_samples(1000)
norm_samples = randn(1000, 2)
1000×2 Matrix{Float64}:
  0.0696263   0.569546
 -0.0118946   0.0401304
  1.6895      1.43094
  0.793694   -1.12277
 -1.05876    -0.0345062
 -0.767916   -0.162202
 -0.374317    0.0290939
 -0.851917   -1.14503
  1.18843     0.289716
  0.169474   -1.79299
  ⋮          
  0.351009   -0.0851662
  0.915733   -0.274554
  1.30743     0.744332
  0.246186   -0.673974
 -0.923339   -0.586072
 -0.267338   -0.437785
 -1.71726     2.48609
 -0.847807   -1.31439
 -0.629006    0.75146

Map the samples through our transport map. Note that evaluate now transports from reference to target, i.e. mapped_samples should be standard normal samples:

mapped_samples = evaluate(M, new_samples)
1000×2 Matrix{Float64}:
 -1.47742    -1.0165
  0.645277   -1.84858
 -0.698264    2.33773
 -2.43713     0.808078
  1.09501     0.339164
 -1.55821    -0.672894
  0.588278   -0.487786
 -1.04991    -0.981329
  1.7823     -0.965122
 -0.0817615  -1.6889
  ⋮          
 -0.317394    0.709783
  1.29987     0.00889054
 -1.92629     1.27716
 -0.171996   -0.0822702
 -0.598767   -0.158228
 -0.299139    0.580278
  0.214492   -0.103406
 -1.08893     0.655359
 -0.272131   -0.475473

while pushing from the standard normal samples to the target distribution generates new samples from the banana distribution:

mapped_banana_samples = inverse(M, norm_samples)
1000×2 Matrix{Float64}:
  0.145652    0.559234
  0.0721646   0.0535075
  1.63916     3.96336
  0.805268   -0.411785
 -0.858081    0.683846
 -0.602099    0.211266
 -0.252683    0.100732
 -0.676223   -0.601504
  1.17024     1.61134
  0.235871   -1.60682
  ⋮          
  0.400501    0.0858838
  0.917691    0.570123
  1.28102     2.30069
  0.305345   -0.526749
 -0.739124   -0.0020765
 -0.157109   -0.369033
 -1.4309      4.17815
 -0.6726     -0.762244
 -0.479178    0.923286

Visualizing Results

Let's create a scatter plot comparing the original samples with the mapped samples to see how well our transport map learned the distribution:

p11 = scatter(new_samples[:, 1], new_samples[:, 2],
            label="Original Samples", alpha=0.5, color=1,
            title="Original Banana Distribution Samples",
            xlabel="x₁", ylabel="x₂")

scatter!(p11, mapped_banana_samples[:, 1], mapped_banana_samples[:, 2],
            label="Mapped Samples", alpha=0.5, color=2,
            title="Transport Map Generated Samples",
            xlabel="x₁", ylabel="x₂")

plot(p11, size=(800, 400))

Sample Comparison

and the resulting samples in standard normal space:

p12 = scatter(norm_samples[:, 1], norm_samples[:, 2],
            label="Original Samples", alpha=0.5, color=1,
            title="Original Banana Distribution Samples",
            xlabel="x₁", ylabel="x₂")

scatter!(p12, mapped_samples[:, 1], mapped_samples[:, 2],
            label="Mapped Samples", alpha=0.5, color=2,
            title="Transport Map Generated Samples",
            xlabel="x₁", ylabel="x₂")

plot(p12, size=(800, 400))

Sample Comparison

Density Comparison

We can also compare the learned density (via pullback) with the true density:

x₁ = range(-3, 3, length=100)
x₂ = range(-2.5, 4.0, length=100)
-2.5:0.06565656565656566:4.0

True banana density values:

true_density = [banana_density([x1, x2]) for x2 in x₂, x1 in x₁]
100×100 Matrix{Float64}:
 3.38667e-32  2.38559e-30  1.35878e-28  …  2.38559e-30  3.38667e-32
 7.19036e-32  4.94664e-30  2.75301e-28     4.94664e-30  7.19036e-32
 1.52005e-31  1.0213e-29   5.55386e-28     1.0213e-29   1.52005e-31
 3.19956e-31  2.09953e-29  1.1156e-27      2.09953e-29  3.19956e-31
 6.70583e-31  4.29754e-29  2.23127e-27     4.29754e-29  6.70583e-31
 1.3994e-30   8.75882e-29  4.44349e-27  …  8.75882e-29  1.3994e-30
 2.90776e-30  1.77746e-28  8.81094e-27     1.77746e-28  2.90776e-30
 6.01595e-30  3.59153e-28  1.7396e-26      3.59153e-28  6.01595e-30
 1.2393e-29   7.22585e-28  3.41982e-26     7.22585e-28  1.2393e-29
 2.54202e-29  1.44752e-27  6.694e-26       1.44752e-27  2.54202e-29
 ⋮                                      ⋱               
 4.15273e-10  3.40506e-9   2.35888e-8      3.40506e-9   4.15273e-10
 5.95589e-10  4.76951e-9   3.22849e-8      4.76951e-9   5.95589e-10
 8.50524e-10  6.65197e-9   4.39968e-8      6.65197e-9   8.50524e-10
 1.20936e-9   9.2375e-9    5.96995e-8      9.2375e-9    1.20936e-9
 1.71219e-9   1.27728e-8   8.06582e-8   …  1.27728e-8   1.71219e-9
 2.41365e-9   1.75852e-8   1.08506e-7      1.75852e-8   2.41365e-9
 3.38787e-9   2.41065e-8   1.4534e-7       2.41065e-8   3.38787e-9
 4.73485e-9   3.2904e-8    1.93842e-7      3.2904e-8    4.73485e-9
 6.58892e-9   4.47191e-8   2.57416e-7      4.47191e-8   6.58892e-9

Learned density via pullback through the transport map. Note that "pullback" computes the density of the mapped samples in the standard normal space:

learned_density = [pullback(M, [x1, x2]) for x2 in x₂, x1 in x₁]
100×100 Matrix{Float64}:
 2.50122e-36  2.94501e-34  2.74183e-32  …  1.41464e-32  1.47263e-34
 6.05018e-36  6.92668e-34  6.27421e-32     3.08527e-32  3.29148e-34
 1.4561e-35   1.62095e-33  1.42851e-31     6.69972e-32  7.32501e-34
 3.48676e-35  3.77417e-33  3.23605e-31     1.44855e-31  1.6231e-33
 8.30722e-35  8.74329e-33  7.29372e-31     3.11833e-31  3.58098e-33
 1.9692e-34   2.01526e-32  1.63563e-30  …  6.68377e-31  7.86635e-33
 4.64435e-34  4.62153e-32  3.64939e-30     1.42636e-30  1.72052e-32
 1.08982e-33  1.05448e-31  8.10128e-30     3.0307e-30   3.74679e-32
 2.54439e-33  2.3938e-31   1.7893e-29      6.41154e-30  8.12401e-32
 5.91025e-33  5.40667e-31  3.93193e-29     1.35047e-29  1.75384e-31
 ⋮                                      ⋱               
 1.22896e-10  1.12823e-9   8.65197e-9      1.08949e-9   1.12051e-10
 1.84307e-10  1.64524e-9   1.22755e-8      1.57459e-9   1.66182e-10
 2.74897e-10  2.38609e-9   1.73217e-8      2.26498e-9   2.45309e-10
 4.07778e-10  3.44167e-9   2.4309e-8       3.24277e-9   3.60414e-10
 6.01588e-10  4.93712e-9   3.39286e-8   …  4.6208e-9    5.27044e-10
 8.82664e-10  7.04367e-9   4.70961e-8      6.5534e-9    7.67092e-10
 1.28798e-9   9.99406e-9   6.50164e-8      9.25047e-9   1.11122e-9
 1.86913e-9   1.41027e-8   8.92642e-8      1.29959e-8   1.60216e-9
 2.69764e-9   1.97914e-8   1.21884e-7      1.81715e-8   2.29911e-9

Create contour plots for comparison:

p3 = contour(x₁, x₂, true_density,
            title="True Banana Density",
            xlabel="x₁", ylabel="x₂",
            colormap=:viridis, levels=10)

p4 = contour(x₁, x₂, learned_density,
            title="Learned Density (Pullback)",
            xlabel="x₁", ylabel="x₂",
            colormap=:viridis, levels=10)

plot(p3, p4, layout=(1, 2), size=(800, 400))

Density Comparison

Combined Visualization

Finally, let's create a combined plot showing both the original samples and the density contours:

scatter(target_samples[:, 1], target_samples[:, 2],
        label="Original Samples", alpha=0.3, color=1,
        xlabel="x₁", ylabel="x₂",
        title="Banana Distribution: Samples and Learned Density")

contour!(x₁, x₂, learned_density./maximum(learned_density),
        levels=5, colormap=:viridis, alpha=0.8,
        label="Learned Density Contours")

xlims!(-3, 3)
ylims!(-2.5, 4.0)

Combined Result

Quality Assessment

We can assess the quality of our sample-based approximation by comparing statistics of the original and mapped samples:

println("\nSample Statistics Comparison:")
println("Original samples - Mean: ", Distributions.mean(target_samples, dims=1))
println("Original samples - Std:  ", Distributions.std(target_samples, dims=1))
println("Mapped samples - Mean:   ", Distributions.mean(mapped_samples, dims=1))
println("Mapped samples - Std:    ", Distributions.std(mapped_samples, dims=1))

Sample Statistics Comparison:
Original samples - Mean: [0.09452083278294168 0.8001840789955877]
Original samples - Std:  [0.9025665298226256 1.455511399656578]
Mapped samples - Mean:   [-0.1233219235694241 -0.010209446099293619]
Mapped samples - Std:    [1.0167904276237183 1.0510315264614853]

Interpretation

The sample-based approach learns the transport map by fitting to the empirical distribution of the samples. This method is particularly useful when:

  • The target density is unknown or expensive to evaluate
  • Only sample data is available from experiments or simulations
  • The distribution is complex and difficult to express analytically

The quality of the approximation depends on:

  • The number and quality of the original samples
  • The polynomial degree of the transport map
  • The optimization algorithm and convergence criteria

Further Experiments

You can experiment with:

  • Different polynomial degrees for more complex distributions
  • Different rectifier functions (Softplus(), ShiftedELU())
  • More sophisticated MCMC sampling strategies
  • Cross-validation techniques to assess generalization
  • Different sample sizes to study convergence behavior