Usage#
Just use it as an (almost) drop-in replacement for baryrat
. Note how the nodes, values and weights of the barycentric approximation are returned instead of a callable barycentric rational.
from diffaaable import aaa
import jax
import jax.numpy as jnp
### sample points ###
z_k_r = z_k_i = jnp.linspace(0, 3, 20)
Z_k_r, Z_k_i = jnp.meshgrid(z_k_r, z_k_r)
z_k = Z_k_r + 1j*Z_k_i
### function to be approximated ###
def f(x, a):
return jnp.tan(a*x)
f_pi = jax.tree_util.Partial(f, jnp.pi)
### alternatively use pre-calculated function values ###
z_j, f_j, w_j, z_n = aaa(z_k, f_pi(z_k))
z_n
array([ 7.54560590666641-9.21209430143431e+00j, -4.38199026211057-9.35421717102445e+00j, 7.76091436976603-2.40067271413564e+00j, 6.29895508178859-4.10638898230448e-01j, 5.47871396438639+1.67056245828073e-02j, -4.73367678785584-2.47426506847172e+00j, 4.50004885877796-2.93450413083927e-05j, 3.50000000030108+8.46691784414970e-10j, -3.28011054196975-4.33694068666063e-01j, 2.5 -9.97797541005238e-17j, -2.47548936742825+2.00143333230860e-02j, -1.50006438520528-4.48103190576403e-05j, 1.5 +1.25893302335084e-16j, -0.50000000059607+1.64574177138415e-09j, 0.5 +4.86626530780097e-16j])
Gradients#
diffaaable
is JAX differentiable. Thus you can freely compose it with other JAX functionality and obtain gradients.
def loss(a):
f_k = f(z_k, a)
z_j, f_j, w_j, z_n = aaa(z_k, f_k)
selected_poles = z_n[z_n.real>1e-2]
relevant_pole = selected_poles[jnp.argmin(selected_poles.real)]
return jnp.real(relevant_pole - 2)
g = jax.grad(loss)
g(jnp.pi/2)
Array(-7.25252718023279, dtype=float64, weak_type=True)