Map from Samples Example
This example demonstrates how to use TransportMaps.jl to approximate a "banana" distribution using polynomial transport maps when only samples from the target distribution are available.
Unlike the density-based approach, this method learns the transport map directly from sample data using optimization techniques. This is particularly useful when the target density is unknown or difficult to evaluate [1].
We start with the necessary packages:
using TransportMaps
using Distributions
using LinearAlgebra
using Plots
Generating Target Samples
The banana distribution has the density:
\[p(x) = \phi(x_1) \cdot \phi(x_2 - x_1^2)\]
where $\phi$ is the standard normal PDF.
banana_density(x) = pdf(Normal(), x[1]) * pdf(Normal(), x[2] - x[1]^2)
banana_density (generic function with 1 method)
Set up the log-target function for sampling:
num_samples = 500
Generate samples using rejection sampling (no external dependencies)
function generate_banana_samples(n_samples::Int)
samples = Matrix{Float64}(undef, n_samples, 2)
count = 0
while count < n_samples
x1 = randn() * 2
x2 = randn() * 3 + x1^2
if rand() < banana_density([x1, x2]) / 0.4
count += 1
samples[count, :] = [x1, x2]
end
end
return samples
end
println("Generating samples from banana distribution...")
target_samples = generate_banana_samples(num_samples)
println("Generated $(size(target_samples, 1)) samples")
Generating samples from banana distribution...
Generated 500 samples
Creating the Transport Map
We create a 2-dimensional polynomial transport map with degree 2. For sample-based optimization, we typically start with lower degrees and can increase complexity as needed.
M = PolynomialMap(2, 2)
PolynomialMap:
Dimensions: 2
Total coefficients: 9
Reference density: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
Maximum degree: 2
Basis: Hermite
Rectifier: Softplus
Components:
Component 1: 3 basis functions
Component 2: 6 basis functions
Coefficients: uninitialized
Optimizing from Samples
The key difference from density-based optimization is that we optimize directly from the sample data without requiring the density function. Inside the optimization the map is arranged s.t. the "forward" direction is from the (unknown) target distribution to the standard normal distribution:
@time res = optimize!(M, target_samples)
println("Optimization result: ", res)
6.065906 seconds (94.94 M allocations: 5.540 GiB, 6.23% gc time, 7.16% compilation time)
Optimization result: * Status: success
* Candidate solution
Final objective value: 2.654886e+00
* Found with
Algorithm: L-BFGS
* Convergence measures
|x - x'| = 5.01e-08 ≰ 0.0e+00
|x - x'|/|x'| = 4.78e-08 ≰ 0.0e+00
|f(x) - f(x')| = 4.88e-15 ≰ 0.0e+00
|f(x) - f(x')|/|f(x')| = 1.84e-15 ≰ 0.0e+00
|g(x)| = 4.02e-10 ≤ 1.0e-08
* Work counters
Seconds run: 6 (vs limit Inf)
Iterations: 19
f(x) calls: 58
∇f(x) calls: 58
Testing the Map
Let's generate new samples from the banana density and standard normal samples and map them through our optimized transport map to verify the learned distribution:
new_samples = generate_banana_samples(1000)
norm_samples = randn(1000, 2)
1000×2 Matrix{Float64}:
0.0696263 0.569546
-0.0118946 0.0401304
1.6895 1.43094
0.793694 -1.12277
-1.05876 -0.0345062
-0.767916 -0.162202
-0.374317 0.0290939
-0.851917 -1.14503
1.18843 0.289716
0.169474 -1.79299
⋮
0.351009 -0.0851662
0.915733 -0.274554
1.30743 0.744332
0.246186 -0.673974
-0.923339 -0.586072
-0.267338 -0.437785
-1.71726 2.48609
-0.847807 -1.31439
-0.629006 0.75146
Map the samples through our transport map. Note that evaluate
now transports from reference to target, i.e. mapped_samples
should be standard normal samples:
mapped_samples = evaluate(M, new_samples)
1000×2 Matrix{Float64}:
-1.47742 -1.0165
0.645277 -1.84858
-0.698264 2.33773
-2.43713 0.808078
1.09501 0.339164
-1.55821 -0.672894
0.588278 -0.487786
-1.04991 -0.981329
1.7823 -0.965122
-0.0817615 -1.6889
⋮
-0.317394 0.709783
1.29987 0.00889054
-1.92629 1.27716
-0.171996 -0.0822702
-0.598767 -0.158228
-0.299139 0.580278
0.214492 -0.103406
-1.08893 0.655359
-0.272131 -0.475473
while pushing from the standard normal samples to the target distribution generates new samples from the banana distribution:
mapped_banana_samples = inverse(M, norm_samples)
1000×2 Matrix{Float64}:
0.145652 0.559234
0.0721646 0.0535075
1.63916 3.96336
0.805268 -0.411785
-0.858081 0.683846
-0.602099 0.211266
-0.252683 0.100732
-0.676223 -0.601504
1.17024 1.61134
0.235871 -1.60682
⋮
0.400501 0.0858838
0.917691 0.570123
1.28102 2.30069
0.305345 -0.526749
-0.739124 -0.0020765
-0.157109 -0.369033
-1.4309 4.17815
-0.6726 -0.762244
-0.479178 0.923286
Visualizing Results
Let's create a scatter plot comparing the original samples with the mapped samples to see how well our transport map learned the distribution:
p11 = scatter(new_samples[:, 1], new_samples[:, 2],
label="Original Samples", alpha=0.5, color=1,
title="Original Banana Distribution Samples",
xlabel="x₁", ylabel="x₂")
scatter!(p11, mapped_banana_samples[:, 1], mapped_banana_samples[:, 2],
label="Mapped Samples", alpha=0.5, color=2,
title="Transport Map Generated Samples",
xlabel="x₁", ylabel="x₂")
plot(p11, size=(800, 400))
and the resulting samples in standard normal space:
p12 = scatter(norm_samples[:, 1], norm_samples[:, 2],
label="Original Samples", alpha=0.5, color=1,
title="Original Banana Distribution Samples",
xlabel="x₁", ylabel="x₂")
scatter!(p12, mapped_samples[:, 1], mapped_samples[:, 2],
label="Mapped Samples", alpha=0.5, color=2,
title="Transport Map Generated Samples",
xlabel="x₁", ylabel="x₂")
plot(p12, size=(800, 400))
Density Comparison
We can also compare the learned density (via pullback) with the true density:
x₁ = range(-3, 3, length=100)
x₂ = range(-2.5, 4.0, length=100)
-2.5:0.06565656565656566:4.0
True banana density values:
true_density = [banana_density([x1, x2]) for x2 in x₂, x1 in x₁]
100×100 Matrix{Float64}:
3.38667e-32 2.38559e-30 1.35878e-28 … 2.38559e-30 3.38667e-32
7.19036e-32 4.94664e-30 2.75301e-28 4.94664e-30 7.19036e-32
1.52005e-31 1.0213e-29 5.55386e-28 1.0213e-29 1.52005e-31
3.19956e-31 2.09953e-29 1.1156e-27 2.09953e-29 3.19956e-31
6.70583e-31 4.29754e-29 2.23127e-27 4.29754e-29 6.70583e-31
1.3994e-30 8.75882e-29 4.44349e-27 … 8.75882e-29 1.3994e-30
2.90776e-30 1.77746e-28 8.81094e-27 1.77746e-28 2.90776e-30
6.01595e-30 3.59153e-28 1.7396e-26 3.59153e-28 6.01595e-30
1.2393e-29 7.22585e-28 3.41982e-26 7.22585e-28 1.2393e-29
2.54202e-29 1.44752e-27 6.694e-26 1.44752e-27 2.54202e-29
⋮ ⋱
4.15273e-10 3.40506e-9 2.35888e-8 3.40506e-9 4.15273e-10
5.95589e-10 4.76951e-9 3.22849e-8 4.76951e-9 5.95589e-10
8.50524e-10 6.65197e-9 4.39968e-8 6.65197e-9 8.50524e-10
1.20936e-9 9.2375e-9 5.96995e-8 9.2375e-9 1.20936e-9
1.71219e-9 1.27728e-8 8.06582e-8 … 1.27728e-8 1.71219e-9
2.41365e-9 1.75852e-8 1.08506e-7 1.75852e-8 2.41365e-9
3.38787e-9 2.41065e-8 1.4534e-7 2.41065e-8 3.38787e-9
4.73485e-9 3.2904e-8 1.93842e-7 3.2904e-8 4.73485e-9
6.58892e-9 4.47191e-8 2.57416e-7 4.47191e-8 6.58892e-9
Learned density via pullback through the transport map. Note that "pullback" computes the density of the mapped samples in the standard normal space:
learned_density = [pullback(M, [x1, x2]) for x2 in x₂, x1 in x₁]
100×100 Matrix{Float64}:
2.50122e-36 2.94501e-34 2.74183e-32 … 1.41464e-32 1.47263e-34
6.05018e-36 6.92668e-34 6.27421e-32 3.08527e-32 3.29148e-34
1.4561e-35 1.62095e-33 1.42851e-31 6.69972e-32 7.32501e-34
3.48676e-35 3.77417e-33 3.23605e-31 1.44855e-31 1.6231e-33
8.30722e-35 8.74329e-33 7.29372e-31 3.11833e-31 3.58098e-33
1.9692e-34 2.01526e-32 1.63563e-30 … 6.68377e-31 7.86635e-33
4.64435e-34 4.62153e-32 3.64939e-30 1.42636e-30 1.72052e-32
1.08982e-33 1.05448e-31 8.10128e-30 3.0307e-30 3.74679e-32
2.54439e-33 2.3938e-31 1.7893e-29 6.41154e-30 8.12401e-32
5.91025e-33 5.40667e-31 3.93193e-29 1.35047e-29 1.75384e-31
⋮ ⋱
1.22896e-10 1.12823e-9 8.65197e-9 1.08949e-9 1.12051e-10
1.84307e-10 1.64524e-9 1.22755e-8 1.57459e-9 1.66182e-10
2.74897e-10 2.38609e-9 1.73217e-8 2.26498e-9 2.45309e-10
4.07778e-10 3.44167e-9 2.4309e-8 3.24277e-9 3.60414e-10
6.01588e-10 4.93712e-9 3.39286e-8 … 4.6208e-9 5.27044e-10
8.82664e-10 7.04367e-9 4.70961e-8 6.5534e-9 7.67092e-10
1.28798e-9 9.99406e-9 6.50164e-8 9.25047e-9 1.11122e-9
1.86913e-9 1.41027e-8 8.92642e-8 1.29959e-8 1.60216e-9
2.69764e-9 1.97914e-8 1.21884e-7 1.81715e-8 2.29911e-9
Create contour plots for comparison:
p3 = contour(x₁, x₂, true_density,
title="True Banana Density",
xlabel="x₁", ylabel="x₂",
colormap=:viridis, levels=10)
p4 = contour(x₁, x₂, learned_density,
title="Learned Density (Pullback)",
xlabel="x₁", ylabel="x₂",
colormap=:viridis, levels=10)
plot(p3, p4, layout=(1, 2), size=(800, 400))
Combined Visualization
Finally, let's create a combined plot showing both the original samples and the density contours:
scatter(target_samples[:, 1], target_samples[:, 2],
label="Original Samples", alpha=0.3, color=1,
xlabel="x₁", ylabel="x₂",
title="Banana Distribution: Samples and Learned Density")
contour!(x₁, x₂, learned_density./maximum(learned_density),
levels=5, colormap=:viridis, alpha=0.8,
label="Learned Density Contours")
xlims!(-3, 3)
ylims!(-2.5, 4.0)
Quality Assessment
We can assess the quality of our sample-based approximation by comparing statistics of the original and mapped samples:
println("\nSample Statistics Comparison:")
println("Original samples - Mean: ", Distributions.mean(target_samples, dims=1))
println("Original samples - Std: ", Distributions.std(target_samples, dims=1))
println("Mapped samples - Mean: ", Distributions.mean(mapped_samples, dims=1))
println("Mapped samples - Std: ", Distributions.std(mapped_samples, dims=1))
Sample Statistics Comparison:
Original samples - Mean: [0.09452083278294168 0.8001840789955877]
Original samples - Std: [0.9025665298226256 1.455511399656578]
Mapped samples - Mean: [-0.1233219235694241 -0.010209446099293619]
Mapped samples - Std: [1.0167904276237183 1.0510315264614853]
Interpretation
The sample-based approach learns the transport map by fitting to the empirical distribution of the samples. This method is particularly useful when:
- The target density is unknown or expensive to evaluate
- Only sample data is available from experiments or simulations
- The distribution is complex and difficult to express analytically
The quality of the approximation depends on:
- The number and quality of the original samples
- The polynomial degree of the transport map
- The optimization algorithm and convergence criteria
Further Experiments
You can experiment with:
- Different polynomial degrees for more complex distributions
- Different rectifier functions (
Softplus()
,ShiftedELU()
) - More sophisticated MCMC sampling strategies
- Cross-validation techniques to assess generalization
- Different sample sizes to study convergence behavior