"""
Surface Green's function implementation for 1D chain contacts.
This module provides a surface Green's function calculator for quasi-1D chain
contacts in quantum transport calculations. It supports three usage patterns for
specifying contact parameters: fully automatic extraction from Fock/Overlap
matrices, custom coupling matrices with automatic onsite parameters, or fully
specified contact parameters. The implementation uses iterative solvers with JAX
JIT compilation for efficient computation.
The surfG class handles:
- Semi-infinite 1D chain Green's function calculations
- Multiple contact geometries and coupling specifications
- Orthogonal and non-orthogonal contact overlap matrices
- De-orthonormalization for proper self-energy contributions
- Contact regularization via congruent eigenvalue clipping
- Self-energy and cross-term calculations with convergence control
Key exports:
- surfG: Main surface Green's function calculator class
"""
# Python packages
import jax
import jax.numpy as jnp
import jax.lax as lax
from jax import jit
# Configuration
from gauNEGF.config import (ETA, SURFACE_GREEN_CONVERGENCE, SURFACE_RELAXATION_FACTOR,
OVERLAP_EIGENVALUE_RATIO, FERMI_DEBUG)
from gauNEGF.utils import fractional_matrix_power, inv, eigh
from gauNEGF.density import densityComplex
#Constants
[docs]
class surfG:
"""
Surface Green's function calculator for 1D chain contacts.
This class implements the surface Green's function calculation for 1D chain
contacts. It supports three usage patterns:
a) Fully automatic extraction from Fock matrix:
- Provide contact indices and connection indices
- All parameters extracted from F/S matrices
Example: surfG1D(F, S, [[c1], [c2]], [[c1conn], [c2conn]])
b) Fock matrix with custom coupling:
- Provide contact indices and coupling matrices
- Onsite contact parameters from F/S, coupling specified manually
- If staus=None (default), assumes orthonormal coupling
Example: surfG1D(F, S, [[c1], [c2]], [tau1, tau2])
c) Fully specified contacts:
- All contact parameters provided manually
- If aOverlaps=None, onsite overlap defaults to identity (orthonormal)
- If bOverlaps=None, hopping overlap defaults to zeros (orthonormal)
Example: surfG1D(F, S, [[c1], [c2]], [tau1, tau2], [stau1, stau2],
[alpha1, alpha2], [salpha1, salpha2], [beta1, beta2], [sbeta1, sbeta2])
Parameters
----------
Fock : ndarray
Fock matrix for the extended system
Overlap : ndarray
Overlap matrix for the extended system
indsList : list of lists
Lists of orbital indices for each contact region
taus : list or None, optional
Either coupling matrices or connection indices (default: None)
- If indices: [[contact1connection], [contact2connection]]
- If matrices: [tau1, tau2]
staus : list or None, optional
Overlap matrices for coupling (default: None = orthonormal coupling)
None entries trigger de-orthonormalization in sigma()
alphas : list of ndarray or None, optional
On-site energies for contacts, required for pattern (c) (default: None)
aOverlaps : list of ndarray or None, optional
On-site overlap matrices; defaults to identity when None (orthonormal)
betas : list of ndarray or None, optional
Hopping matrices between contact unit cells, required for pattern (c) (default: None)
bOverlaps : list of ndarray or None, optional
Overlap matrices between contact unit cells; defaults to zeros when None
eta : float, optional
Broadening parameter in eV (default: 1e-9)
Attributes
----------
F : ndarray
Fock matrix
S : ndarray
Overlap matrix
X : ndarray
Inverse square root of overlap matrix for orthogonalization (S^-0.5)
Xi : ndarray
Square root of overlap matrix for de-orthonormalization (S^+0.5 = inv(X))
tauList : list
Contact coupling matrices
stauList : list
Contact coupling overlap matrices (None entries -> orthonormal coupling)
aList : list
On-site energy matrices for contacts
aSList : list
On-site overlap matrices for contacts
bList : list
Hopping matrices between contact unit cells
bSList : list
Overlap matrices between contact unit cells
gPrev : list
Previous surface Green's functions for convergence
"""
[docs]
def __init__(self, Fock, Overlap, indsList, taus=None, staus=None, alphas=None, aOverlaps=None, betas=None, bOverlaps=None, eta=ETA, spin='r'):
"""
Initialize the surface Green's function calculator.
The initialization follows one of three patterns:
a) Fully automatic: Only provide Fock, Overlap, indsList, and connection indices in taus
b) Custom coupling: Provide Fock, Overlap, indsList, coupling matrices in taus
- staus=None (default) means orthonormal coupling -> de-ortho applied in sigma()
c) Fully specified: Provide all parameters including alphas, aOverlaps, betas, bOverlaps
- aOverlaps=None defaults to identity (orthonormal onsite)
- bOverlaps=None defaults to zeros (orthonormal hopping)
Parameters
----------
Fock : ndarray
Fock matrix for the extended system
Overlap : ndarray
Overlap matrix for the extended system
indsList : list of lists
Lists of orbital indices for each contact region
taus : list or None, optional
Either coupling matrices or connection indices (default: None)
- If indices: [[contact1connection], [contact2connection]]
- If matrices: [tau1, tau2]
staus : list or None, optional
Overlap matrices for coupling (default: None = orthonormal)
alphas : list of ndarray or None, optional
On-site energies for contacts, required for pattern (c) (default: None)
aOverlaps : list of ndarray or None, optional
On-site overlap matrices for contacts (default: None = identity)
betas : list of ndarray or None, optional
Hopping matrices between contact unit cells, required for pattern (c) (default: None)
bOverlaps : list of ndarray or None, optional
Overlap matrices between contact unit cells (default: None = zeros)
eta : float, optional
Broadening parameter in eV (default: 1e-9)
spin : str, optional
Spin configuration ('r' for restricted) (default: 'r')
"""
# Set up system
self.F = jnp.array(Fock)
self.S = jnp.array(Overlap)
self.S_orig = jnp.array(Overlap)
self.spin = spin
self.X = jnp.array(fractional_matrix_power(Overlap, -0.5))
self.Xi = jnp.linalg.inv(self.X)
# Keep indsList as Python list - loop unrolls with concrete indices
self.indsList = [jnp.array(inds) for inds in indsList]
# Set Contact Coupling
if taus is None:
taus = [self.indsList[-1], self.indsList[0]]
taus = [jnp.array(tau) for tau in taus]
if len(jnp.shape(taus[0])) == 1:
self.tauFromFock = True
self.tauInds = taus
taus = [self.F[jnp.ix_(self.tauInds[0],self.indsList[0])],
self.F[jnp.ix_(self.tauInds[1],self.indsList[-1])]]
staus = [self.S[jnp.ix_(self.tauInds[0],self.indsList[0])],
self.S[jnp.ix_(self.tauInds[1],self.indsList[-1])]]
# Canonicalize all-zero overlap blocks to None (orthogonal-contact sentinel)
staus = [None if not bool(jnp.any(stau)) else stau for stau in staus]
else:
self.tauFromFock = False
# TODO: orthogonal-basis foot-gun. When a user works in a Lowdin-orthogonal
# basis they often pass alpha_S=I and beta_S=I (or tau_S=I), meaning
# "orthogonal everything". beta_S=I is wrong: in an orthogonal basis the
# inter-cell overlap is 0, not I, and beta_S=I makes the periodic contact
# overlap S(k) = (1 + 2 cos k) I non-PSD, which corrupts the surface GF
# (per-contact TSW collapses to ~0, device sigma is bogus at deep E,
# calcEmin reports negative DOS, calcTSW finds wildly wrong Eminf).
# Two fixes worth considering:
# (1) Auto-canonicalize: in the matrix-input branch above, detect
# beta_S equal to identity when alpha_S is also identity and warn
# (or treat as zero). Mirror the all-zero -> None canonicalization
# already done in the tauFromFock branch for staus/bOverlaps.
# (2) Add an explicit orthogonal=True flag to scfE.setContact1D that
# overrides aOverlaps/bOverlaps to I and 0 respectively. Removes
# the foot-gun for the common Lowdin-orthogonal case entirely.
self.tauList = taus
self.stauList = ([None] * len(taus) if staus is None
else [None if stau is None else jnp.array(stau) for stau in staus])
# Store number of contacts for loop bounds
self.num_contacts = len(indsList)
# Set up broadening for retarded/advanced Green's function, initialize g
self.eta = eta
# Rigid-band shift bookkeeping (used in contactFromFock=False branch).
# fermi0 captured on first setF call; dFermi = mu - fermi0 thereafter.
# Must be initialized before _setContacts -- regularization can trigger
# a debug density calc that calls crossTermQ which reads dFermiList.
self.dFermiList = [0.0] * self.num_contacts
self.fermi0List = [None] * self.num_contacts
# Set up contact information
if alphas is None:
self.contactFromFock = True
self._setContacts()
else:
self.contactFromFock = False
self._setContacts(alphas, aOverlaps, betas, bOverlaps)
self.fermiList = [None]*len(indsList)
# JIT compile g and sigma methods with static contact index
# This compiles separate versions for each contact (i=0, i=1, etc.)
# The expensive iterative calculation gets fully optimized
self._rejit()
def _setContacts(self, alphas=None, aOverlaps=None, betas=None, bOverlaps=None):
"""Internal: build aList/aSList/bList/bSList and regularize contacts.
contactFromFock=True: extracts alpha/Salpha from self.F and self.S_orig.
contactFromFock=False: uses provided alphas/aOverlaps/betas/bOverlaps.
Calls _regularizeContacts() after setting lists.
"""
if self.contactFromFock:
alphas = []
aOverlaps = []
for inds in self.indsList:
alphas.append(self.F[jnp.ix_(inds, inds)])
aOverlaps.append(self.S_orig[jnp.ix_(inds, inds)])
self.aList = alphas
self.aSList = aOverlaps
self.bList = [jnp.array(tau) for tau in self.tauList]
self.bSList = [jnp.zeros_like(tau) if stau is None else jnp.array(stau)
for tau, stau in zip(self.tauList, self.stauList)]
else:
self.aList = [jnp.array(alpha) for alpha in alphas]
self.bList = [jnp.array(beta) for beta in betas]
self.aSList = ([jnp.eye(len(alpha)) for alpha in alphas] if aOverlaps is None
else [jnp.array(aOverlap) for aOverlap in aOverlaps])
self.bSList = ([jnp.zeros_like(beta) for beta in betas] if bOverlaps is None
else [jnp.zeros_like(beta) if bOverlap is None else jnp.array(bOverlap)
for beta, bOverlap in zip(betas, bOverlaps)])
# Initial snapshot so g() works during regularization's debug-density
# call (if FERMI_DEBUG). Refreshed below to capture regularized state.
self.aList0 = [jnp.array(a) for a in self.aList]
self.bList0 = [jnp.array(b) for b in self.bList]
self.aSList0 = [jnp.array(s) for s in self.aSList]
self.bSList0 = [jnp.array(s) for s in self.bSList]
self._regularizeContacts()
# Final snapshot: intrinsic contact reference state. g/sigma read these
# so the rigid-band shift is applied via E only -- aList itself never
# mutates post-init in the contactFromFock=False branch.
self.aList0 = [jnp.array(a) for a in self.aList]
self.bList0 = [jnp.array(b) for b in self.bList]
self.aSList0 = [jnp.array(s) for s in self.aSList]
self.bSList0 = [jnp.array(s) for s in self.bSList]
def _regularizeContacts(self):
"""Ensure the infinite chain overlap is PSD via congruent eigenvalue clipping.
NOTE: This function is currently DISABLED, needs development/debugging
For each contact i, diagonalizes S0 = aSList[i]. If the minimum eigenvalue
is already >= OVERLAP_EIGENVALUE_RATIO * max(eigenvalue), no transform is
applied. Otherwise builds a congruence transform C that floors small
eigenvalues and applies C' @ X @ C to all four contact matrices in-place:
aSList[i], aList[i], bSList[i], bList[i]
This keeps the basis dimension unchanged and each basis vector as close as
possible to an original orbital. No downstream sigma correction is needed.
See docs/plans/2026-03-10-congruent-clipping-design.md.
"""
if not hasattr(self, "CList"):
self.CList = [jnp.eye(len(A)*2) for A in self.aList]
for i in range(len(self.indsList)):
S0 = self.aSList[i]
S1 = self.bSList[i]
S2 = jnp.block([[S0, S1],
[S1.T, S0]])
n = S1.shape[0]
eigvals, U = eigh(S2)
lam_max = jnp.max(eigvals)
lam_min_thresh = OVERLAP_EIGENVALUE_RATIO * lam_max
# Turning off regularization for now
if True:#jnp.min(eigvals) >= lam_min_thresh:
# Already PSD -- no transform needed
continue
lam_prime = jnp.maximum(eigvals, lam_min_thresh)
C = U @ jnp.diag(jnp.sqrt(lam_prime / jnp.abs(eigvals)))
C = C.astype(self.aSList[i].dtype)
self.CList[i] = C.copy()
S2_reg = C.conj().T@S2@C
self.aSList[i] = (S2_reg[:n, :n] + S2_reg[n:, n:])/2
self.bSList[i] = S2_reg[:n, n:]#(S3_reg[:n, n:-n] + S3_reg[n:-n, -n:])/2
print(f'Contact overlap regularized (contact {i}): '
f'min_eig {eigvals[0]:.4e} -> {lam_min_thresh:.4e}')
if FERMI_DEBUG:
rho_, _ = densityComplex(self.F, self.S, self, -1e6, 1e6)
print(f"Total Spectral Weight: {jnp.trace(rho_@self.S).real}")
for i in range(len(self.indsList)):
H0 = self.aList[i]
H1 = self.bList[i]
n=len(H0)
H2 = jnp.block([[H0, H1],
[H1.conj().T, H0]])
C = self.CList[i]
H2_reg = C.conj().T@H2@C
self.aList[i] = (H2_reg[n:, n:]+H2_reg[:n, :n])/2
self.bList[i] = H2_reg[:n, n:]#(H3_reg[:n, n:-n] + H3_reg[n:-n, -n:])/2
def _rejit(self):
"""Recompile g and sigma to pick up updated contact parameters.
JAX JIT caches compiled functions keyed on shape/dtype of closed-over
arrays, not their values. After setF/_setContacts change aList/bList,
creating fresh JIT wrappers forces a re-trace on next call.
self.__class__.g always refers to the original class method regardless
of what self.g currently points to (instance vs class attribute).
"""
self.g = jit(self.__class__.g.__get__(self), static_argnums=(1,))
[docs]
def g(self, E, i, conv=SURFACE_GREEN_CONVERGENCE, relFactor=0.5):#SURFACE_RELAXATION_FACTOR):
"""
Calculate surface Green's function for a contact.
Uses an iterative scheme to calculate the surface Green's function
for contact i at energy E. The iteration continues until the change
in the Green's function is below the convergence criterion.
Parameters
----------
E : float
Energy point in eV
i : int
Contact index (static argument for JAX JIT compilation)
conv : float, optional
Convergence criterion for iteration (default: 1e-5)
relFactor : float, optional
Relaxation factor for iteration mixing (default: 0.5)
Returns
-------
ndarray
Surface Green's function matrix for contact i
"""
alpha = self.aList0[i]
Salpha = self.aSList0[i]
beta = self.bList0[i]
Sbeta = self.bSList0[i]
# Prepare matrices using JAX
A = (E+1j*self.eta)*Salpha - alpha
B = (E+1j*self.eta)*Sbeta - beta
B_bar = (E+1j*self.eta)*Sbeta.conj().T - beta.conj().T
# Iterative solution using jax.lax.while_loop
MAX_ITER = 10000
def cond_fun(state):
count, diff, g = state
return (diff > conv) & (count < MAX_ITER)
def body_fun(state):
count, diff, g = state
# Compute new Green's function using JAX operations
g_new = inv(A - B @ g @ B_bar)
# Compute convergence metric
dg = jnp.abs(g_new - g) / jnp.maximum(jnp.abs(g_new), conv)
diff = jnp.max(dg)
# Apply relaxation mixing
g = g_new * relFactor + g * (1 - relFactor)
count += 1
return (count, diff, g)
# Initial state: (count, diff, g)
init_state = (0, jnp.inf, inv(A))
count, diff, g = lax.while_loop(cond_fun, body_fun, init_state)
return g
[docs]
def setF(self, F, mu1=None, mu2=None):
"""
Update the Fock matrix and contact chemical potentials.
This method updates the system's Fock matrix and optionally shifts
the contact chemical potentials. If the contacts are extracted from
the Fock matrix, their parameters are automatically updated.
Parameters
----------
F : ndarray
New Fock matrix for the system
mu1 : float or None, optional
Chemical potential for first contact in eV (default: None)
mu2 : float or None, optional
Chemical potential for second contact in eV (default: None)
"""
self.F = jnp.array(F)
if self.tauFromFock:
taus = self.tauInds
indsList = self.indsList # Python list
# Rebuild coupling arrays from new F
tau_temp = [self.F[jnp.ix_(taus[0],indsList[0])], self.F[jnp.ix_(taus[1],indsList[-1])]]
stau_temp = [self.S[jnp.ix_(taus[0],indsList[0])], self.S[jnp.ix_(taus[1],indsList[-1])]]
# Canonicalize all-zero overlap blocks to None (orthogonal-contact sentinel)
stau_temp = [None if not bool(jnp.any(stau)) else stau for stau in stau_temp]
self.tauList = tau_temp
self.stauList = stau_temp
if self.contactFromFock:
# Rebuild aList/bList from new F and re-trace JIT'd functions.
# _setContacts also refreshes the *List0 snapshots, so g sees the
# updated alphas. dFermiList stays at zero -- the Fermi shift is
# already baked into the new F.
self._setContacts()
self._rejit()
if not self.contactFromFock:
# Track chemical potentials and compute rigid-shift dFermi.
# Pre-shifting E in sigma/crossTermQ is algebraically identical
# to shifting alpha/beta by dFermi*S; mu enters self-energy only
# via this shift (and via Fermi function in density integration).
mus = [mu1, mu2]
for slot, mu in zip([0, -1], mus):
if mu is None:
continue
if self.fermi0List[slot] is None:
self.fermi0List[slot] = mu
self.fermiList[slot] = mu
self.dFermiList[slot] = 0.0
else:
self.fermiList[slot] = mu
self.dFermiList[slot] = mu - self.fermi0List[slot]
[docs]
def sigma(self, E, i, conv=SURFACE_GREEN_CONVERGENCE):
"""
Calculate self-energy matrix for a contact.
Computes the self-energy matrix for contact i at energy E using
the surface Green's function. The self-energy represents the
effect of the semi-infinite contact on the device region.
When stauList[i] is None (orthonormal coupling), applies de-orthonormalization:
sig -> Xi[inds,inds] @ sig @ Xi[inds,inds] where Xi = S^+0.5 = inv(X).
Parameters
----------
E : float
Energy point in eV
i : int (static)
Contact index
conv : float, optional
Convergence criterion for surface Green's function (default: 1e-5)
Returns
-------
ndarray
Self-energy matrix for contact i
"""
inds = self.indsList[i]
stau = self.stauList[i]
tau = self.tauList[i]
# Pre-shift E for the surface Green's function call only. The device-
# contact coupling tau (F_dc block) is not rigid-shifted, so t/bar_t
# use raw E. Matches surfGBAt convention.
E_shifted = E - self.dFermiList[i]
t = (-tau) if stau is None else (E*stau - tau)
bar_t = (-tau.conj().T) if stau is None else (E*stau.conj().T - tau.conj().T)
n = len(self.aList[i])
C_mid = self.CList[i][:n, :n]
t_reg = t @ C_mid
bar_t_reg = C_mid.conj().T @ bar_t
sig = t_reg @ self.g(E_shifted, i, conv) @ bar_t_reg
sigma = jnp.zeros(self.F.shape, dtype=complex)
sigma = sigma.at[jnp.ix_(inds, inds)].add(sig)
# De-orthonormalize: stau is None signals orthonormal tau, so Xi @ sig @ Xi
# maps sigma back to the non-orthogonal AO basis. Only safe when alpha/beta
# are also orthonormal (i.e., contactFromFock=False with identity aOverlaps).
sigma = lax.cond(stau is None,
lambda s: self.Xi @ s @ self.Xi,
lambda s: s,
sigma)
return sigma
[docs]
def crossTermQ(self, E, i, conv=SURFACE_GREEN_CONVERGENCE):
"""Symmetrized cross-term matrix Q_sym_i in full device basis.
Q_sym = (t_eff @ g_surf @ S_LD + S_DL @ g_surf @ bar_t_eff) / 2
where t_eff is the regularized tau_DL and bar_t_eff is the regularized
tau_LD (EOM coupling with unconjugated E, NOT t_eff^dagger). Returns
None if contact i has orthogonal coupling (stauList[i] is None).
"""
stau = self.stauList[i]
if stau is None:
return None
inds = self.indsList[i]
tau = self.tauList[i]
n = len(self.aList[i])
E_shifted = E - self.dFermiList[i]
t = E * stau - tau
C_mid = self.CList[i][:n, :n]
t_reg = t @ C_mid
g_surf = self.g(E_shifted, i, conv)
bar_t = E * stau.conj().T - tau.conj().T
bar_t_reg = C_mid.conj().T @ bar_t
Q_fwd = t_reg @ g_surf @ stau.conj().T
Q_rev = stau @ g_surf @ bar_t_reg
Q_raw = (Q_fwd + Q_rev) / 2
Q = jnp.zeros(self.F.shape, dtype=complex)
Q = Q.at[jnp.ix_(inds, inds)].set(Q_raw)
return Q
[docs]
def crossTermQTot(self, E, conv=SURFACE_GREEN_CONVERGENCE):
"""Sum of Q_sym over all contacts. Returns None if all contacts orthogonal."""
Q_tot = None
for i in range(self.num_contacts):
Q_i = self.crossTermQ(E, i, conv)
if Q_i is not None:
Q_tot = Q_i if Q_tot is None else Q_tot + Q_i
return Q_tot
[docs]
def sigmaTot(self, E, conv=SURFACE_GREEN_CONVERGENCE):
"""
Calculate total self-energy matrix from all contacts.
Computes the total self-energy matrix at energy E by summing
contributions from all contacts. This represents the combined
effect of all semi-infinite contacts on the device region.
Parameters
----------
E : float
Energy point in eV
conv : float, optional
Convergence criterion for surface Green's functions (default: 1e-5)
Returns
-------
ndarray
Total self-energy matrix from all contacts
"""
# Use Python for loop - JAX unrolls it with concrete indices
sigma = jnp.zeros(self.F.shape, dtype=complex)
for i in range(self.num_contacts):
sigma = sigma + self.sigma(E, i, conv)
return sigma