SciPy

scipy.stats.matrix_normal

scipy.stats.matrix_normal(mean=None, rowcov=1, colcov=1, seed=None) = <scipy.stats._multivariate.matrix_normal_gen object>[source]

A matrix normal random variable.

The mean keyword specifies the mean. The rowcov keyword specifies the among-row covariance matrix. The ‘colcov’ keyword specifies the among-column covariance matrix.

Parameters
Xarray_like

Quantiles, with the last two axes of X denoting the components.

meanarray_like, optional

Mean of the distribution (default: None)

rowcovarray_like, optional

Among-row covariance matrix of the distribution (default: 1)

colcovarray_like, optional

Among-column covariance matrix of the distribution (default: 1)

random_state{None, int, np.random.RandomState, np.random.Generator}, optional

Used for drawing random variates. If seed is None the RandomState singleton is used. If seed is an int, a new RandomState instance is used, seeded with seed. If seed is already a RandomState or Generator instance, then that object is used. Default is None.

Alternatively, the object may be called (as a function) to fix the mean
and covariance parameters, returning a “frozen” matrix normal
random variable:
rv = matrix_normal(mean=None, rowcov=1, colcov=1)
  • Frozen object with the same methods but holding the given mean and covariance fixed.

Notes

If mean is set to None then a matrix of zeros is used for the mean.

The dimensions of this matrix are inferred from the shape of rowcov and colcov, if these are provided, or set to 1 if ambiguous.

rowcov and colcov can be two-dimensional array_likes specifying the covariance matrices directly. Alternatively, a one-dimensional array will be be interpreted as the entries of a diagonal matrix, and a scalar or zero-dimensional array will be interpreted as this value times the identity matrix.

The covariance matrices specified by rowcov and colcov must be (symmetric) positive definite. If the samples in X are \(m \times n\), then rowcov must be \(m \times m\) and colcov must be \(n \times n\). mean must be the same shape as X.

The probability density function for matrix_normal is

\[f(X) = (2 \pi)^{-\frac{mn}{2}}|U|^{-\frac{n}{2}} |V|^{-\frac{m}{2}} \exp\left( -\frac{1}{2} \mathrm{Tr}\left[ U^{-1} (X-M) V^{-1} (X-M)^T \right] \right),\]

where \(M\) is the mean, \(U\) the among-row covariance matrix, \(V\) the among-column covariance matrix.

The allow_singular behaviour of the multivariate_normal distribution is not currently supported. Covariance matrices must be full rank.

The matrix_normal distribution is closely related to the multivariate_normal distribution. Specifically, \(\mathrm{Vec}(X)\) (the vector formed by concatenating the columns of \(X\)) has a multivariate normal distribution with mean \(\mathrm{Vec}(M)\) and covariance \(V \otimes U\) (where \(\otimes\) is the Kronecker product). Sampling and pdf evaluation are \(\mathcal{O}(m^3 + n^3 + m^2 n + m n^2)\) for the matrix normal, but \(\mathcal{O}(m^3 n^3)\) for the equivalent multivariate normal, making this equivalent form algorithmically inefficient.

New in version 0.17.0.

Examples

>>> from scipy.stats import matrix_normal
>>> M = np.arange(6).reshape(3,2); M
array([[0, 1],
       [2, 3],
       [4, 5]])
>>> U = np.diag([1,2,3]); U
array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]])
>>> V = 0.3*np.identity(2); V
array([[ 0.3,  0. ],
       [ 0. ,  0.3]])
>>> X = M + 0.1; X
array([[ 0.1,  1.1],
       [ 2.1,  3.1],
       [ 4.1,  5.1]])
>>> matrix_normal.pdf(X, mean=M, rowcov=U, colcov=V)
0.023410202050005054
>>> # Equivalent multivariate normal
>>> from scipy.stats import multivariate_normal
>>> vectorised_X = X.T.flatten()
>>> equiv_mean = M.T.flatten()
>>> equiv_cov = np.kron(V,U)
>>> multivariate_normal.pdf(vectorised_X, mean=equiv_mean, cov=equiv_cov)
0.023410202050005054

Methods

``pdf(X, mean=None, rowcov=1, colcov=1)``

Probability density function.

``logpdf(X, mean=None, rowcov=1, colcov=1)``

Log of the probability density function.

``rvs(mean=None, rowcov=1, colcov=1, size=1, random_state=None)``

Draw random samples.

Previous topic

scipy.stats.multivariate_normal

Next topic

scipy.stats.dirichlet