Simple Tests

some simple tests to check whether the fundamental functionality is working
from javiche import jaxit
from jax import grad
@jaxit()
def plus(A, B):
  """adds two numbers/arrays"""
  return A+B

@jaxit()
def times(A, B):
  """multiplies two numbers/arrays"""
  return A*B

@jaxit()
def square(A):
  """squares number/array"""
  return A**2
assert grad(plus, argnums=[0,1])(2.0,2.0) == (1,1)
assert grad(times, argnums=[0,1])(2.0,3.0) == (3,2)
assert grad(square, argnums=0)(2.0) == 4
import time
import timeit

Caching

We can also use caching to avoid recalculating the result of a function with the same input parameters.

@jaxit(cache=True)
def cubed(A):
  """computes A to the power of three"""
  time.sleep(1)
  return A**3
1.01 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
342 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

At the same time we maintain the ability to calculate gradients on it using jax

1.55 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
1.01 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
assert grad(cubed, argnums=0)(2.0) == 12

Note that this functionality relies on all inputs to the function being hashable. Additionally also JAX Arrays and numpy arrays are supported