# Python packages
import jax
import jax.numpy as np
import jax.lax as lax
from jax import jit
# Configuration
from gauNEGF.config import (ETA, SURFACE_GREEN_CONVERGENCE, SURFACE_RELAXATION_FACTOR)
from gauNEGF.utils import fractional_matrix_power, inv
#Constants
[docs]
class surfG:
"""
Surface Green's function calculator for 1D chain contacts.
This class implements the surface Green's function calculation for 1D chain
contacts. It supports three usage patterns:
a) Fully automatic extraction from Fock matrix:
- Provide contact indices and connection indices
- All parameters extracted from F/S matrices
Example: surfG1D(F, S, [[c1], [c2]], [[c1conn], [c2conn]])
b) Fock matrix with custom coupling:
- Provide contact indices and coupling matrices
- Onsite contact parameters from F/S, coupling specified manually
Example: surfG1D(F, S, [[c1], [c2]], [tau1, tau2], [stau1, stau2])
c) Fully specified contacts:
- All contact parameters provided manually
Example: surfG1D(F, S, [[c1], [c2]], [tau1, tau2], [stau1, stau2],
[alpha1, alpha2], [salpha1, salpha2], [beta1, beta2], [sbeta1, sbeta2])
Parameters
----------
Fock : ndarray
Fock matrix for the extended system
Overlap : ndarray
Overlap matrix for the extended system
indsList : list of lists
Lists of orbital indices for each contact region
taus : list or None, optional
Either coupling matrices or connection indices (default: None)
- If indices: [[contact1connection], [contact2connection]]
- If matrices: [tau1, tau2]
staus : list or None, optional
Overlap matrices for coupling, required if taus are matrices (default: None)
alphas : list of ndarray or None, optional
On-site energies for contacts, required for pattern (c) (default: None)
aOverlaps : list of ndarray or None, optional
On-site overlap matrices for contacts, required for pattern (c) (default: None)
betas : list of ndarray or None, optional
Hopping matrices between contact unit cells, required for pattern (c) (default: None)
bOverlaps : list of ndarray or None, optional
Overlap matrices between contact unit cells, required for pattern (c) (default: None)
eta : float, optional
Broadening parameter in eV (default: 1e-9)
Attributes
----------
F : ndarray
Fock matrix
S : ndarray
Overlap matrix
X : ndarray
Inverse square root of overlap matrix for orthogonalization
tauList : list
Contact coupling matrices
stauList : list
Contact coupling overlap matrices
aList : list
On-site energy matrices for contacts
aSList : list
On-site overlap matrices for contacts
bList : list
Hopping matrices between contact unit cells
bSList : list
Overlap matrices between contact unit cells
gPrev : list
Previous surface Green's functions for convergence
"""
[docs]
def __init__(self, Fock, Overlap, indsList, taus=None, staus=None, alphas=None, aOverlaps=None, betas=None, bOverlaps=None, eta=ETA):
"""
Initialize the surface Green's function calculator.
The initialization follows one of three patterns:
a) Fully automatic: Only provide Fock, Overlap, indsList, and connection indices in taus
b) Custom coupling: Provide Fock, Overlap, indsList, coupling matrices in taus, and staus
c) Fully specified: Provide all parameters including alphas, aOverlaps, betas, bOverlaps
Parameters
----------
Fock : ndarray
Fock matrix for the extended system
Overlap : ndarray
Overlap matrix for the extended system
indsList : list of lists
Lists of orbital indices for each contact region
taus : list or None, optional
Either coupling matrices or connection indices (default: None)
- If indices: [[contact1connection], [contact2connection]]
- If matrices: [tau1, tau2]
staus : list or None, optional
Overlap matrices for coupling, required if taus are matrices (default: None)
alphas : list of ndarray or None, optional
On-site energies for contacts, required for pattern (c) (default: None)
aOverlaps : list of ndarray or None, optional
On-site overlap matrices for contacts, required for pattern (c) (default: None)
betas : list of ndarray or None, optional
Hopping matrices between contact unit cells, required for pattern (c) (default: None)
bOverlaps : list of ndarray or None, optional
Overlap matrices between contact unit cells, required for pattern (c) (default: None)
eta : float, optional
Broadening parameter in eV (default: 1e-9)
Notes
-----
The initialization will raise an error if:
- The parameters don't match one of the three usage patterns
- taus contains matrices but staus is None
- alphas is provided but aOverlaps is None
- betas is provided but bOverlaps is None
"""
# Set up system
self.F = np.array(Fock)
self.S = np.array(Overlap)
self.X = np.array(fractional_matrix_power(Overlap, -0.5))
# Keep indsList as Python list - loop unrolls with concrete indices
self.indsList = [np.array(inds) for inds in indsList]
# Set Contact Coupling
if taus is None:
taus = [self.indsList[-1], self.indsList[0]]
taus = [np.array(tau) for tau in taus]
if len(np.shape(taus[0])) == 1:
self.tauFromFock = True
self.tauInds = taus
self.tauList = [self.F[np.ix_(taus[0],self.indsList[0])], self.F[np.ix_(taus[1],self.indsList[-1])]]
self.stauList = [self.S[np.ix_(taus[0],self.indsList[0])], self.S[np.ix_(taus[1],self.indsList[-1])]]
else:
self.tauFromFock = False
self.tauList = [np.array(tau) for tau in taus]
self.stauList = [np.array(stau) for stau in staus]
# Set up contact information
if alphas is None:
self.contactFromFock = True
self.setContacts()
else:
self.contactFromFock = False
self.setContacts(alphas, aOverlaps, betas, bOverlaps)
self.fermiList = [None]*len(indsList)
# Set up broadening for retarded/advanced Green's function, initialize g
self.eta = eta
# Store number of contacts for loop bounds
self.num_contacts = len(indsList)
# JIT compile g and sigma methods with static contact index
# This compiles separate versions for each contact (i=0, i=1, etc.)
# The expensive iterative calculation gets fully optimized
self.g = jit(self.g, static_argnums=(1,2,3))
self.sigma = jit(self.sigma, static_argnums=(1,2))
[docs]
def g(self, E, i, conv=SURFACE_GREEN_CONVERGENCE, relFactor=SURFACE_RELAXATION_FACTOR):
"""
Calculate surface Green's function for a contact.
Uses an iterative scheme to calculate the surface Green's function
for contact i at energy E. The iteration continues until the change
in the Green's function is below the convergence criterion.
Parameters
----------
E : float
Energy point in eV
i : int
Contact index
conv : float, optional
Convergence criterion for iteration (default: 1e-5)
relFactor : float, optional
Relaxation factor for iteration mixing (default: 0.1)
Returns
-------
ndarray
Surface Green's function matrix for contact i
Notes
-----
The method uses the previous solution as an initial guess to improve
convergence. For the first calculation at a given energy, it uses
zeros as the initial guess. The relaxation factor controls mixing
between iterations to help convergence.
"""
alpha = self.aList[i]
Salpha = self.aSList[i]
beta = self.bList[i]
Sbeta = self.bSList[i]
# Prepare matrices using JAX
A = (E+1j*self.eta)*Salpha - alpha
B = (E+1j*self.eta)*Sbeta - beta
B_dag = B.conj().T
# Iterative solution using jax.lax.while_loop
MAX_ITER = 2000
def cond_fun(state):
count, diff, g = state
return (diff > conv) & (count < MAX_ITER)
def body_fun(state):
count, diff, g = state
# Compute new Green's function using JAX operations
g_new = inv(A - B @ g @ B_dag)
# Compute convergence metric
dg = np.abs(g_new - g) / np.maximum(np.abs(g_new), 1e-12)
diff = np.max(dg)
# Apply relaxation mixing
g = g_new * relFactor + g * (1 - relFactor)
count += 1
return (count, diff, g)
# Initial state: (count, diff, g, g_prev)
init_state = (0, np.inf, inv(A))
count, diff, g = lax.while_loop(cond_fun, body_fun, init_state)
# Check convergence and warn if needed
#if diff>conv:
# jax.debug.print('Warning: exceeded max iterations! E: {E}, Conv: {diff}',
# E=E, diff=diff, ordered=True)
return g
[docs]
def setF(self, F, mu1=None, mu2=None):
"""
Update the Fock matrix and contact chemical potentials.
This method updates the system's Fock matrix and optionally shifts
the contact chemical potentials. If the contacts are extracted from
the Fock matrix, their parameters are automatically updated.
Parameters
----------
F : ndarray
New Fock matrix for the system
mu1 : float or None, optional
Chemical potential for first contact in eV (default: None)
mu2 : float or None, optional
Chemical potential for second contact in eV (default: None)
Notes
-----
If chemical potentials are provided, the corresponding contact
parameters are shifted to align with the new potentials.
"""
self.F = np.array(F)
if self.tauFromFock:
taus = self.tauInds
indsList = self.indsList # Python list
self.F = self.F.at[np.ix_(indsList[0], indsList[0])].set(self.F[np.ix_(taus[0], taus[0])].copy())
self.F = self.F.at[np.ix_(indsList[-1], indsList[-1])].set(self.F[np.ix_(taus[1], taus[1])].copy())
# Rebuild stacked arrays from new F
tau_temp = [self.F[np.ix_(taus[0],indsList[0])], self.F[np.ix_(taus[1],indsList[-1])]]
stau_temp = [self.S[np.ix_(taus[0],indsList[0])], self.S[np.ix_(taus[1],indsList[-1])]]
self.tauList = [np.array(tau) for tau in tau_temp]
self.stauList = [np.array(stau) for stau in stau_temp]
if not self.contactFromFock:
if self.fermiList[0] == None:
self.fermiList[0] = mu1
self.fermiList[-1] = mu2
else:
for i,mu in zip([0,-1], [mu1, mu2]):
fermi = self.fermiList[i]
if fermi is not None and mu is not None and fermi != mu:
dFermi = mu - fermi
# Use JAX immutable updates
self.aList = self.aList.at[i].set(self.aList[i] + dFermi*np.eye(len(self.aList[i])))
self.bList = self.bList.at[i].set(self.bList[i] + dFermi*self.bSList[i])
self.fermiList[i] = mu
[docs]
def sigma(self, E, i, conv=SURFACE_GREEN_CONVERGENCE):
"""
Calculate self-energy matrix for a contact.
Computes the self-energy matrix for contact i at energy E using
the surface Green's function. The self-energy represents the
effect of the semi-infinite contact on the device region.
Parameters
----------
E : float
Energy point in eV
i : int (can be traced)
Contact index
conv : float, optional
Convergence criterion for surface Green's function (default: 1e-5)
Returns
-------
ndarray
Self-energy matrix for contact i
"""
sigma = np.zeros(self.F.shape, dtype=complex)
inds = self.indsList[i]
stau = self.stauList[i]
tau = self.tauList[i]
t = E*stau - tau
sig = t @ self.g(E, i, conv) @ t.conj().T
sigma = sigma.at[np.ix_(inds, inds)].add(sig)
return sigma
[docs]
def sigmaTot(self, E, conv=SURFACE_GREEN_CONVERGENCE):
"""
Calculate total self-energy matrix from all contacts.
Computes the total self-energy matrix at energy E by summing
contributions from all contacts. This represents the combined
effect of all semi-infinite contacts on the device region.
Parameters
----------
E : float
Energy point in eV
conv : float, optional
Convergence criterion for surface Green's functions (default: 1e-5)
Returns
-------
ndarray
Total self-energy matrix from all contacts
"""
# Use Python for loop - JAX unrolls it with concrete indices
sigma = np.zeros(self.F.shape, dtype=complex)
for i in range(self.num_contacts):
sigma = sigma + self.sigma(E, i, conv)
return sigma