Usage

Contents

Usage#

Just use it as an (almost) drop-in replacement for baryrat. Note how the nodes, values and weights of the barycentric approximation are returned instead of a callable barycentric rational.

from diffaaable import aaa
import jax
import jax.numpy as jnp

### sample points ###
z_k_r = z_k_i = jnp.linspace(0, 3, 20)
Z_k_r, Z_k_i = jnp.meshgrid(z_k_r, z_k_r)
z_k = Z_k_r + 1j*Z_k_i

### function to be approximated ###
def f(x, a):
    return jnp.tan(a*x)
f_pi = jax.tree_util.Partial(f, jnp.pi)

### alternatively use pre-calculated function values ###
z_j, f_j, w_j, z_n = aaa(z_k, f_pi(z_k))

z_n
array([ 7.5456060716143 -9.21209421136196e+00j, -4.38199039195268-9.35421724831426e+00j,  7.76091435697344-2.40067258846522e+00j,  6.29895508631654-4.10638833538889e-01j,  5.47871397547399+1.67056231112251e-02j, -4.7336768366249 -2.47426502515057e+00j,  4.50004885871537-2.93450441078303e-05j,  3.50000000030108+8.46673883206123e-10j, -3.28011056065685-4.33694051754996e-01j,  2.5             -2.88743182202112e-16j, -2.47548937007102+2.00143304761715e-02j, -1.50006438519416-4.48103159303254e-05j,  1.5             +1.76250623179801e-16j, -0.50000000059613+1.64574818329065e-09j,  0.5             +0.00000000000000e+00j])

Gradients#

diffaaable is JAX differentiable. Thus you can freely compose it with other JAX functionality and obtain gradients.

def loss(a):
    f_k = f(z_k, a)

    z_j, f_j, w_j, z_n = aaa(z_k, f_k)

    selected_poles = z_n[z_n.real>1e-2]
    relevant_pole = selected_poles[jnp.argmin(selected_poles.real)]
    return jnp.real(relevant_pole - 2)

g = jax.grad(loss)
g(jnp.pi/2)
Array(-7.25252718023279, dtype=float64, weak_type=True)