jaxify

The core functionality of javiche. Defines the @jaxit() decorator

source

np_cache

 np_cache (*args, **kwargs)

LRU cache implementation for functions whose FIRST parameter is a numpy array >>> array = np.array([[1, 2, 3], [4, 5, 6]]) >>> @np_cache(maxsize=256) … def multiply(array, factor): … print(“Calculating…”) … return factor*array >>> multiply(array, 2) Calculating… array([[ 2, 4, 6], [ 8, 10, 12]]) >>> multiply(array, 2) array([[ 2, 4, 6], [ 8, 10, 12]]) >>> multiply.cache_info() CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)


source

as_jax

 as_jax (x)

source

as_numpy

 as_numpy (x)

source

jaxit

 jaxit (cache:bool=False)

make a function that internally uses autograd compatible to jax gradient calculations

Attention: only a single output variable is supported