Banana Distribution Example
This example demonstrates how to use TransportMaps.jl to approximate a "banana" distribution using polynomial transport maps.
The banana distribution is a common test case in transport map literature [1], defined as a standard normal in the first dimension and a normal distribution centered at x₁² in the second dimension. This example showcases the effectiveness of triangular transport maps for capturing nonlinear dependencies [3].
We start with the necessary packages:
using TransportMaps
using Distributions
using Plots
Creating the Transport Map
We start by creating a 2-dimensional polynomial transport map with degree 2 and a Softplus rectifier function.
M = PolynomialMap(2, 2, Normal(), Softplus())
PolynomialMap:
Dimensions: 2
Total coefficients: 9
Reference density: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
Maximum degree: 2
Basis: Hermite
Rectifier: Softplus
Components:
Component 1: 3 basis functions
Component 2: 6 basis functions
Coefficients: min=0.0, max=6.9268182312233e-310, mean=4.6178764940809e-310
Setting up Quadrature
For optimization, we need to specify quadrature weights. Here we use Gauss-Hermite quadrature with 3 points per dimension.
quadrature = GaussHermiteWeights(3, 2)
GaussHermiteWeights:
Number of points: 9
Dimensions: 2
Quadrature type: Tensor product Gauss-Hermite
Reference measure: Standard Gaussian
Weight range: [0.02777777777777786, 0.44444444444444353]
Weight sum: 0.9999999999999997
Points:
[-1.7320508075688776, -1.7320508075688776] → weight: 0.02777777777777786
[-1.2560739669470201e-15, -1.7320508075688776] → weight: 0.11111111111111116
[1.7320508075688776, -1.7320508075688776] → weight: 0.02777777777777786
[-1.7320508075688776, -1.2560739669470201e-15] → weight: 0.11111111111111116
[-1.2560739669470201e-15, -1.2560739669470201e-15] → weight: 0.44444444444444353
[1.7320508075688776, -1.2560739669470201e-15] → weight: 0.11111111111111116
[-1.7320508075688776, 1.7320508075688776] → weight: 0.02777777777777786
[-1.2560739669470201e-15, 1.7320508075688776] → weight: 0.11111111111111116
[1.7320508075688776, 1.7320508075688776] → weight: 0.02777777777777786
Defining the Target Density
The banana distribution has the density:
\[p(x) = \phi(x_1) \cdot \phi(x_2 - x_1^2)\]
where $\phi$ is the standard normal PDF.
target_density(x) = pdf(Normal(), x[1]) * pdf(Normal(), x[2] - x[1]^2)
Create a MapTargetDensity object for optimization
target = MapTargetDensity(target_density, :auto_diff)
MapTargetDensity(density=target_density, gradient_type=auto_diff)
Optimizing the Map
Now we optimize the map coefficients to approximate the target density:
@time res = optimize!(M, target, quadrature)
println("Optimization result: ", res)
4.038698 seconds (9.06 M allocations: 470.683 MiB, 1.98% gc time, 98.07% compilation time)
Optimization result: * Status: success
* Candidate solution
Final objective value: 2.837877e+00
* Found with
Algorithm: L-BFGS
* Convergence measures
|x - x'| = 1.43e-06 ≰ 0.0e+00
|x - x'|/|x'| = 1.43e-06 ≰ 0.0e+00
|f(x) - f(x')| = 2.94e-12 ≰ 0.0e+00
|f(x) - f(x')|/|f(x')| = 1.04e-12 ≰ 0.0e+00
|g(x)| = 3.45e-09 ≤ 1.0e-08
* Work counters
Seconds run: 0 (vs limit Inf)
Iterations: 12
f(x) calls: 33
∇f(x) calls: 33
Testing the Map
Let's generate some samples from the standard normal distribution and map them through our optimized transport map:
samples_z = randn(1000, 2)
1000×2 Matrix{Float64}:
0.0214098 -0.305181
-0.00413776 -0.23988
1.1182 0.121107
-1.33506 -0.782044
1.19707 -1.09262
0.252417 -0.732866
-1.26508 0.0631072
-0.65111 0.940255
-0.551836 -0.365505
-1.04201 0.522248
⋮
-0.446182 -2.03668
-0.281179 1.29208
-0.31168 0.146151
-0.103084 0.996107
-0.660835 -0.260806
-0.772434 0.0760994
-2.62014 -0.704688
-0.987289 1.31398
-1.02567 0.164213
Map the samples through our transport map:
mapped_samples = evaluate(M, samples_z)
1000×2 Matrix{Float64}:
0.0214098 -0.304722
-0.00413776 -0.239863
1.1182 1.37147
-1.33506 1.00034
1.19707 0.340357
0.252417 -0.669152
-1.26508 1.66354
-0.65111 1.3642
-0.551836 -0.0609824
-1.04201 1.60803
⋮
-0.446182 -1.83761
-0.281179 1.37114
-0.31168 0.243295
-0.103084 1.00673
-0.660835 0.175896
-0.772434 0.672754
-2.62014 6.16044
-0.987289 2.28872
-1.02567 1.21622
Visualizing Results
Let's create a scatter plot of the mapped samples to see how well our transport map approximates the banana distribution:
scatter(mapped_samples[:, 1], mapped_samples[:, 2],
label="Mapped Samples", alpha=0.5, color=2,
title="Transport Map Approximation of Banana Distribution",
xlabel="x₁", ylabel="x₂")
GKS: cannot open display - headless operation mode active
Quality Assessment
We can assess the quality of our approximation using the variance diagnostic:
var_diag = variance_diagnostic(M, target, samples_z)
println("Variance Diagnostic: ", var_diag)
Variance Diagnostic: 6.059893191728661e-18
Interpretation
The variance diagnostic provides a measure of how well the transport map approximates the target distribution. Lower values indicate better approximation.
The scatter plot should show the characteristic "banana" shape, with samples curved according to the relationship x₂ ≈ x₁².
Further Experiments
You can experiment with:
- Different polynomial degrees (see [3] for monotone map theory)
- Different rectifier functions (
IdentityRectifier()
,ShiftedELU()
) - Different quadrature methods (
MonteCarloWeights
,LatinHypercubeWeights
) - More quadrature points for higher accuracy