from javiche import jaxit
from jax import grad
Simple Tests
some simple tests to check whether the fundamental functionality is working
@jaxit()
def plus(A, B):
"""adds two numbers/arrays"""
return A+B
@jaxit()
def times(A, B):
"""multiplies two numbers/arrays"""
return A*B
@jaxit()
def square(A):
"""squares number/array"""
return A**2
assert grad(plus, argnums=[0,1])(2.0,2.0) == (1,1)
assert grad(times, argnums=[0,1])(2.0,3.0) == (3,2)
assert grad(square, argnums=0)(2.0) == 4
import time
import timeit
Caching
We can also use caching to avoid recalculating the result of a function with the same input parameters.
@jaxit(cache=True)
def cubed(A):
"""computes A to the power of three"""
1)
time.sleep(return A**3
1.01 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
342 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
At the same time we maintain the ability to calculate gradients on it using jax
1.55 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
1.01 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
assert grad(cubed, argnums=0)(2.0) == 12
Note that this functionality relies on all inputs to the function being hashable. Additionally also JAX Arrays and numpy arrays are supported