scipy.stats.wasserstein_distance#
- scipy.stats.wasserstein_distance(u_values, v_values, u_weights=None, v_weights=None)[source]#
Compute the first Wasserstein distance between two 1D distributions.
This distance is also known as the earth mover’s distance, since it can be seen as the minimum amount of “work” required to transform \(u\) into \(v\), where “work” is measured as the amount of distribution weight that must be moved, multiplied by the distance it has to be moved.
New in version 1.0.0.
- Parameters:
- u_values, v_valuesarray_like
Values observed in the (empirical) distribution.
- u_weights, v_weightsarray_like, optional
Weight for each value. If unspecified, each value is assigned the same weight. u_weights (resp. v_weights) must have the same length as u_values (resp. v_values). If the weight sum differs from 1, it must still be positive and finite so that the weights can be normalized to sum to 1.
- Returns:
- distancefloat
The computed distance between the distributions.
Notes
The first Wasserstein distance between the distributions \(u\) and \(v\) is:
\[l_1 (u, v) = \inf_{\pi \in \Gamma (u, v)} \int_{\mathbb{R} \times \mathbb{R}} |x-y| \mathrm{d} \pi (x, y)\]where \(\Gamma (u, v)\) is the set of (probability) distributions on \(\mathbb{R} \times \mathbb{R}\) whose marginals are \(u\) and \(v\) on the first and second factors respectively.
If \(U\) and \(V\) are the respective CDFs of \(u\) and \(v\), this distance also equals to:
\[l_1(u, v) = \int_{-\infty}^{+\infty} |U-V|\]See [2] for a proof of the equivalence of both definitions.
The input distributions can be empirical, therefore coming from samples whose values are effectively inputs of the function, or they can be seen as generalized functions, in which case they are weighted sums of Dirac delta functions located at the specified values.
References
[1]“Wasserstein metric”, https://en.wikipedia.org/wiki/Wasserstein_metric
[2]Ramdas, Garcia, Cuturi “On Wasserstein Two Sample Testing and Related Families of Nonparametric Tests” (2015). arXiv:1509.02237.
Examples
>>> from scipy.stats import wasserstein_distance >>> wasserstein_distance([0, 1, 3], [5, 6, 8]) 5.0 >>> wasserstein_distance([0, 1], [0, 1], [3, 1], [2, 2]) 0.25 >>> wasserstein_distance([3.4, 3.9, 7.5, 7.8], [4.5, 1.4], ... [1.4, 0.9, 3.1, 7.2], [3.2, 3.5]) 4.0781331438047861