Source code for gauNEGF.surfG1D

"""
Surface Green's function implementation for 1D chain contacts.

This module provides a surface Green's function calculator for quasi-1D chain
contacts in quantum transport calculations. It supports three usage patterns for
specifying contact parameters: fully automatic extraction from Fock/Overlap
matrices, custom coupling matrices with automatic onsite parameters, or fully
specified contact parameters. The implementation uses iterative solvers with JAX
JIT compilation for efficient computation.

The surfG class handles:
- Semi-infinite 1D chain Green's function calculations
- Multiple contact geometries and coupling specifications
- Orthogonal and non-orthogonal contact overlap matrices
- De-orthonormalization for proper self-energy contributions
- Contact regularization via congruent eigenvalue clipping
- Self-energy and cross-term calculations with convergence control

Key exports:
- surfG: Main surface Green's function calculator class
"""

# Python packages
import jax
import jax.numpy as jnp
import jax.lax as lax
from jax import jit

# Configuration
from gauNEGF.config import (ETA, SURFACE_GREEN_CONVERGENCE, SURFACE_RELAXATION_FACTOR,
                             OVERLAP_EIGENVALUE_RATIO, FERMI_DEBUG)
from gauNEGF.utils import fractional_matrix_power, inv, eigh
from gauNEGF.density import densityComplex

#Constants

[docs] class surfG: """ Surface Green's function calculator for 1D chain contacts. This class implements the surface Green's function calculation for 1D chain contacts. It supports three usage patterns: a) Fully automatic extraction from Fock matrix: - Provide contact indices and connection indices - All parameters extracted from F/S matrices Example: surfG1D(F, S, [[c1], [c2]], [[c1conn], [c2conn]]) b) Fock matrix with custom coupling: - Provide contact indices and coupling matrices - Onsite contact parameters from F/S, coupling specified manually - If staus=None (default), assumes orthonormal coupling Example: surfG1D(F, S, [[c1], [c2]], [tau1, tau2]) c) Fully specified contacts: - All contact parameters provided manually - If aOverlaps=None, onsite overlap defaults to identity (orthonormal) - If bOverlaps=None, hopping overlap defaults to zeros (orthonormal) Example: surfG1D(F, S, [[c1], [c2]], [tau1, tau2], [stau1, stau2], [alpha1, alpha2], [salpha1, salpha2], [beta1, beta2], [sbeta1, sbeta2]) Parameters ---------- Fock : ndarray Fock matrix for the extended system Overlap : ndarray Overlap matrix for the extended system indsList : list of lists Lists of orbital indices for each contact region taus : list or None, optional Either coupling matrices or connection indices (default: None) - If indices: [[contact1connection], [contact2connection]] - If matrices: [tau1, tau2] staus : list or None, optional Overlap matrices for coupling (default: None = orthonormal coupling) None entries trigger de-orthonormalization in sigma() alphas : list of ndarray or None, optional On-site energies for contacts, required for pattern (c) (default: None) aOverlaps : list of ndarray or None, optional On-site overlap matrices; defaults to identity when None (orthonormal) betas : list of ndarray or None, optional Hopping matrices between contact unit cells, required for pattern (c) (default: None) bOverlaps : list of ndarray or None, optional Overlap matrices between contact unit cells; defaults to zeros when None eta : float, optional Broadening parameter in eV (default: 1e-9) Attributes ---------- F : ndarray Fock matrix S : ndarray Overlap matrix X : ndarray Inverse square root of overlap matrix for orthogonalization (S^-0.5) Xi : ndarray Square root of overlap matrix for de-orthonormalization (S^+0.5 = inv(X)) tauList : list Contact coupling matrices stauList : list Contact coupling overlap matrices (None entries -> orthonormal coupling) aList : list On-site energy matrices for contacts aSList : list On-site overlap matrices for contacts bList : list Hopping matrices between contact unit cells bSList : list Overlap matrices between contact unit cells gPrev : list Previous surface Green's functions for convergence """
[docs] def __init__(self, Fock, Overlap, indsList, taus=None, staus=None, alphas=None, aOverlaps=None, betas=None, bOverlaps=None, eta=ETA, spin='r'): """ Initialize the surface Green's function calculator. The initialization follows one of three patterns: a) Fully automatic: Only provide Fock, Overlap, indsList, and connection indices in taus b) Custom coupling: Provide Fock, Overlap, indsList, coupling matrices in taus - staus=None (default) means orthonormal coupling -> de-ortho applied in sigma() c) Fully specified: Provide all parameters including alphas, aOverlaps, betas, bOverlaps - aOverlaps=None defaults to identity (orthonormal onsite) - bOverlaps=None defaults to zeros (orthonormal hopping) Parameters ---------- Fock : ndarray Fock matrix for the extended system Overlap : ndarray Overlap matrix for the extended system indsList : list of lists Lists of orbital indices for each contact region taus : list or None, optional Either coupling matrices or connection indices (default: None) - If indices: [[contact1connection], [contact2connection]] - If matrices: [tau1, tau2] staus : list or None, optional Overlap matrices for coupling (default: None = orthonormal) alphas : list of ndarray or None, optional On-site energies for contacts, required for pattern (c) (default: None) aOverlaps : list of ndarray or None, optional On-site overlap matrices for contacts (default: None = identity) betas : list of ndarray or None, optional Hopping matrices between contact unit cells, required for pattern (c) (default: None) bOverlaps : list of ndarray or None, optional Overlap matrices between contact unit cells (default: None = zeros) eta : float, optional Broadening parameter in eV (default: 1e-9) spin : str, optional Spin configuration ('r' for restricted) (default: 'r') """ # Set up system self.F = jnp.array(Fock) self.S = jnp.array(Overlap) self.S_orig = jnp.array(Overlap) self.spin = spin self.X = jnp.array(fractional_matrix_power(Overlap, -0.5)) self.Xi = jnp.linalg.inv(self.X) # Keep indsList as Python list - loop unrolls with concrete indices self.indsList = [jnp.array(inds) for inds in indsList] # Set Contact Coupling if taus is None: taus = [self.indsList[-1], self.indsList[0]] taus = [jnp.array(tau) for tau in taus] if len(jnp.shape(taus[0])) == 1: self.tauFromFock = True self.tauInds = taus taus = [self.F[jnp.ix_(self.tauInds[0],self.indsList[0])], self.F[jnp.ix_(self.tauInds[1],self.indsList[-1])]] staus = [self.S[jnp.ix_(self.tauInds[0],self.indsList[0])], self.S[jnp.ix_(self.tauInds[1],self.indsList[-1])]] # Canonicalize all-zero overlap blocks to None (orthogonal-contact sentinel) staus = [None if not bool(jnp.any(stau)) else stau for stau in staus] else: self.tauFromFock = False # TODO: orthogonal-basis foot-gun. When a user works in a Lowdin-orthogonal # basis they often pass alpha_S=I and beta_S=I (or tau_S=I), meaning # "orthogonal everything". beta_S=I is wrong: in an orthogonal basis the # inter-cell overlap is 0, not I, and beta_S=I makes the periodic contact # overlap S(k) = (1 + 2 cos k) I non-PSD, which corrupts the surface GF # (per-contact TSW collapses to ~0, device sigma is bogus at deep E, # calcEmin reports negative DOS, calcTSW finds wildly wrong Eminf). # Two fixes worth considering: # (1) Auto-canonicalize: in the matrix-input branch above, detect # beta_S equal to identity when alpha_S is also identity and warn # (or treat as zero). Mirror the all-zero -> None canonicalization # already done in the tauFromFock branch for staus/bOverlaps. # (2) Add an explicit orthogonal=True flag to scfE.setContact1D that # overrides aOverlaps/bOverlaps to I and 0 respectively. Removes # the foot-gun for the common Lowdin-orthogonal case entirely. self.tauList = taus self.stauList = ([None] * len(taus) if staus is None else [None if stau is None else jnp.array(stau) for stau in staus]) # Store number of contacts for loop bounds self.num_contacts = len(indsList) # Set up broadening for retarded/advanced Green's function, initialize g self.eta = eta # Rigid-band shift bookkeeping (used in contactFromFock=False branch). # fermi0 captured on first setF call; dFermi = mu - fermi0 thereafter. # Must be initialized before _setContacts -- regularization can trigger # a debug density calc that calls crossTermQ which reads dFermiList. self.dFermiList = [0.0] * self.num_contacts self.fermi0List = [None] * self.num_contacts # Set up contact information if alphas is None: self.contactFromFock = True self._setContacts() else: self.contactFromFock = False self._setContacts(alphas, aOverlaps, betas, bOverlaps) self.fermiList = [None]*len(indsList) # JIT compile g and sigma methods with static contact index # This compiles separate versions for each contact (i=0, i=1, etc.) # The expensive iterative calculation gets fully optimized self._rejit()
def _setContacts(self, alphas=None, aOverlaps=None, betas=None, bOverlaps=None): """Internal: build aList/aSList/bList/bSList and regularize contacts. contactFromFock=True: extracts alpha/Salpha from self.F and self.S_orig. contactFromFock=False: uses provided alphas/aOverlaps/betas/bOverlaps. Calls _regularizeContacts() after setting lists. """ if self.contactFromFock: alphas = [] aOverlaps = [] for inds in self.indsList: alphas.append(self.F[jnp.ix_(inds, inds)]) aOverlaps.append(self.S_orig[jnp.ix_(inds, inds)]) self.aList = alphas self.aSList = aOverlaps self.bList = [jnp.array(tau) for tau in self.tauList] self.bSList = [jnp.zeros_like(tau) if stau is None else jnp.array(stau) for tau, stau in zip(self.tauList, self.stauList)] else: self.aList = [jnp.array(alpha) for alpha in alphas] self.bList = [jnp.array(beta) for beta in betas] self.aSList = ([jnp.eye(len(alpha)) for alpha in alphas] if aOverlaps is None else [jnp.array(aOverlap) for aOverlap in aOverlaps]) self.bSList = ([jnp.zeros_like(beta) for beta in betas] if bOverlaps is None else [jnp.zeros_like(beta) if bOverlap is None else jnp.array(bOverlap) for beta, bOverlap in zip(betas, bOverlaps)]) # Initial snapshot so g() works during regularization's debug-density # call (if FERMI_DEBUG). Refreshed below to capture regularized state. self.aList0 = [jnp.array(a) for a in self.aList] self.bList0 = [jnp.array(b) for b in self.bList] self.aSList0 = [jnp.array(s) for s in self.aSList] self.bSList0 = [jnp.array(s) for s in self.bSList] self._regularizeContacts() # Final snapshot: intrinsic contact reference state. g/sigma read these # so the rigid-band shift is applied via E only -- aList itself never # mutates post-init in the contactFromFock=False branch. self.aList0 = [jnp.array(a) for a in self.aList] self.bList0 = [jnp.array(b) for b in self.bList] self.aSList0 = [jnp.array(s) for s in self.aSList] self.bSList0 = [jnp.array(s) for s in self.bSList] def _regularizeContacts(self): """Ensure the infinite chain overlap is PSD via congruent eigenvalue clipping. NOTE: This function is currently DISABLED, needs development/debugging For each contact i, diagonalizes S0 = aSList[i]. If the minimum eigenvalue is already >= OVERLAP_EIGENVALUE_RATIO * max(eigenvalue), no transform is applied. Otherwise builds a congruence transform C that floors small eigenvalues and applies C' @ X @ C to all four contact matrices in-place: aSList[i], aList[i], bSList[i], bList[i] This keeps the basis dimension unchanged and each basis vector as close as possible to an original orbital. No downstream sigma correction is needed. See docs/plans/2026-03-10-congruent-clipping-design.md. """ if not hasattr(self, "CList"): self.CList = [jnp.eye(len(A)*2) for A in self.aList] for i in range(len(self.indsList)): S0 = self.aSList[i] S1 = self.bSList[i] S2 = jnp.block([[S0, S1], [S1.T, S0]]) n = S1.shape[0] eigvals, U = eigh(S2) lam_max = jnp.max(eigvals) lam_min_thresh = OVERLAP_EIGENVALUE_RATIO * lam_max # Turning off regularization for now if True:#jnp.min(eigvals) >= lam_min_thresh: # Already PSD -- no transform needed continue lam_prime = jnp.maximum(eigvals, lam_min_thresh) C = U @ jnp.diag(jnp.sqrt(lam_prime / jnp.abs(eigvals))) C = C.astype(self.aSList[i].dtype) self.CList[i] = C.copy() S2_reg = C.conj().T@S2@C self.aSList[i] = (S2_reg[:n, :n] + S2_reg[n:, n:])/2 self.bSList[i] = S2_reg[:n, n:]#(S3_reg[:n, n:-n] + S3_reg[n:-n, -n:])/2 print(f'Contact overlap regularized (contact {i}): ' f'min_eig {eigvals[0]:.4e} -> {lam_min_thresh:.4e}') if FERMI_DEBUG: rho_, _ = densityComplex(self.F, self.S, self, -1e6, 1e6) print(f"Total Spectral Weight: {jnp.trace(rho_@self.S).real}") for i in range(len(self.indsList)): H0 = self.aList[i] H1 = self.bList[i] n=len(H0) H2 = jnp.block([[H0, H1], [H1.conj().T, H0]]) C = self.CList[i] H2_reg = C.conj().T@H2@C self.aList[i] = (H2_reg[n:, n:]+H2_reg[:n, :n])/2 self.bList[i] = H2_reg[:n, n:]#(H3_reg[:n, n:-n] + H3_reg[n:-n, -n:])/2 def _rejit(self): """Recompile g and sigma to pick up updated contact parameters. JAX JIT caches compiled functions keyed on shape/dtype of closed-over arrays, not their values. After setF/_setContacts change aList/bList, creating fresh JIT wrappers forces a re-trace on next call. self.__class__.g always refers to the original class method regardless of what self.g currently points to (instance vs class attribute). """ self.g = jit(self.__class__.g.__get__(self), static_argnums=(1,))
[docs] def g(self, E, i, conv=SURFACE_GREEN_CONVERGENCE, relFactor=0.5):#SURFACE_RELAXATION_FACTOR): """ Calculate surface Green's function for a contact. Uses an iterative scheme to calculate the surface Green's function for contact i at energy E. The iteration continues until the change in the Green's function is below the convergence criterion. Parameters ---------- E : float Energy point in eV i : int Contact index (static argument for JAX JIT compilation) conv : float, optional Convergence criterion for iteration (default: 1e-5) relFactor : float, optional Relaxation factor for iteration mixing (default: 0.5) Returns ------- ndarray Surface Green's function matrix for contact i """ alpha = self.aList0[i] Salpha = self.aSList0[i] beta = self.bList0[i] Sbeta = self.bSList0[i] # Prepare matrices using JAX A = (E+1j*self.eta)*Salpha - alpha B = (E+1j*self.eta)*Sbeta - beta B_bar = (E+1j*self.eta)*Sbeta.conj().T - beta.conj().T # Iterative solution using jax.lax.while_loop MAX_ITER = 10000 def cond_fun(state): count, diff, g = state return (diff > conv) & (count < MAX_ITER) def body_fun(state): count, diff, g = state # Compute new Green's function using JAX operations g_new = inv(A - B @ g @ B_bar) # Compute convergence metric dg = jnp.abs(g_new - g) / jnp.maximum(jnp.abs(g_new), conv) diff = jnp.max(dg) # Apply relaxation mixing g = g_new * relFactor + g * (1 - relFactor) count += 1 return (count, diff, g) # Initial state: (count, diff, g) init_state = (0, jnp.inf, inv(A)) count, diff, g = lax.while_loop(cond_fun, body_fun, init_state) return g
[docs] def setF(self, F, mu1=None, mu2=None): """ Update the Fock matrix and contact chemical potentials. This method updates the system's Fock matrix and optionally shifts the contact chemical potentials. If the contacts are extracted from the Fock matrix, their parameters are automatically updated. Parameters ---------- F : ndarray New Fock matrix for the system mu1 : float or None, optional Chemical potential for first contact in eV (default: None) mu2 : float or None, optional Chemical potential for second contact in eV (default: None) """ self.F = jnp.array(F) if self.tauFromFock: taus = self.tauInds indsList = self.indsList # Python list # Rebuild coupling arrays from new F tau_temp = [self.F[jnp.ix_(taus[0],indsList[0])], self.F[jnp.ix_(taus[1],indsList[-1])]] stau_temp = [self.S[jnp.ix_(taus[0],indsList[0])], self.S[jnp.ix_(taus[1],indsList[-1])]] # Canonicalize all-zero overlap blocks to None (orthogonal-contact sentinel) stau_temp = [None if not bool(jnp.any(stau)) else stau for stau in stau_temp] self.tauList = tau_temp self.stauList = stau_temp if self.contactFromFock: # Rebuild aList/bList from new F and re-trace JIT'd functions. # _setContacts also refreshes the *List0 snapshots, so g sees the # updated alphas. dFermiList stays at zero -- the Fermi shift is # already baked into the new F. self._setContacts() self._rejit() if not self.contactFromFock: # Track chemical potentials and compute rigid-shift dFermi. # Pre-shifting E in sigma/crossTermQ is algebraically identical # to shifting alpha/beta by dFermi*S; mu enters self-energy only # via this shift (and via Fermi function in density integration). mus = [mu1, mu2] for slot, mu in zip([0, -1], mus): if mu is None: continue if self.fermi0List[slot] is None: self.fermi0List[slot] = mu self.fermiList[slot] = mu self.dFermiList[slot] = 0.0 else: self.fermiList[slot] = mu self.dFermiList[slot] = mu - self.fermi0List[slot]
[docs] def sigma(self, E, i, conv=SURFACE_GREEN_CONVERGENCE): """ Calculate self-energy matrix for a contact. Computes the self-energy matrix for contact i at energy E using the surface Green's function. The self-energy represents the effect of the semi-infinite contact on the device region. When stauList[i] is None (orthonormal coupling), applies de-orthonormalization: sig -> Xi[inds,inds] @ sig @ Xi[inds,inds] where Xi = S^+0.5 = inv(X). Parameters ---------- E : float Energy point in eV i : int (static) Contact index conv : float, optional Convergence criterion for surface Green's function (default: 1e-5) Returns ------- ndarray Self-energy matrix for contact i """ inds = self.indsList[i] stau = self.stauList[i] tau = self.tauList[i] # Pre-shift E for the surface Green's function call only. The device- # contact coupling tau (F_dc block) is not rigid-shifted, so t/bar_t # use raw E. Matches surfGBAt convention. E_shifted = E - self.dFermiList[i] t = (-tau) if stau is None else (E*stau - tau) bar_t = (-tau.conj().T) if stau is None else (E*stau.conj().T - tau.conj().T) n = len(self.aList[i]) C_mid = self.CList[i][:n, :n] t_reg = t @ C_mid bar_t_reg = C_mid.conj().T @ bar_t sig = t_reg @ self.g(E_shifted, i, conv) @ bar_t_reg sigma = jnp.zeros(self.F.shape, dtype=complex) sigma = sigma.at[jnp.ix_(inds, inds)].add(sig) # De-orthonormalize: stau is None signals orthonormal tau, so Xi @ sig @ Xi # maps sigma back to the non-orthogonal AO basis. Only safe when alpha/beta # are also orthonormal (i.e., contactFromFock=False with identity aOverlaps). sigma = lax.cond(stau is None, lambda s: self.Xi @ s @ self.Xi, lambda s: s, sigma) return sigma
[docs] def crossTermQ(self, E, i, conv=SURFACE_GREEN_CONVERGENCE): """Symmetrized cross-term matrix Q_sym_i in full device basis. Q_sym = (t_eff @ g_surf @ S_LD + S_DL @ g_surf @ bar_t_eff) / 2 where t_eff is the regularized tau_DL and bar_t_eff is the regularized tau_LD (EOM coupling with unconjugated E, NOT t_eff^dagger). Returns None if contact i has orthogonal coupling (stauList[i] is None). """ stau = self.stauList[i] if stau is None: return None inds = self.indsList[i] tau = self.tauList[i] n = len(self.aList[i]) E_shifted = E - self.dFermiList[i] t = E * stau - tau C_mid = self.CList[i][:n, :n] t_reg = t @ C_mid g_surf = self.g(E_shifted, i, conv) bar_t = E * stau.conj().T - tau.conj().T bar_t_reg = C_mid.conj().T @ bar_t Q_fwd = t_reg @ g_surf @ stau.conj().T Q_rev = stau @ g_surf @ bar_t_reg Q_raw = (Q_fwd + Q_rev) / 2 Q = jnp.zeros(self.F.shape, dtype=complex) Q = Q.at[jnp.ix_(inds, inds)].set(Q_raw) return Q
[docs] def crossTermQTot(self, E, conv=SURFACE_GREEN_CONVERGENCE): """Sum of Q_sym over all contacts. Returns None if all contacts orthogonal.""" Q_tot = None for i in range(self.num_contacts): Q_i = self.crossTermQ(E, i, conv) if Q_i is not None: Q_tot = Q_i if Q_tot is None else Q_tot + Q_i return Q_tot
[docs] def sigmaTot(self, E, conv=SURFACE_GREEN_CONVERGENCE): """ Calculate total self-energy matrix from all contacts. Computes the total self-energy matrix at energy E by summing contributions from all contacts. This represents the combined effect of all semi-infinite contacts on the device region. Parameters ---------- E : float Energy point in eV conv : float, optional Convergence criterion for surface Green's functions (default: 1e-5) Returns ------- ndarray Total self-energy matrix from all contacts """ # Use Python for loop - JAX unrolls it with concrete indices sigma = jnp.zeros(self.F.shape, dtype=complex) for i in range(self.num_contacts): sigma = sigma + self.sigma(E, i, conv) return sigma