Banana Distribution Example

This example demonstrates how to use TransportMaps.jl to approximate a "banana" distribution using polynomial transport maps.

The banana distribution is a common test case in transport map literature [1], defined as a standard normal in the first dimension and a normal distribution centered at x₁² in the second dimension. This example showcases the effectiveness of triangular transport maps for capturing nonlinear dependencies [3].

We start with the necessary packages:

using TransportMaps
using Distributions
using Plots

Creating the Transport Map

We start by creating a 2-dimensional polynomial transport map with degree 2 and a Softplus rectifier function.

M = PolynomialMap(2, 2, Normal(), Softplus())
PolynomialMap:
  Dimensions: 2
  Total coefficients: 9
  Reference density: Distributions.Normal{Float64}(μ=0.0, σ=1.0)
  Maximum degree: 2
  Basis: Hermite
  Rectifier: Softplus
  Components:
    Component 1: 3 basis functions
    Component 2: 6 basis functions
  Coefficients: min=0.0, max=6.9268182312233e-310, mean=4.6178764940809e-310

Setting up Quadrature

For optimization, we need to specify quadrature weights. Here we use Gauss-Hermite quadrature with 3 points per dimension.

quadrature = GaussHermiteWeights(3, 2)
GaussHermiteWeights:
  Number of points: 9
  Dimensions: 2
  Quadrature type: Tensor product Gauss-Hermite
  Reference measure: Standard Gaussian
  Weight range: [0.02777777777777786, 0.44444444444444353]
  Weight sum: 0.9999999999999997
  Points:
    [-1.7320508075688776, -1.7320508075688776] → weight: 0.02777777777777786
    [-1.2560739669470201e-15, -1.7320508075688776] → weight: 0.11111111111111116
    [1.7320508075688776, -1.7320508075688776] → weight: 0.02777777777777786
    [-1.7320508075688776, -1.2560739669470201e-15] → weight: 0.11111111111111116
    [-1.2560739669470201e-15, -1.2560739669470201e-15] → weight: 0.44444444444444353
    [1.7320508075688776, -1.2560739669470201e-15] → weight: 0.11111111111111116
    [-1.7320508075688776, 1.7320508075688776] → weight: 0.02777777777777786
    [-1.2560739669470201e-15, 1.7320508075688776] → weight: 0.11111111111111116
    [1.7320508075688776, 1.7320508075688776] → weight: 0.02777777777777786

Defining the Target Density

The banana distribution has the density:

\[p(x) = \phi(x_1) \cdot \phi(x_2 - x_1^2)\]

where $\phi$ is the standard normal PDF.

target_density(x) = pdf(Normal(), x[1]) * pdf(Normal(), x[2] - x[1]^2)

Create a MapTargetDensity object for optimization

target = MapTargetDensity(target_density, :auto_diff)
MapTargetDensity(density=target_density, gradient_type=auto_diff)

Optimizing the Map

Now we optimize the map coefficients to approximate the target density:

@time res = optimize!(M, target, quadrature)
println("Optimization result: ", res)
  4.038698 seconds (9.06 M allocations: 470.683 MiB, 1.98% gc time, 98.07% compilation time)
Optimization result:  * Status: success

 * Candidate solution
    Final objective value:     2.837877e+00

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.43e-06 ≰ 0.0e+00
    |x - x'|/|x'|          = 1.43e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 2.94e-12 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.04e-12 ≰ 0.0e+00
    |g(x)|                 = 3.45e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    12
    f(x) calls:    33
    ∇f(x) calls:   33

Testing the Map

Let's generate some samples from the standard normal distribution and map them through our optimized transport map:

samples_z = randn(1000, 2)
1000×2 Matrix{Float64}:
  0.0214098   -0.305181
 -0.00413776  -0.23988
  1.1182       0.121107
 -1.33506     -0.782044
  1.19707     -1.09262
  0.252417    -0.732866
 -1.26508      0.0631072
 -0.65111      0.940255
 -0.551836    -0.365505
 -1.04201      0.522248
  ⋮           
 -0.446182    -2.03668
 -0.281179     1.29208
 -0.31168      0.146151
 -0.103084     0.996107
 -0.660835    -0.260806
 -0.772434     0.0760994
 -2.62014     -0.704688
 -0.987289     1.31398
 -1.02567      0.164213

Map the samples through our transport map:

mapped_samples = evaluate(M, samples_z)
1000×2 Matrix{Float64}:
  0.0214098   -0.304722
 -0.00413776  -0.239863
  1.1182       1.37147
 -1.33506      1.00034
  1.19707      0.340357
  0.252417    -0.669152
 -1.26508      1.66354
 -0.65111      1.3642
 -0.551836    -0.0609824
 -1.04201      1.60803
  ⋮           
 -0.446182    -1.83761
 -0.281179     1.37114
 -0.31168      0.243295
 -0.103084     1.00673
 -0.660835     0.175896
 -0.772434     0.672754
 -2.62014      6.16044
 -0.987289     2.28872
 -1.02567      1.21622

Visualizing Results

Let's create a scatter plot of the mapped samples to see how well our transport map approximates the banana distribution:

scatter(mapped_samples[:, 1], mapped_samples[:, 2],
           label="Mapped Samples", alpha=0.5, color=2,
           title="Transport Map Approximation of Banana Distribution",
           xlabel="x₁", ylabel="x₂")
GKS: cannot open display - headless operation mode active

Banana Samples

Quality Assessment

We can assess the quality of our approximation using the variance diagnostic:

var_diag = variance_diagnostic(M, target, samples_z)
println("Variance Diagnostic: ", var_diag)
Variance Diagnostic: 6.059893191728661e-18

Interpretation

The variance diagnostic provides a measure of how well the transport map approximates the target distribution. Lower values indicate better approximation.

The scatter plot should show the characteristic "banana" shape, with samples curved according to the relationship x₂ ≈ x₁².

Further Experiments

You can experiment with:

  • Different polynomial degrees (see [3] for monotone map theory)
  • Different rectifier functions (IdentityRectifier(), ShiftedELU())
  • Different quadrature methods (MonteCarloWeights, LatinHypercubeWeights)
  • More quadrature points for higher accuracy