jaxify
The core functionality of
javiche
. Defines the @jaxit()
decorator
np_cache
np_cache (*args, **kwargs)
LRU cache implementation for functions whose FIRST parameter is a numpy array >>> array = np.array([[1, 2, 3], [4, 5, 6]]) >>> @np_cache(maxsize=256) … def multiply(array, factor): … print(“Calculating…”) … return factor*array >>> multiply(array, 2) Calculating… array([[ 2, 4, 6], [ 8, 10, 12]]) >>> multiply(array, 2) array([[ 2, 4, 6], [ 8, 10, 12]]) >>> multiply.cache_info() CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
as_jax
as_jax (x)
as_numpy
as_numpy (x)
jaxit
jaxit (cache:bool=False)
make a function that internally uses autograd compatible to jax gradient calculations
Attention: only a single output variable is supported