Source code for staaax.tildify
from staaax.propagation import angled_sqrt
import jax.numpy as jnp
[docs]
def tildify(k, Cs, bcs, nan_tolerance=0, sign=1, normalize=False):
""" Coordinate transform to mitigate influence of square root type branch cuts. See also https://doi.org/10.1002/lpor.202500811
Args:
k (complex): coordinate in untransformed space
Cs (List[complex]): Branchpoints
bcs (List[float]): Direction of the branchcuts (branch angles)
nan_tolerance (int, optional): In this vicinity (in radian) to the branchcut the squareroot wil return nan. Defaults to 0.
sign (int, optional): Used to flip the transform to allow for transforming kx instead of k0. Defaults to 1.
normalize (bool, optional): If True, the transform is normalized by the sum of the branchpoints else the number of branchpoints is used for normalization. Defaults to False.
Returns:
complex: transformed coordinate
"""
if len(bcs) != len(Cs):
raise ValueError("Provide same number of branchpoints and branch angles")
if normalize:
norm = jnp.sum(jnp.array(Cs))
else:
norm = len(Cs)
return 1/norm*jnp.sum(jnp.array([
angled_sqrt(
sign*(k**2 - C**2),
bc_angle=bc,
nan_tolerance=nan_tolerance) for C, bc in zip(Cs, bcs)
]), axis=0)
[docs]
def inverse_tildify(k_tilde, branchpoints, sign=1, normalize=False, single_branch=False):
"""Inverse coordinate transform corresponding to tildify.
Args:
k_tilde (complex): Transformed space coordinate
branchpoints (List[complex]): Branchpoints
sign (int, optional): As in tildify.
normalize (bool, optional): As in tildify.
single_branch (bool, optional): Whether to return all possible branches, that could have lead to this coordinate in transformed space, or just a single one (False). Defaults to False.
Raises:
NotImplementedError: The inverse coordinate transform currently only supports 1 or 2 branchpoints.
Returns:
Complex: Untransformed space coordinate
"""
if len(branchpoints) > 2:
raise NotImplementedError("Only 1 or 2 branchpoints are supported")
def mv(k):
""" Generate multivalued (or not if single branch)"""
if single_branch:
return k
return jnp.concat([k, -k])
if normalize:
k_tilde *= jnp.sum(jnp.array(branchpoints))
else:
k_tilde *= len(branchpoints)
k_hat = k_tilde**2/sign
if len(branchpoints) == 1:
k = jnp.sqrt(k_hat + branchpoints[0]**2)
return mv(k)
C1 = branchpoints[0]**2
C2 = branchpoints[1]**2
nom = ((k_hat+C1+C2)**2)/4-C1*C2
den = k_hat
k = jnp.sqrt(nom/den)
return mv(k)
if __name__ == "__main__":
import numpy as np
n = 3
k = (np.random.random(n)-0.5)*10 + 1j*(np.random.random(n)-0.5)*10
Cs = (1,2)
bcs = (np.pi/2, np.pi/2)
k_tilde = tildify(k, Cs, bcs)
k_prime = inverse_tildify(k_tilde, Cs)
print(k, k_prime)