Array API Standard Support: signal#
This page explains some caveats of the signal module and provides (currently
incomplete) tables about the
CPU,
GPU and
JIT support.
Caveats#
JAX and CuPy provide alternative
implementations for some signal functions. When such a function is called, a
decorator decides which implementation to use by inspecting the xp parameter.
Hence, there can be, especially during CI testing, discrepancies in behavior between the default NumPy-based implementation and the JAX and CuPy backends. Skipping the incompatible backends in unit tests, as described in the Adding tests section, is the currently recommended workaround.
The functions are decorated by the code in file
scipy/signal/_support_alternative_backends.py:
1import functools
2import types
3from scipy._lib._array_api import (
4 is_cupy, is_jax, scipy_namespace_for, SCIPY_ARRAY_API, xp_capabilities
5)
6
7from ._signal_api import * # noqa: F403
8from . import _signal_api
9from . import _delegators
10__all__ = _signal_api.__all__
11
12
13MODULE_NAME = 'signal'
14
15# jax.scipy.signal has only partial coverage of scipy.signal, so we keep the list
16# of functions we can delegate to JAX
17# https://jax.readthedocs.io/en/latest/jax.scipy.html
18JAX_SIGNAL_FUNCS = [
19 'fftconvolve', 'convolve', 'convolve2d', 'correlate', 'correlate2d',
20 'csd', 'detrend', 'istft', 'welch'
21]
22
23# some cupyx.scipy.signal functions are incompatible with their scipy counterparts
24CUPY_BLACKLIST = [
25 'lfilter_zi', 'sosfilt_zi', 'get_window', 'besselap', 'envelope', 'remez', 'bessel'
26]
27
28# freqz_sos is a sosfreqz rename, and cupy does not have the new name yet (in v13.x)
29CUPY_RENAMES = {'freqz_sos': 'sosfreqz'}
30
31
32def delegate_xp(delegator, module_name):
33 def inner(func):
34 @functools.wraps(func)
35 def wrapper(*args, **kwds):
36 try:
37 xp = delegator(*args, **kwds)
38 except TypeError:
39 # object arrays
40 if func.__name__ == "tf2ss":
41 import numpy as np
42 xp = np
43 else:
44 raise
45
46 # try delegating to a cupyx/jax namesake
47 if is_cupy(xp) and func.__name__ not in CUPY_BLACKLIST:
48 func_name = CUPY_RENAMES.get(func.__name__, func.__name__)
49
50 # https://github.com/cupy/cupy/issues/8336
51 import importlib
52 cupyx_module = importlib.import_module(f"cupyx.scipy.{module_name}")
53 cupyx_func = getattr(cupyx_module, func_name)
54 kwds.pop('xp', None)
55 return cupyx_func(*args, **kwds)
56 elif is_jax(xp) and func.__name__ in JAX_SIGNAL_FUNCS:
57 spx = scipy_namespace_for(xp)
58 jax_module = getattr(spx, module_name)
59 jax_func = getattr(jax_module, func.__name__)
60 kwds.pop('xp', None)
61 return jax_func(*args, **kwds)
62 else:
63 # the original function
64 return func(*args, **kwds)
65 return wrapper
66 return inner
67
68
69# Although most of these functions currently exist in CuPy and some in JAX,
70# there are no alternative backend tests for any of them in the current
71# test suite. Each will be documented as np_only until tests are added.
72untested = {
73 "argrelextrema",
74 "argrelmax",
75 "argrelmin",
76 "band_stop_obj",
77 "check_NOLA",
78 "chirp",
79 "coherence",
80 "csd",
81 "czt_points",
82 "dbode",
83 "dfreqresp",
84 "dlsim",
85 "dstep",
86 "find_peaks",
87 "find_peaks_cwt",
88 "findfreqs",
89 "freqresp",
90 "gausspulse",
91 "lombscargle",
92 "lsim",
93 "max_len_seq",
94 "peak_prominences",
95 "peak_widths",
96 "periodogram",
97 "place_pols",
98 "sawtooth",
99 "sepfir2d",
100 "square",
101 "ss2tf",
102 "ss2zpk",
103 "step",
104 "sweep_poly",
105 "symiirorder1",
106 "symiirorder2",
107 "tf2ss",
108 "unit_impulse",
109 "welch",
110 "zoom_fft",
111 "zpk2ss",
112}
113
114
115def get_default_capabilities(func_name, delegator):
116 if delegator is None or func_name in untested:
117 return xp_capabilities(np_only=True)
118 return xp_capabilities()
119
120bilinear_extra_note = \
121 """CuPy does not accept complex inputs.
122
123 """
124
125uses_choose_conv_extra_note = \
126 """CuPy does not support inputs with ``ndim>1`` when ``method="auto"``
127 but does support higher dimensional arrays for ``method="direct"``
128 and ``method="fft"``.
129
130 """
131
132resample_poly_extra_note = \
133 """CuPy only supports ``padtype="constant"``.
134
135 """
136
137upfirdn_extra_note = \
138 """CuPy only supports ``mode="constant"`` and ``cval=0.0``.
139
140 """
141
142xord_extra_note = \
143 """The ``torch`` backend on GPU does not support the case where
144 `wp` and `ws` specify a Bandstop filter.
145
146 """
147
148convolve2d_extra_note = \
149 """The JAX backend only supports ``boundary="fill"`` and ``fillvalue=0``.
150
151 """
152
153zpk2tf_extra_note = \
154 """The CuPy and JAX backends both support only 1d input.
155
156 """
157
158capabilities_overrides = {
159 "bessel": xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True),
160 "bilinear": xp_capabilities(cpu_only=True, exceptions=["cupy"],
161 jax_jit=False, allow_dask_compute=True,
162 reason="Uses np.polynomial.Polynomial",
163 extra_note=bilinear_extra_note),
164 "bilinear_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
165 jax_jit=False, allow_dask_compute=True),
166 "butter": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
167 allow_dask_compute=True),
168 "buttord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
169 jax_jit=False, allow_dask_compute=True,
170 extra_note=xord_extra_note),
171 "cheb1ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
172 jax_jit=False, allow_dask_compute=True,
173 extra_note=xord_extra_note),
174 "cheb2ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
175 jax_jit=False, allow_dask_compute=True,
176 extra_note=xord_extra_note),
177 "cheby1": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
178 allow_dask_compute=True),
179
180 "cheby2": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
181 allow_dask_compute=True),
182 "cont2discrete": xp_capabilities(np_only=True, exceptions=["cupy"]),
183 "convolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
184 allow_dask_compute=True,
185 extra_note=uses_choose_conv_extra_note),
186 "convolve2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
187 allow_dask_compute=True,
188 extra_note=convolve2d_extra_note),
189 "correlate": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
190 allow_dask_compute=True,
191 extra_note=uses_choose_conv_extra_note),
192 "correlate2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
193 allow_dask_compute=True,
194 extra_note=convolve2d_extra_note),
195 "correlation_lags": xp_capabilities(out_of_scope=True),
196 "cspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
197 jax_jit=False, allow_dask_compute=True),
198 "cspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
199 jax_jit=False, allow_dask_compute=True),
200 "cspline2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
201 jax_jit=False, allow_dask_compute=True),
202 "czt": xp_capabilities(np_only=True, exceptions=["cupy"]),
203 "deconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy"],
204 allow_dask_compute=True,
205 skip_backends=[("jax.numpy", "item assignment")]),
206 "decimate": xp_capabilities(np_only=True, exceptions=["cupy"]),
207 "detrend": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
208 allow_dask_compute=True),
209 "dimpulse": xp_capabilities(np_only=True, exceptions=["cupy"]),
210 "dlti": xp_capabilities(np_only=True,
211 reason="works in CuPy but delegation isn't set up yet"),
212 "ellip": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
213 allow_dask_compute=True,
214 reason="scipy.special.ellipk"),
215 "ellipord": xp_capabilities(cpu_only=True, exceptions=["cupy"],
216 jax_jit=False, allow_dask_compute=True,
217 reason="scipy.special.ellipk"),
218 "firls": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False,
219 reason="lstsq"),
220 "firwin": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
221 jax_jit=False, allow_dask_compute=True),
222 "firwin2": xp_capabilities(cpu_only=True, exceptions=["cupy"],
223 jax_jit=False, allow_dask_compute=True,
224 reason="firwin uses np.interp"),
225 "fftconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"]),
226 "freqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
227 jax_jit=False, allow_dask_compute=True),
228 "freqs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
229 jax_jit=False, allow_dask_compute=True),
230 "freqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
231 jax_jit=False, allow_dask_compute=True),
232 "freqz_sos": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
233 jax_jit=False, allow_dask_compute=True),
234 "group_delay": xp_capabilities(cpu_only=True, exceptions=["cupy"],
235 jax_jit=False, allow_dask_compute=True),
236 "hilbert": xp_capabilities(
237 cpu_only=True, exceptions=["cupy", "torch"],
238 skip_backends=[("jax.numpy", "item assignment")],
239 ),
240 "hilbert2": xp_capabilities(
241 cpu_only=True, exceptions=["cupy", "torch"],
242 skip_backends=[("jax.numpy", "item assignment")],
243 ),
244 "invres": xp_capabilities(np_only=True, exceptions=["cupy"]),
245 "invresz": xp_capabilities(np_only=True, exceptions=["cupy"]),
246 "iircomb": xp_capabilities(xfail_backends=[("jax.numpy", "inaccurate")]),
247 "iirfilter": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
248 jax_jit=False, allow_dask_compute=True),
249 "kaiser_atten": xp_capabilities(
250 out_of_scope=True, reason="scalars in, scalars out"
251 ),
252 "kaiser_beta": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
253 "kaiserord": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
254 "lfilter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
255 allow_dask_compute=True, jax_jit=False),
256 "lfilter_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
257 jax_jit=False),
258 "lfiltic": xp_capabilities(cpu_only=True, exceptions=["cupy"],
259 allow_dask_compute=True),
260 "lp2bp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
261 allow_dask_compute=True,
262 skip_backends=[("jax.numpy", "in-place item assignment")]),
263 "lp2bp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
264 allow_dask_compute=True, jax_jit=False),
265 "lp2bs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
266 allow_dask_compute=True,
267 skip_backends=[("jax.numpy", "in-place item assignment")]),
268 "lp2bs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
269 allow_dask_compute=True, jax_jit=False),
270 "lp2lp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
271 allow_dask_compute=True, jax_jit=False),
272 "lp2lp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
273 allow_dask_compute=True, jax_jit=False),
274 "lp2hp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
275 allow_dask_compute=True,
276 skip_backends=[("jax.numpy", "in-place item assignment")]),
277 "lp2hp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
278 allow_dask_compute=True, jax_jit=False),
279 "lti": xp_capabilities(np_only=True,
280 reason="works in CuPy but delegation isn't set up yet"),
281 "medfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
282 allow_dask_compute=True, jax_jit=False,
283 reason="uses scipy.ndimage.rank_filter"),
284 "medfilt2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
285 allow_dask_compute=True, jax_jit=False,
286 reason="c extension module"),
287 "minimum_phase": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
288 allow_dask_compute=True, jax_jit=False),
289 "normalize": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
290 jax_jit=False, allow_dask_compute=True),
291 "oaconvolve": xp_capabilities(
292 cpu_only=True, exceptions=["cupy", "torch"],
293 skip_backends=[("jax.numpy", "fails all around")],
294 xfail_backends=[("dask.array", "wrong answer")],
295 ),
296 "order_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
297 allow_dask_compute=True, jax_jit=False,
298 reason="uses scipy.ndimage.rank_filter"),
299 "qspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
300 jax_jit=False, allow_dask_compute=True),
301 "qspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
302 jax_jit=False, allow_dask_compute=True),
303 "qspline2d": xp_capabilities(np_only=True, exceptions=["cupy"]),
304 "remez": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False),
305 "resample": xp_capabilities(
306 cpu_only=True, exceptions=["cupy"],
307 skip_backends=[
308 ("dask.array", "XXX something in dask"),
309 ("jax.numpy", "XXX: immutable arrays"),
310 ]
311 ),
312 "resample_poly": xp_capabilities(
313 cpu_only=True, exceptions=["cupy"],
314 jax_jit=False, skip_backends=[("dask.array", "XXX something in dask")],
315 extra_note=resample_poly_extra_note,
316 ),
317 "residue": xp_capabilities(np_only=True, exceptions=["cupy"]),
318 "residuez": xp_capabilities(np_only=True, exceptions=["cupy"]),
319 "savgol_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
320 jax_jit=False,
321 reason="convolve1d is cpu-only"),
322 "sepfir2d": xp_capabilities(np_only=True),
323 "sos2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
324 allow_dask_compute=True),
325 "sos2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
326 allow_dask_compute=True),
327 "sosfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
328 allow_dask_compute=True),
329 "sosfiltfilt": xp_capabilities(
330 cpu_only=True, exceptions=["cupy"],
331 skip_backends=[
332 (
333 "dask.array",
334 "sosfiltfilt directly sets shape attributes on arrays"
335 " which dask doesn't like"
336 ),
337 ("torch", "negative strides"),
338 ("jax.numpy", "sosfilt works in-place"),
339 ],
340 ),
341 "sosfreqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
342 jax_jit=False, allow_dask_compute=True),
343 "spline_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
344 jax_jit=False, allow_dask_compute=True),
345 "tf2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
346 allow_dask_compute=True),
347 "tf2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
348 allow_dask_compute=True),
349 "unique_roots": xp_capabilities(np_only=True, exceptions=["cupy"]),
350 "upfirdn": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
351 allow_dask_compute=True,
352 reason="Cython implementation",
353 extra_note=upfirdn_extra_note),
354 "vectorstrength": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
355 allow_dask_compute=True, jax_jit=False),
356 "wiener": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
357 allow_dask_compute=True, jax_jit=False,
358 reason="uses scipy.signal.correlate"),
359 "zpk2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
360 allow_dask_compute=True),
361 "zpk2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
362 allow_dask_compute=True,
363 extra_note=zpk2tf_extra_note),
364 "spectrogram": xp_capabilities(out_of_scope=True), # legacy
365 "stft": xp_capabilities(out_of_scope=True), # legacy
366 "istft": xp_capabilities(out_of_scope=True), # legacy
367 "check_COLA": xp_capabilities(out_of_scope=True), # legacy
368}
369
370
371# ### decorate ###
372for obj_name in _signal_api.__all__:
373 bare_obj = getattr(_signal_api, obj_name)
374 delegator = getattr(_delegators, obj_name + "_signature", None)
375
376 if SCIPY_ARRAY_API and delegator is not None:
377 f = delegate_xp(delegator, MODULE_NAME)(bare_obj)
378 else:
379 f = bare_obj
380
381 if not isinstance(f, types.ModuleType):
382 capabilities = capabilities_overrides.get(
383 obj_name, get_default_capabilities(obj_name, delegator)
384 )
385 f = capabilities(f)
386
387 # add the decorated function to the namespace, to be imported in __init__.py
388 vars()[obj_name] = f
Note that a function will only be decorated if the environment variable
SCIPY_ARRAY_API is set and its signature is listed in the file
scipy/signal/_delegators.py. E.g., for firwin, the signature
function looks like this:
340def firwin_signature(numtaps, cutoff, *args, **kwds):
341 if isinstance(cutoff, int | float):
342 xp = np_compat
343 else:
344 xp = array_namespace(cutoff)
345 return xp
Support on CPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
torch |
jax |
dask |
|---|---|---|---|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
Support on GPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
cupy |
torch |
jax |
|---|---|---|---|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
Support with JIT#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
jax |
|---|---|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✔️ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
N/A |
|
N/A |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |