"""
Utility functions compiled with JIT for gauNEGF.
Contains commonly used pure mathematical functions that are reused
across multiple modules in the gauNEGF package.
"""
import jax.numpy as jnp
from jax import jit
[docs]
@jit
def fractional_matrix_power(S, power):
"""
Calculate matrix power S^p using eigendecomposition.
Supports fractional powers including negative values like -0.5.
Parameters
----------
S : jax array
Input matrix (should be Hermitian for numerical stability)
power : float
Power to raise matrix to (e.g., 0.5 for sqrt, -0.5 for inverse sqrt)
Returns
-------
jax array
Matrix power S^p
Notes
-----
This function is optimized for Hermitian matrices (like overlap matrices)
and uses eigendecomposition: S^p = V @ D^p @ V^H where S = V @ D @ V^H.
Unlike JAX's matrix_power, this function properly handles fractional
powers including negative values.
"""
# Use eigh for Hermitian matrices (more stable and faster than eig)
eigenvalues, eigenvectors = eigh(S)
# Handle numerical precision for near-zero eigenvalues
eigenvalues = jnp.maximum(eigenvalues, 1e-16)
powered_eigenvalues = jnp.power(eigenvalues, power)
# Reconstruct matrix: S^p = V @ D^p @ V^H
result = eigenvectors @ jnp.diag(powered_eigenvalues) @ eigenvectors.conj().T
return result
[docs]
@jit
def inv_sqrt_general(M):
"""Inverse square root M^(-1/2) for a general diagonalizable matrix.
Diagonalizes M = V @ diag(D) @ V^(-1) with the general (non-symmetric)
eigendecomposition and returns Y = V @ diag(D^(-1/2)) @ V^(-1). This
satisfies Y @ M @ Y = I to machine precision for ANY diagonalizable M --
Hermitian, complex-symmetric, or neither -- since D^(-1/2) D D^(-1/2) = 1
holds for every branch of the scalar square root.
Use this for the effective overlap S_eff = S - X_asymp. The retarded contact
self-energy (Sigma = A @ g_surf @ A^dagger, g_surf complex-symmetric) carries
broadening Gamma = i(Sigma - Sigma^dagger) and is therefore NON-Hermitian, so
X_asymp and S_eff are non-Hermitian. eigh assumes a Hermitian matrix and uses
only one triangle, which silently computes the wrong S_eff^(-1/2); use this
routine instead. The physical lower-contour density is independent of the
sqrt branch chosen here, because G(E) = Y @ (E*I - Fbar)^(-1) @ Y reduces to
[E*S_eff - H_eff]^(-1) for any Y with Y @ Y = S_eff^(-1).
Parameters
----------
M : jax array (N, N)
General diagonalizable matrix (may be complex / non-Hermitian).
Returns
-------
jax array (N, N), complex
Y = M^(-1/2), satisfying Y @ M @ Y = I.
Notes
-----
Uses JAX's general eig (jnp.linalg.eig, CPU backend) -- the same wrapper
used elsewhere in this module. Accuracy degrades if the eigenvector matrix
is ill-conditioned (near-defective M); a Schur-based matrix power is the
robust fallback for that case.
"""
D, V = jnp.linalg.eig(M)
D_inv_sqrt = jnp.power(D.astype(jnp.complex128), -0.5)
return V @ jnp.diag(D_inv_sqrt) @ jnp.linalg.inv(V)
# Simple numpy operations
[docs]
@jit
def inv(A):
"""
Compute matrix inverse using JAX linalg solve.
Solves the linear system A @ X = I for X, which is equivalent to
computing the inverse A^(-1). This method is more numerically stable
than direct inversion for ill-conditioned matrices.
Parameters
----------
A : jax array
Square invertible matrix (NxN).
Returns
-------
jax array
Inverse matrix A^(-1) (NxN).
"""
return jnp.linalg.solve(A, jnp.eye(A.shape[0]))
[docs]
@jit
def eig(A):
"""
Compute eigenvalues and eigenvectors of a square matrix.
Wrapper around JAX's general eigenvalue decomposition for non-Hermitian
matrices. For Hermitian matrices, prefer eigh() which uses a more stable
algorithm.
Parameters
----------
A : jax array
Square matrix (NxN), may be complex or non-Hermitian.
Returns
-------
eigenvalues : jax array of shape (N,)
Eigenvalues of A (may be complex).
eigenvectors : jax array of shape (NxN)
Eigenvectors as columns, with A @ eigenvectors = eigenvectors @ diag(eigenvalues).
"""
return jnp.linalg.eig(A)
[docs]
@jit
def eigh(A):
"""
Compute eigenvalues and eigenvectors of a Hermitian matrix.
Wrapper around JAX's Hermitian eigenvalue decomposition. Assumes the
input matrix is Hermitian (A = A^H) and uses a more stable algorithm
than eig().
Parameters
----------
A : jax array
Hermitian matrix (NxN).
Returns
-------
eigenvalues : jax array of shape (N,)
Real-valued eigenvalues of A in ascending order.
eigenvectors : jax array of shape (NxN)
Eigenvectors as columns, with A @ eigenvectors = eigenvectors @ diag(eigenvalues).
"""
return jnp.linalg.eigh(A)