scipy.optimize.check_grad#
- scipy.optimize.check_grad(func, grad, x0, *args, epsilon=1.4901161193847656e-08, direction='all', seed=None)[source]#
Check the correctness of a gradient function by comparing it against a (forward) finite-difference approximation of the gradient.
- Parameters
- funccallable
func(x0, *args)
Function whose derivative is to be checked.
- gradcallable
grad(x0, *args)
Gradient of func.
- x0ndarray
Points to check grad against forward difference approximation of grad using func.
- args*args, optional
Extra arguments passed to func and grad.
- epsilonfloat, optional
Step size used for the finite difference approximation. It defaults to
sqrt(np.finfo(float).eps)
, which is approximately 1.49e-08.- directionstr, optional
If set to
'random'
, then gradients along a random vector are used to check grad against forward difference approximation using func. By default it is'all'
, in which case, all the one hot direction vectors are considered to check grad.- seed{None, int,
numpy.random.Generator
, numpy.random.RandomState
}, optionalIf seed is None (or np.random), the
numpy.random.RandomState
singleton is used. If seed is an int, a newRandomState
instance is used, seeded with seed. If seed is already aGenerator
orRandomState
instance then that instance is used. Specify seed for reproducing the return value from this function. The random numbers generated with this seed affect the random vector along which gradients are computed to checkgrad
. Note that seed is only used when direction argument is set to ‘random’.
- funccallable
- Returns
- errfloat
The square root of the sum of squares (i.e., the 2-norm) of the difference between
grad(x0, *args)
and the finite difference approximation of grad using func at the points x0.
See also
Examples
>>> def func(x): ... return x[0]**2 - 0.5 * x[1]**3 >>> def grad(x): ... return [2 * x[0], -1.5 * x[1]**2] >>> from scipy.optimize import check_grad >>> check_grad(func, grad, [1.5, -1.5]) 2.9802322387695312e-08 # may vary >>> rng = np.random.default_rng() >>> check_grad(func, grad, [1.5, -1.5], ... direction='random', seed=rng) 2.9802322387695312e-08