from javiche import jaxit
javiche
A JAX wrapper around ceviche to make interoperability easier. In the future it might make sense to update ceviche itself to use JAX internally.
Small package to enable using ceviche with a JAX optimizer easily.
Install
This package is not yet published. As soon as it is install with:
pip install javiche
or
conda install javiche
How to use
Import the decorator
decorate your function (will be differentiated using ceviches jacobian -> HIPS autograd)
@jaxit()
def square(A):
"""squares number/array"""
return A**2
Now you can use jax as usual:
= jax.grad(square) grad_fn
2.0) grad_fn(
Array(4., dtype=float32, weak_type=True)
In this toy example that was already possible without the jaxit()
decorator. However jaxit()
decorated functions can contain autograd operators (but no jax operators):
import autograd.numpy as npa
def sin(A):
"""computes sin of number/array using autograds numpy"""
return npa.sin(A)
= jax.grad(sin)
grad_sin try:
print(grad_sin(0.0))
except Exception as e:
print(e)
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(0.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
primal = 0.0
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[], weak_type=True), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@jaxit()
def cos(A):
"""computes sin of number/array using autograds numpy"""
return npa.cos(A)
= jax.grad(cos)
grad_cos try:
print(grad_cos(0.0))
except Exception as e:
print(e)
-0.0
Usecase
This library is intended for use with ceviche, while running a JAX optimization stack as demonstated in the inverse design example