Source code for gauNEGF.surfGBethe

"""
Surface Green's function implementation for Bethe lattice contacts.

This module provides a Bethe lattice implementation for modeling semi-infinite
metallic contacts in quantum transport calculations. It supports:
- FCC [111] surface geometry with proper orbital symmetries
- Slater-Koster parameterization for s, p, and d orbitals
- Temperature-dependent calculations
- Spin-restricted and unrestricted calculations

The implementation follows the ANT.Gaussian approach [1], using a minimal basis
set with s, p, and d orbitals for each contact atom. The Bethe lattice model
provides an efficient way to describe bulk metallic electrodes while maintaining
proper orbital symmetries and electronic structure. This approach allows for
accurate modeling of metal-molecule interfaces without the computational cost
of explicit periodic boundary conditions.

References
----------
[1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models 
    in density functional theory based quantum transport calculations.
    The Journal of Chemical Physics, 134(4), 044118.
    DOI: 10.1063/1.3526044
"""

# Developed packages (import config BEFORE jax)
from gauNEGF.density import getFermiContact
from gauNEGF.config import (ETA, TEMPERATURE, SURFACE_GREEN_CONVERGENCE,
                            FERMI_CALCULATION_TOL)
from gauNEGF.utils import fractional_matrix_power

# Python packages
import jax
import jax.numpy as jnp
import jax.numpy.linalg as LA
from jax import jit
import jax.lax as lax

#Constants
dim = 9                 # size of single atom matrix: 1s + 3p + 5d
har_to_eV = 27.211386   # eV/Hartree
bohr_to_ang = 0.529177  # Bohr radius to Angstrom

# Bethe lattice surface Green's function for a device with contacts
[docs] class surfGB: """ Surface Green's function calculator for Bethe lattice contacts. This class implements the Bethe lattice approximation for modeling semi-infinite metallic contacts. It handles: - Contact geometry detection and setup - Slater-Koster parameter management - Energy-dependent self-energy calculations - Temperature effects - Spin configurations The Bethe lattice model represents the contacts as a semi-infinite tree-like structure with proper coordination number and orbital symmetries matching those of bulk FCC metals. This provides an efficient way to compute surface Green's functions and self-energies for the electrodes, as demonstrated by Jacob & Palacios in their 2011 paper [1]. Parameters ---------- F : ndarray Fock matrix S : ndarray Overlap matrix contacts : list of lists List of atom indices for each contact bar : object Gaussian interface object containing basis set information latFile : str, optional Filename for Bethe lattice parameters (default: 'Au') spin : str, optional Spin configuration ('r' for restricted) (default: 'r') eta : float, optional Broadening parameter in eV (default: 1e-9) T : float, optional Temperature in Kelvin (default: 0) Attributes ---------- cVecs : list Normal vectors for each contact surface latVecs : list Lattice vectors for each contact indsLists : list Orbital indices for each contact atom dirLists : list Direction vectors for nearest neighbors nIndLists : list Nearest neighbor indices for each atom gList : list surfGBAt objects for each contact References ---------- [1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models in density functional theory based quantum transport calculations. The Journal of Chemical Physics, 134(4), 044118. DOI: 10.1063/1.3526044 """ def __init__(self, F, S, contacts, bar, latFile='Au', spin='r', eta=ETA, T=TEMPERATURE): #Read contact/orbital information and store self.cVecs = [] self.latVecs = [] self.indsLists = [] self.dirLists = [] self.nIndLists = [] self.Xi = fractional_matrix_power(S, 0.5) if spin != 'r': self.Xi = self.Xi[::2, ::2] # Spin independent implementation, add degenerate spin terms during sigma generation self.spin = spin orbMap = bar.ibfatm[bar.ibfatm>0] orbTyp = bar.ibftyp[bar.ibfatm>0] coords =jnp.array([bar.c[i*3:(i+1)*3] for i in range(len(bar.c)//3)])*bohr_to_ang self.N = len(orbMap) # Collect contact information for contact in contacts: indsList = [] cList = [] for atom in contact: inds = jnp.where(jnp.isin(orbMap, atom))[0] cList.append(coords[atom-1]) assert len(inds) == 9, f'Error: Atom {atom} has {len(inds)} basis functions, expecting 9' inds = inds[jnp.argsort(abs(orbTyp[inds])//1000)] indsList.append(inds) self.indsLists.append(indsList) # Calculate plane direction using SVD cList = jnp.array(cList) # Convert list to JAX array centeredCoords = cList-jnp.mean(cList, axis=0) _, _, Vt = LA.svd(centeredCoords) contDir = jnp.mean(cList, axis=0)-jnp.mean(coords, axis=0) contVec = Vt[-1] if jnp.dot(contDir, contVec)<0: contVec *= -1 self.cVecs.append(contVec) # Calculate one lattice direction for lining up atoms vInd = jnp.argmin(jnp.array([LA.norm(v - cList[0]) for v in cList[1:]]))+1 latVec = cList[vInd]-cList[0] latDist = LA.norm(latVec) self.latVecs.append(latVec/latDist) # Calculate rest of lattice directions nVecs1 = self.genNeighbors(contVec, latVec) # Calculate second direction (in case off by rotation) nVecs2 = self.genNeighbors(contVec, -latVec) # Use lattice vectors to see what nearest neighbors nIndList = [] nVecs = nVecs1.copy() for c in cList: nAtVecs = [] for c2 in coords: l = LA.norm(c2-c) # if within 0.2*nearest neighbor dist and not the same atom if l > 0.8 * latDist and l < 1.2 * latDist and not jnp.allclose(c2, c): nAtVecs.append((c2-c)/l) #Unit vector for that direction # Align out of plane vectors (two options) nVecs = nVecs1.copy() outOfPlane = [3,4,5,9,10,11] for vec in nAtVecs: valList = jnp.array([jnp.dot(vec, direction) for direction in nVecs2]) dirInd = jnp.argmax(valList) if dirInd in outOfPlane and valList[dirInd]>0.9: nVecs = nVecs2.copy() break # Now that orientation is fixed, track all neighbors nInds = [] for vec in nAtVecs: valList = jnp.array([jnp.dot(vec, direction) for direction in nVecs]) dirInd = jnp.argmax(valList) if valList[dirInd]>0.9: nInds.append(dirInd) else: print(f'Warning: Lattice Vec #{dirInd} mismatch, neighbor not recorded') # write neighbor indices for each atom nIndList.append(nInds) # write direction vectors and neighbors for each contact self.nIndLists.append(nIndList) self.dirLists.append(nVecs) # Read Bethe lattice parameters and generate hopping/overlap matrices self.readBetheParams(latFile) self.Slists = [] self.Vlists = [] for dirList in self.dirLists: # Construct hopping matrices and store to contact Slist = [] Vlist = [] for d in dirList: Slist.append(self.constructMat(self.Sdict, d, self.SOC)) Vlist.append(self.constructMat(self.Vdict, d, self.SOC)) self.Slists.append(Slist) self.Vlists.append(Vlist) self.num_contacts = len(self.indsLists) # Use surfGBAt() object to store the atomic Bethe lattice green's function for each contact self.gList = [surfGBAt(self.H0.copy(), Slist, Vlist, eta, T, self.SOC) for Slist, Vlist in zip(self.Slists, self.Vlists)] # Calculate fermi level # With SOC, 18x18 basis already includes both spins -> use full ne # Without SOC, spin-restricted -> ne/2 per spin channel ne_fermi = self.ne if self.SOC else self.ne/2 fermi = self.gList[0].calcFermi(ne_fermi) for g in self.gList: g.fermi = fermi g.fermi0 = fermi # Store variables self.cList = cList #first contact coords, used for testing self.F = F self.S = S self.eta = eta # Note: JIT compilation is on surfGBAt.sigmaSurf/sigmaK (the expensive Dyson # iteration). surfGB.sigma is NOT JIT'd so it can access dFermi and shift # E before calling the atomic methods -- keeping dFermi outside the JIT # boundary avoids stale-closure bugs and unnecessary recompilation.
[docs] def genNeighbors(self, plane_normal, first_neighbor): """ Generate 12 nearest neighbor unit vectors for an FCC [111] surface. Creates a list of unit vectors representing the 12 nearest neighbors in an FCC lattice: - 6 in-plane vectors forming a hexagonal pattern (3 pairs of opposite vectors) - 6 out-of-plane vectors forming triangular patterns (3 pairs of opposite vectors) Parameters ---------- plane_normal : ndarray Vector normal to the crystal plane (will be normalized) first_neighbor : ndarray Vector to one nearest neighbor (will be projected onto plane) Returns ------- list 12 unit vectors representing nearest neighbor directions """ # Project first_neighbor onto plane perpendicular to plane_normal proj = first_neighbor - jnp.dot(first_neighbor, plane_normal) * plane_normal first_neighbor = proj / LA.norm(proj) # Generate in-plane vectors using 60-degree rotations in_plane_vectors = [] rotation_angle = jnp.pi / 3 # 60 degrees for i in range(3): angle = i * rotation_angle # Rodrigues rotation formula cos_theta = jnp.cos(angle) sin_theta = jnp.sin(angle) K = jnp.array([[0, -plane_normal[2], plane_normal[1]], [plane_normal[2], 0, -plane_normal[0]], [-plane_normal[1], plane_normal[0], 0]]) R = jnp.eye(3) + sin_theta * K + (1 - cos_theta) * jnp.matmul(K, K) rotated_vector = jnp.dot(R, first_neighbor) in_plane_vectors.append(rotated_vector / LA.norm(rotated_vector)) # Generate out-of-plane vectors out_of_plane_angle = jnp.arccos(1/jnp.sqrt(3)) # ~54.74 out_of_plane_vectors = [] # Add 30deg = pi/6 rotation to base vector before going out of plane rot_angle = jnp.pi/6 K = jnp.array([[0, -plane_normal[2], plane_normal[1]], [plane_normal[2], 0, -plane_normal[0]], [-plane_normal[1], plane_normal[0], 0]]) R = jnp.eye(3) + jnp.sin(rot_angle) * K + (1 - jnp.cos(rot_angle)) * jnp.matmul(K, K) rotated_first = jnp.dot(R, first_neighbor) out_of_plane_base = jnp.cos(out_of_plane_angle) * rotated_first + \ jnp.sin(out_of_plane_angle) * plane_normal for i in range(3): angle = i * 2 * jnp.pi / 3 # 120 degree rotations cos_theta = jnp.cos(angle) sin_theta = jnp.sin(angle) K = jnp.array([[0, -plane_normal[2], plane_normal[1]], [plane_normal[2], 0, -plane_normal[0]], [-plane_normal[1], plane_normal[0], 0]]) R = jnp.eye(3) + sin_theta * K + (1 - cos_theta) * jnp.matmul(K, K) rotated_vector = jnp.dot(R, out_of_plane_base) out_of_plane_vectors.append(rotated_vector) # Add corresponding opposite vectors at the (k+6)%12 location all_vectors = in_plane_vectors + out_of_plane_vectors for i in range(6): all_vectors.append(-all_vectors[i]) # Return vectors return all_vectors
# Read parameters from filename.bethe file, check values, store into dicts
[docs] def readBetheParams(self, filename): """ Read Slater-Koster parameters from a .bethe file. Reads and validates parameters for minimal basis with single s, p, and d orbitals. Parameters are stored in dictionaries for onsite energies, hopping integrals, and overlap matrices. Parameters ---------- filename : str Name of the .bethe file (without extension) Raises ------ AssertionError If parameters are missing or invalid Notes ----- Parameters are sorted into: - Edict: Onsite energies (converted from Hartrees to eV) - Vdict: Hopping parameters (converted from Hartrees to eV) - Sdict: Overlap parameters """ params = {} with open(filename+'.bethe', 'r') as f: for line in f: # Skip empty lines if not line.strip(): continue # Split on comma and strip whitespace line = line.replace(' ','') key, value = line.split('=') params[key] = float(value) # Check to make sure parameters are all specified # Note: set up only for minimal basis with single s, p, and d orbital expected_keys = {'ne', 'es', 'ep', 'edd', 'edt', 'sss', 'sps', 'pps', 'ppp', 'sds', 'pds', 'pdp', 'dds', 'ddp', 'ddd', 'Ssss', 'Ssps', 'Spps', 'Sppp', 'Ssds', 'Spds', 'Spdp', 'Sdds', 'Sddp', 'Sddd'} optional_keys = {'soc_p', 'soc_d'} found_keys = set(params.keys()) assert expected_keys.issubset(found_keys) and found_keys.issubset(expected_keys | optional_keys), \ f"Error reading file: Found Bethe parameters: {list(params.keys())}, expected: {sorted(expected_keys)}" # sort parameters and convert Hartrees to eV self.ne = params['ne'] self.Edict = {k[1:]:params[k]*har_to_eV for k in params if k.startswith('e')} self.Sdict = {k[1:]:params[k] for k in params if k.startswith('S')} self.Vdict = {k:params[k]*har_to_eV for k in params if not k.startswith('e') and not k.startswith('S') and not k.startswith('soc') and k != 'ne'} # Setup onsite H0 matrix before Fermi level shifting hdiag = [self.Edict['s']]+ [self.Edict['p']]*3 + [self.Edict['dd']]+ \ [self.Edict['dt']]*2 + [self.Edict['dd'], self.Edict['dt']] H0 = jnp.diag(jnp.array(hdiag)) if 'soc_p' in params and 'soc_d' in params and self.spin != 'r': self.SOC = True lambdas = [0.0, params['soc_p'] * har_to_eV, params['soc_d'] * har_to_eV] from gauNEGF.spinTools import constructSOCterm Hsoc = constructSOCterm(lambdas) self.H0 = jnp.kron(H0, jnp.eye(2)) + jnp.array(Hsoc) else: self.SOC = False self.H0 = H0
[docs] def constructMat(self, Mdict, dirCosines, SOC=False): """ Construct hopping/overlap matrix using Slater-Koster formalism. Builds a 9x9 matrix for s, p, and d orbital interactions based on the Slater-Koster two-center approximation. The matrix is first constructed assuming a [0,0,1] bond direction, then rotated to the given direction using direction cosines. When SOC=True, expands to 18x18 via kron(M, I2). Parameters ---------- Mdict : dict Dictionary of Slater-Koster parameters (sss, sps, pps, etc.) dirCosines : ndarray Array [l,m,n] of direction cosines for the bond Returns ------- ndarray 9x9 matrix containing orbital interactions in the rotated frame Notes ----- Matrix blocks: - [0,0]: s-s interaction - [0:4,0:4]: s-p block - [0:4,4:9]: s-d and p-d blocks - [4:9,4:9]: d-d block """ M = jnp.zeros((dim, dim)) #Original matrix before rotation - assuming [0,0,1] bond direction # s-s coefficient M = M.at[0,0].set(Mdict['sss']) # s-p block M = M.at[0,3].set(Mdict['sps']) #s-pz M = M.at[3,0].set(-Mdict['sps']) #pz-s # p-p block M = M.at[1,1].set(Mdict['ppp']) #px-px M = M.at[2,2].set(Mdict['ppp']) #py-py M = M.at[3,3].set(Mdict['pps']) #pz-pz # s-d block M = M.at[0, 4].set(Mdict['sds']) #s - d3z²-r² M = M.at[4, 0].set(Mdict['sds']) # p-d block M = M.at[1,5].set(Mdict['pdp']) #px - dxz M = M.at[2,6].set(Mdict['pdp']) #py - dyz M = M.at[3,4].set(Mdict['pds']) #pz - d3z²-r² M = M.at[5,1].set(-Mdict['pdp']) #dxz - px M = M.at[6,2].set(-Mdict['pdp']) #dyz - py M = M.at[4,3].set(-Mdict['pds']) #d3z²-r² - pz # d-d block M = M.at[4,4].set(Mdict['dds']) #d3z²-r² - d3z²-r² M = M.at[5,5].set(Mdict['ddp']) #dxz - dxz M = M.at[6,6].set(Mdict['ddp']) #dyz - dyz M = M.at[7,7].set(Mdict['ddd']) #dx²-y² - dx²-y² M = M.at[8,8].set(Mdict['ddd']) #dxy - dxy # Initialize 9x9 transformation matrix and polar directions tr = jnp.zeros((9, 9)) x, y, z = dirCosines theta = jnp.arccos(z) # polar angle from z-axis phi = jnp.arctan2(y, x) # azimuthal angle in x-y plane # s orbital (1x1) at position [0,0] - always 1 since spherically symmetric tr = tr.at[0,0].set(1.0) # p orbitals (3x3) at positions [1:4,1:4] # [px,py,pz] block - describes how p orbitals transform under rotation tr = tr.at[1:4,1:4].set(jnp.array([ [jnp.cos(theta) * jnp.cos(phi), -jnp.sin(phi) , jnp.sin(theta)*jnp.cos(phi)], [jnp.cos(theta) * jnp.sin(phi), jnp.cos(phi) , jnp.sin(theta)*jnp.sin(phi)], [-jnp.sin(theta) , 0 , jnp.cos(theta)] ])) # d orbitals (5x5) at positions [4:9,4:9] # [d3z2-r2, dxz, dyz, dx2-y2, dxy] block - transforms the five d orbitals d_block = jnp.zeros((5,5)) # Copying formula from ANT.Gaussian directly d_block = d_block.at[0,0].set((3 * z**2 - 1) / 2) d_block = d_block.at[0,1].set(-jnp.sqrt(3) * jnp.sin(2*theta) / 2) d_block = d_block.at[0,3].set(jnp.sqrt(3) * jnp.sin(theta)**2 / 2) d_10 = jnp.sqrt(3) * jnp.sin(2*theta) * jnp.cos(phi) / 2 d_block = d_block.at[1,0].set(d_10) d_block = d_block.at[1,1].set(jnp.cos(2*theta) * jnp.cos(phi)) d_block = d_block.at[1,2].set(-jnp.cos(theta) * jnp.sin(phi)) d_block = d_block.at[1,3].set(-d_10 / jnp.sqrt(3)) d_block = d_block.at[1,4].set(jnp.sin(theta) * jnp.sin(phi)) d_20 = jnp.sqrt(3) * jnp.sin(2*theta) * jnp.sin(phi) / 2 d_block = d_block.at[2,0].set(d_20) d_block = d_block.at[2,1].set(jnp.cos(2*theta) * jnp.sin(phi)) d_block = d_block.at[2,2].set(jnp.cos(theta) * jnp.cos(phi)) d_block = d_block.at[2,3].set(-d_20 / jnp.sqrt(3)) d_block = d_block.at[2,4].set(-jnp.sin(theta) * jnp.cos(phi)) d_block = d_block.at[3,0].set(jnp.sqrt(3) * jnp.sin(theta)**2 * jnp.cos(2*phi) / 2) d_block = d_block.at[3,1].set(jnp.sin(2*theta) * jnp.cos(2*phi) / 2) d_block = d_block.at[3,2].set(-jnp.sin(theta) * jnp.sin(2*phi)) d_block = d_block.at[3,3].set((1 + jnp.cos(theta)**2) * jnp.cos(2*phi) / 2) d_block = d_block.at[3,4].set(-jnp.cos(theta) * jnp.sin(2*phi)) d_block = d_block.at[4,0].set(jnp.sqrt(3) * jnp.sin(theta)**2 * jnp.sin(2*phi) / 2) d_block = d_block.at[4,1].set(jnp.sin(2*theta) * jnp.sin(2*phi) / 2) d_block = d_block.at[4,2].set(jnp.sin(theta) * jnp.cos(2*phi)) d_block = d_block.at[4,3].set((1 + jnp.cos(theta)**2) * jnp.sin(2*phi) / 2) d_block = d_block.at[4,4].set(jnp.cos(theta) * jnp.cos(2*phi)) tr = tr.at[4:9,4:9].set(d_block) # Apply transformation M_rot = tr @ M @ tr.T if SOC: return jnp.kron(M_rot, jnp.eye(2)) return M_rot
[docs] def sigma(self, E, i, conv=SURFACE_GREEN_CONVERGENCE): """ Calculate self-energy matrix for a specific contact. Computes the self-energy matrix for contact i by: 1. Calculating surface self-energies for all 9 directions 2. Summing contributions from directions not connected to the device 3. Applying de-orthonormalization if needed 4. Handling spin configurations Parameters ---------- E : float Energy point for self-energy calculation (in eV) i : int Index of the contact to calculate self-energy for conv : float, optional Convergence criterion for self-energy calculation (default: SURFACE_GREEN_CONVERGENCE) Returns ------- ndarray Self-energy matrix for the specified contact, with dimensions: - (N, N) for restricted calculations - (2N, 2N) for unrestricted or generalized spin calculations References ---------- [1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models in density functional theory based quantum transport calculations. The Journal of Chemical Physics, 134(4), 044118. DOI: 10.1063/1.3526044 """ # Shift E to reference frame of immutable H0/Vlist0 (keeps dFermi # outside the JIT boundary of surfGBAt.sigma) E_shifted = E - self.gList[i].dFermi sigSurf = self.gList[i].sigmaSurf(E_shifted, conv) # Get contact-specific data for this static contact index nIndLists_i = self.nIndLists[i] indsLists_i = self.indsLists[i] if self.SOC: # SOC matrices are 18x18 (spin already included, interleaved ordering) sig = jnp.zeros((2*self.N, 2*self.N), dtype=complex) for nInds, Finds in zip(nIndLists_i, indsLists_i): sigInds = list(set(range(9)) - {int(x) for x in nInds}) sigAtom = sum(sigSurf[j] for j in sigInds) # Expand orbital indices to spin-orbital: i -> [2*i, 2*i+1] socFinds = jnp.array([idx for i in Finds for idx in (2*i, 2*i+1)]) sig = sig.at[jnp.ix_(socFinds, socFinds)].set(sigAtom) # De-orthonormalization with expanded Xi sig = lax.cond(self.Sdict['sss'] == 0, lambda s: jnp.kron(self.Xi, jnp.eye(2)) @ s @ jnp.kron(self.Xi, jnp.eye(2)), lambda s: s, sig) # No kron needed -- spin is already in SOC matrices else: sig = jnp.zeros((self.N, self.N), dtype=complex) for nInds, Finds in zip(nIndLists_i, indsLists_i): sigInds = list(set(range(9)) - {int(x) for x in nInds}) sigAtom = sum(sigSurf[j] for j in sigInds) sig = sig.at[jnp.ix_(Finds, Finds)].set(sigAtom) # Apply de-orthonormalization technique from ANT.Gaussian if orthonormal sig = lax.cond(self.Sdict['sss'] == 0, lambda s: self.Xi @ s @ self.Xi, lambda s: s, sig) # Handle spin - use if/else since spin is static if self.spin == 'u' or self.spin == 'ro': sig = jnp.kron(jnp.eye(2), sig) elif self.spin == 'g': sig = jnp.kron(sig, jnp.eye(2)) # else: spin == 'r', keep sig as-is return sig
[docs] def sigmaTot(self, E, conv=SURFACE_GREEN_CONVERGENCE): """ Calculate total self-energy matrix from all contacts. Sums sigma(E, i) over all contacts and returns the result in the full device basis. Parameters ---------- E : float Energy point for Green's function calculation (in eV) conv : float, optional Convergence criterion for self-energy calculation (default: 1e-5) Returns ------- ndarray Total self-energy matrix in the full device basis """ num_contacts = len(self.indsLists) sigs = [self.sigma(E, i, conv) for i in range(num_contacts)] return sum(sigs)
[docs] def crossTermQ(self, E, i, conv=SURFACE_GREEN_CONVERGENCE): """Cross-term Q_sym for contact i in full device basis. Mirrors surfGB.sigma: iterates atoms in contact i, calls gList[i].crossTermQSurf with per-atom active directions, assembles result, then applies de-orthonormalization (same as sigma). """ nIndLists_i = self.nIndLists[i] indsLists_i = self.indsLists[i] E_shifted = E - self.gList[i].dFermi if self.SOC: # SOC matrices are 18x18 (spin already included, interleaved ordering) sig = jnp.zeros((2*self.N, 2*self.N), dtype=complex) for nInds, Finds in zip(nIndLists_i, indsLists_i): sigInds = list(set(range(9)) - {int(x) for x in nInds}) Q_atom = self.gList[i].crossTermQSurf(E_shifted, sigInds=sigInds, conv=conv) # Expand orbital indices to spin-orbital: k -> [2*k, 2*k+1] socFinds = jnp.array([idx for k in Finds for idx in (2*k, 2*k+1)]) sig = sig.at[jnp.ix_(socFinds, socFinds)].set(Q_atom) # De-orthonormalization with expanded Xi (same condition as sigma) sig = lax.cond(self.Sdict['sss'] == 0, lambda s: jnp.kron(self.Xi, jnp.eye(2)) @ s @ jnp.kron(self.Xi, jnp.eye(2)), lambda s: s, sig) # No trailing spin kron -- spin already in SOC matrices else: sig = jnp.zeros((self.N, self.N), dtype=complex) for nInds, Finds in zip(nIndLists_i, indsLists_i): sigInds = list(set(range(9)) - {int(x) for x in nInds}) Q_atom = self.gList[i].crossTermQSurf(E_shifted, sigInds=sigInds, conv=conv) sig = sig.at[jnp.ix_(Finds, Finds)].set(Q_atom) # Apply de-orthonormalization if orthonormal basis (same as sigma) sig = lax.cond(self.Sdict['sss'] == 0, lambda s: self.Xi @ s @ self.Xi, lambda s: s, sig) if self.spin == 'u' or self.spin == 'ro': sig = jnp.kron(jnp.eye(2), sig) elif self.spin == 'g': sig = jnp.kron(sig, jnp.eye(2)) return sig
[docs] def crossTermQTot(self, E, conv=SURFACE_GREEN_CONVERGENCE): """Total cross-term Q_sym from all contacts.""" num_contacts = len(self.indsLists) qs = [self.crossTermQ(E, i, conv) for i in range(num_contacts)] return sum(qs)
[docs] def getSigma(self, Elist=[None, None], conv=SURFACE_GREEN_CONVERGENCE): """ Helper method for getting the left and right contact self-energies Parameters ---------- Elist : tuple, optional A list of contact energies for selecting sigma, (default: use contact Fermi energy) conv: float, optional Convergence criterion for the self-energy matrix Returns ------- tuple A tuple of both self-energy matrices (ndarrays) """ if Elist[0] is None: Elist[0] = self.gList[0].fermi if Elist[1] is None: Elist[1] = self.gList[-1].fermi return (self.sigma(Elist[0], 0, conv), self.sigma(Elist[1], -1, conv))
[docs] def updateFermi(self, i, Ef): """ Update Fermi energy for a specific contact. Shifts the Hamiltonian of contact i to align its Fermi level with the specified energy. Parameters ---------- i : int Contact index Ef : float New Fermi energy in eV """ self.gList[i].updateH(Ef)
[docs] def setF(self, F, muL, muR): """ Update Fock matrix and contact chemical potentials. Sets the Fock matrix and updates the Fermi levels of the left and right contacts if they have changed. Parameters ---------- F : ndarray New Fock matrix muL : float Chemical potential for left contact in eV muR : float Chemical potential for right contact in eV """ self.F = F if self.gList[0].fermi != muL: self.updateFermi(0, muL) if self.gList[-1].fermi != muR: self.updateFermi(-1, muR)
## TESTING METHODS FOR SLATER-KOSTER INTERACTIONS:
[docs] def testDOrbitalFunctions(self): """ Test d orbital angular functions. Validates the angular dependence of d orbital interactions by checking: - dxy interaction along x-axis (should be zero) - dx2-y2 interaction along x-axis (should be sqrt(3)/2 * sds) - dz2 interaction along x-axis (should be -1/2 * sds) """ # Use values from the Bethe parameter dictionaries Vdict = self.Vdict # Contains hopping parameters # Test along x-axis [1,0,0] M = self.constructMat(self.Vdict, [1, 0, 0]) # dxy should be zero along x-axis jnp.testing.assert_almost_equal(M[0,8], 0.0, err_msg="dxy not zero along x-axis") # dx2-y2 should be sqrt(3)/2 * sds along x-axis jnp.testing.assert_almost_equal(M[0,7], jnp.sqrt(3)/2 * Vdict['sds'], err_msg="dx2-y2 incorrect along x-axis") # dz2 should be -1/2 along x-axis jnp.testing.assert_almost_equal(M[0,4], -0.5 * Vdict['sds'], err_msg="dz2 incorrect along x-axis") print("d orbital angular function tests passed!")
[docs] def testDOrbitalSymmetry(self): """ Test d orbital symmetry properties. Validates that d orbital interactions respect inversion symmetry by comparing interactions along opposite directions. """ # Test inversion symmetry dir1 = [1/jnp.sqrt(2), 1/jnp.sqrt(2), 0] dir2 = [-1/jnp.sqrt(2), -1/jnp.sqrt(2), 0] M1 = self.constructMat(self.Vdict, dir1) M2 = self.constructMat(self.Vdict, dir2) # d-d block should be identical under inversion jnp.testing.assert_array_almost_equal( M1[4:,4:], M2[4:,4:], err_msg="d-d block not symmetric under inversion") print("d orbital symmetry tests passed!")
[docs] def testPDInteraction(self): """ Test p-d orbital interactions. Validates p-d orbital interactions by checking: - px-dxy interaction along x-axis (should be zero) - pz-dz2 interaction along z-axis (should be pure sigma) """ Vdict = self.Vdict # Test px-dxy interaction along x-axis M = self.constructMat(Vdict, [1, 0, 0]) # px-dxy should be zero along x-axis jnp.testing.assert_almost_equal( M[1,8], 0.0, err_msg="px-dxy interaction incorrect along x-axis") # Test pz-dz2 interaction along z-axis M = self.constructMat(Vdict, [0, 0, 1]) expected = Vdict['pds'] # Should be pure sigma jnp.testing.assert_almost_equal( M[3,4], expected, err_msg="pz-dz2 interaction incorrect along z-axis") print("p-d interaction tests passed!")
[docs] def testDDInteraction(self): """ Test d-d orbital interactions. Validates d-d orbital interactions by checking: - dyz-dyz interaction along x-axis (should be pure delta) - dz2-dz2 interaction along z-axis (should be pure sigma) """ Vdict = self.Vdict # Test dyz-dyz interaction along x-axis M = self.constructMat(Vdict, [1, 0, 0]) # Should be pure delta interaction expected = Vdict['ddd'] jnp.testing.assert_almost_equal( M[6,6], expected, err_msg="dyz-dyz interaction incorrect along x-axis") # Test dz2-dz2 interaction along x-axis M = self.constructMat(Vdict, [0, 0, 1]) # Should be pure sigma interaction expected = Vdict['dds'] jnp.testing.assert_almost_equal( M[4,4], expected, err_msg="dz2-dz2 interaction incorrect along z-axis") print("d-d interaction tests passed!")
[docs] def testHoppingPhysics(self): """ Test physical properties of hopping matrices. Validates hopping matrix physics by checking: - s-p hopping antisymmetry - Conservation of total s-p hopping magnitude - Proper angular dependence along principal axes and 45-degree rotations """ eps = 1e-10 # Tolerance for floating point comparisons # Get reference hopping values from [0,0,1] configuration s_p_mag = abs(self.Vdict['sps']) # Magnitude of s-p hopping # Test set of physically important directions test_cases = [ # Principal axes ([0, 0, 1], "z-axis"), ([1, 0, 0], "x-axis"), ([0, 1, 0], "y-axis"), # 45-degree rotations ([1/jnp.sqrt(2), 0, 1/jnp.sqrt(2)], "45° in xz-plane"), ([0, 1/jnp.sqrt(2), 1/jnp.sqrt(2)], "45° in yz-plane"), ([1/jnp.sqrt(2), 1/jnp.sqrt(2), 0], "45° in xy-plane"), ] print("\nTesting hopping matrix physics...") for direction, name in test_cases: direction = jnp.array(direction) x, y, z = direction print(f"\nChecking {name} direction: [{x:.3f}, {y:.3f}, {z:.3f}]") V = self.constructMat(self.Vdict, direction) # Check s-p hopping antisymmetry for i in range(1, 4): # Check all p orbitals assert abs(V[0,i] + V[i,0]) < eps, \ f"s-p hopping not antisymmetric for p{i}" # Check total s-p hopping magnitude is preserved s_p_total = jnp.sqrt(V[0,1]**2 + V[0,2]**2 + V[0,3]**2) assert abs(s_p_total - s_p_mag) < eps, \ f"s-p hopping magnitude not preserved: {s_p_total:.6f} != {s_p_mag:.6f}" # Print values for verification print(f"s-px: {V[0,1]:.3f}, px-s: {V[1,0]:.3f}") print(f"s-py: {V[0,2]:.3f}, py-s: {V[2,0]:.3f}") print(f"s-pz: {V[0,3]:.3f}, pz-s: {V[3,0]:.3f}") print(f"Total s-p magnitude: {s_p_total:.3f}") print("\nAll hopping physics tests passed!")
[docs] def runAllTests(self): """ Run all validation tests for surfGB. Executes all test methods to validate: - d orbital angular functions - d orbital symmetry - p-d interactions - d-d interactions - General hopping physics """ print("Running Slater-Koster projection tests...") self.testDOrbitalFunctions() self.testDOrbitalSymmetry() self.testPDInteraction() self.testDDInteraction() self.testHoppingPhysics() print("\nAll tests passed!")
# Bethe lattice surface Green's function for a single atom
[docs] class surfGBAt: """ Atomic-level Bethe lattice Green's function calculator. This class implements the surface Green's function calculation for a single atom in the Bethe lattice, handling: - Onsite and hopping matrix construction - Self-energy calculations for bulk and surface - Temperature effects - Fermi energy optimization Parameters ---------- H : ndarray Onsite Hamiltonian matrix (9x9 for minimal basis) Slist : list of ndarray List of 12 overlap matrices for nearest neighbors Vlist : list of ndarray List of 12 hopping matrices for nearest neighbors eta : float Broadening parameter in eV T : float, optional Temperature in Kelvin (default: 0) Attributes ---------- NN : int Number of nearest neighbors (fixed to 12 for FCC) fermi : float Current Fermi energy F : ndarray Extended Fock matrix including neighbors S : ndarray Extended overlap matrix including neighbors """
[docs] def __init__(self, H, Slist, Vlist, eta, T=TEMPERATURE, SOC=False): """ Initialize surfGBAt with Hamiltonian and neighbor matrices. Parameters ---------- H : ndarray Onsite Hamiltonian matrix (9x9 for minimal basis) Slist : list of ndarray List of 12 overlap matrices for nearest neighbors Vlist : list of ndarray List of 12 hopping matrices for nearest neighbors eta : float Broadening parameter in eV T : float, optional Temperature in Kelvin (default: 0) SOC : bool, optional Whether to include spin-orbit coupling (default: False) Raises ------ AssertionError If matrix dimensions are incorrect or number of neighbors != 12 """ self.dim = dim*2 if SOC else dim assert jnp.shape(H) == (self.dim,self.dim), f"Error with H dim, should be {self.dim}x{self.dim}" for S,V in zip(Slist, Vlist): assert jnp.shape(S) == (self.dim,self.dim), f"Error with S dim, should be {self.dim}x{self.dim}" assert jnp.shape(V) == (self.dim,self.dim), f"Error with F dim, should be {self.dim}x{self.dim}" self.H = H self.Slist = Slist self.Vlist = Vlist self.spin = 'r' # spin-dependence not implemented yet self.NN = len(Slist) assert self.NN == 12, "Error: surfGBAt only implemented for FCC using 12 NN" self.eta = eta self.T = T self.fermi = None self.H0 = jnp.array(H) # immutable reference (never mutated) self.Vlist0 = jnp.array(Vlist) # immutable reference (never mutated) self.fermi0 = None # reference Fermi level self.dFermi = 0.0 # shift from reference self.num_contacts = 1 self.updateH() # JIT compile methods with self as static argument self.sigmaK = jit(self.sigmaK, static_argnums=(1,2)) self.sigmaSurf = jit(self.sigmaSurf, static_argnums=(1,2))
[docs] def updateH(self, fermi=None): """ Update Hamiltonian and Fock matrix. Sets F = H (9x9) and S = I for protocol compatibility with density.py. Parameters ---------- fermi : float, optional New Fermi energy setpoint in eV (default: None) """ if fermi is not None and fermi != self.fermi: self.fermi = fermi self.dFermi = fermi if self.fermi0 is None else fermi - self.fermi0 if self.fermi0 is None: self.fermi0 = fermi # Build H and Vlist from H0/Vlist0 plus current dFermi (for non-JIT callers) self.H = self.H0 + self.dFermi * jnp.eye(self.dim) self.Vlist = jnp.array([self.Vlist0[j] + self.dFermi * self.Slist[j] for j in range(self.NN)]) self.F = self.H # 9x9 self.S = jnp.eye(self.dim)
# Calculate sigmaK for the bulk
[docs] def sigmaK(self, E, conv=SURFACE_GREEN_CONVERGENCE, mix=0.5): """ Calculate bulk self-energies for all 12 lattice directions. Computes self-energies for an FCC lattice with the following geometry: [3x out of plane dir] \|/ [3x plane dir] - o - [3x plane dir] /|\ [3x out of plane dir] Uses a self-consistent iteration scheme with mixing to solve the Dyson equation. Parameters ---------- E : float Energy point for Green's function calculation (in eV) conv : float, optional Convergence criterion for Dyson equation (default: 1e-5) mix : float, optional Mixing factor for Dyson equation (default: 0.5) Returns ------- ndarray Array of 12 self-energy matrices (9x9 each) in order by lattice direction Notes ----- Uses previous solution as initial guess when energy point is close to previous calculation to improve convergence. """ sigmaK = jnp.array([jnp.eye(self.dim)*-1j for k in range(self.NN)], dtype=complex) # E is pre-shifted by the caller (wrapper subtracts dFermi before entering # the JIT boundary). Using E directly keeps H0/Vlist0 as immutable # constants in the compiled code, eliminating JIT recompilation during SCF. E_eff = E + self.eta*1j A = E_eff*jnp.eye(self.dim) - self.H0 #Self-consistency loop using jax.lax.while_loop maxIter = 1000 def cond_fun(state): count, diff, sigmaK, sigmaK_ = state return (diff > conv) & (count < maxIter) def body_fun(state): count, diff, sigmaK, sigmaK_ = state sigmaK_ = sigmaK.copy() sigTot = jnp.sum(sigmaK, axis=0) for k in range(self.NN): pair_k = (k + 6)%12 # Opposite direction vector gK = LA.inv(A - sigTot + sigmaK[pair_k]) # subtracted from sigTot B = E_eff*self.Slist[k] - self.Vlist0[k] B_bar = E_eff*self.Slist[k].conj().T - self.Vlist0[k].conj().T sigmaK = sigmaK.at[k].set(mix*(B@gK@B_bar) + (1-mix)*sigmaK_[k]) # Convergence Check diff = jnp.max(jnp.abs(sigmaK - sigmaK_))/jnp.max(jnp.abs(sigmaK_)) count += 1 return (count, diff, sigmaK, sigmaK_) # Initial state: (count, diff, sigmaK, sigmaK_) init_state = (0, jnp.inf, sigmaK, sigmaK.copy()) count, diff, sigmaK, sigmaK_ = lax.while_loop(cond_fun, body_fun, init_state) return sigmaK
[docs] def sigmaSurf(self, E, conv=SURFACE_GREEN_CONVERGENCE, mix=0.5): """ Calculate surface self-energies for an FCC lattice. Computes self-energies for atoms at the surface with the geometry: [3x plane dir] - o - [3x plane dir] /|\ [3x out of plane dir] Uses a self-consistent iteration scheme with mixing to solve the Dyson equation. The implementation follows the Bethe lattice approach described in Jacob & Palacios (2011), where the self-energy is computed recursively for a semi-infinite tree-like structure that preserves the proper coordination number and orbital symmetries of bulk FCC metals. Parameters ---------- E : float Energy point for Green's function calculation (in eV) conv : float, optional Convergence criterion for Dyson equation (default: 1e-5) mix : float, optional Mixing factor for Dyson equation (default: 0.5) Returns ------- list List of self-energy matrices for the surface atom. Notes ----- First calculates bulk self-energies using sigmaK, then iterates to find surface self-energies for the 9 surface directions. The recursive method ensures proper treatment of the metal-molecule interface while maintaining computational efficiency. References ---------- [1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models in density functional theory based quantum transport calculations. The Journal of Chemical Physics, 134(4), 044118. DOI: 10.1063/1.3526044 """ sigSurf = self.sigmaK(E, conv, mix)[:9] #Self-consistency loop using jax.lax.while_loop maxIter = 1000 E_eff = E + self.eta*1j A = E_eff*jnp.eye(self.dim) - self.H0 planeVec = [0,1,2,6,7,8] # Location of vectors in plane def cond_fun(state): count, diff, sigSurf, sigSurf_ = state return (diff > conv) & (count < maxIter) def body_fun(state): count, diff, sigSurf, sigSurf_ = state sigSurf_ = sigSurf.copy() sigTot = jnp.sum(sigSurf, axis=0) g = LA.inv(A - sigTot) # subtracted from sigTot for k in planeVec: pair_k = (k + 6)%12 # Opposite direction vector B = E_eff*self.Slist[k] - self.Vlist0[k] B_bar = E_eff*self.Slist[k].conj().T - self.Vlist0[k].conj().T sigSurf = sigSurf.at[k].set(mix*(B@g@B_bar) + (1-mix)*sigSurf_[k]) # Convergence Check diff = jnp.max(jnp.abs(sigSurf - sigSurf_))/jnp.max(jnp.abs(sigSurf_)) count += 1 return (count, diff, sigSurf, sigSurf_) init_state = (0, jnp.inf, sigSurf, sigSurf.copy()) count, diff, sigSurf, sigSurf_ = lax.while_loop(cond_fun, body_fun, init_state) return sigSurf
[docs] def crossTermQSurf(self, E, sigInds=None, conv=SURFACE_GREEN_CONVERGENCE, mix=0.5): """Symmetrized cross-term Q_sym = sum_k (B_k g_k S_k + S_k g_k B_k^bar) / 2. sigInds: list of surface direction indices to include (default: all 9). Uses the neighbor Green's function g_k = inv(A - sigTot + sigSurf[pair_k]) for each direction k, not the center atom's own Green's function. For out-of-plane UP directions (3,4,5) whose pairs (9,10,11) are not surface directions, g_k = inv(A - sigTot) since the pair self-energy is already absent from sigTot. """ if sigInds is None: sigInds = list(range(9)) # Get converged surface self-energies (9 self-energies for surface) # E is pre-shifted by the caller sigSurf = self.sigmaSurf(E, conv, mix) # shape (9, dim, dim) E_eff = E + self.eta * 1j A = E_eff * jnp.eye(self.dim) - self.H0 sigTot = jnp.sum(sigSurf, axis=0) Q = jnp.zeros((self.dim, self.dim), dtype=complex) for k in sigInds: pair_k = (k + 6) % 12 if pair_k < 9: g_k = LA.inv(A - sigTot + sigSurf[pair_k]) else: g_k = LA.inv(A - sigTot) B_k = E_eff * self.Slist[k] - self.Vlist0[k] B_k_bar = E_eff * self.Slist[k].conj().T - self.Vlist0[k].conj().T Q_fwd = B_k @ g_k @ self.Slist[k].conj().T Q_rev = self.Slist[k] @ g_k @ B_k_bar Q = Q + (Q_fwd + Q_rev) / 2 return Q
[docs] def crossTermQBulk(self, E, conv=SURFACE_GREEN_CONVERGENCE, mix=0.5): """Bulk cross-term Q_sym over all 12 directions. Uses g_k = inv(A - sigTot + sigK[pair_k]) for each direction k, i.e. the neighbor's Green's function excluding the coupling back to the center atom. """ # E is pre-shifted by the caller sigK = self.sigmaK(E, conv, mix) # 12 bulk self-energies E_eff = E + self.eta * 1j A = E_eff * jnp.eye(self.dim) - self.H0 sigTot = jnp.sum(sigK, axis=0) Q = jnp.zeros((self.dim, self.dim), dtype=complex) for k in range(12): pair_k = (k + 6) % 12 g_k = LA.inv(A - sigTot + sigK[pair_k]) B_k = E_eff * self.Slist[k] - self.Vlist0[k] B_k_bar = E_eff * self.Slist[k].conj().T - self.Vlist0[k].conj().T Q_fwd = B_k @ g_k @ self.Slist[k].conj().T Q_rev = self.Slist[k] @ g_k @ B_k_bar Q = Q + (Q_fwd + Q_rev) / 2 return Q
# Empty function for compatibility with density.py methods
[docs] def setF(self, F, mu1, mu2): """ Empty function for compatibility with density.py methods. Bethe lattice bulk properties are intrinsic (dependent on TB parameters). Parameters ---------- F : ndarray Fock matrix (unused) mu1 : float First chemical potential (unused) mu2 : float Second chemical potential (unused) """ pass # Bethe lattice bulk properties are intrinsic (dependent on TB parameters)
[docs] def sigmaTot(self, E, conv=SURFACE_GREEN_CONVERGENCE): return jnp.sum(self.sigmaK(E - self.dFermi, conv), axis=0)
[docs] def sigma(self, E, i, conv=SURFACE_GREEN_CONVERGENCE): """Self-energy for contact i (only i=0, single bulk contact).""" return self.sigmaTot(E, conv)
[docs] def crossTermQ(self, E, i, conv=SURFACE_GREEN_CONVERGENCE): """Cross-term Q_sym for contact i (delegates to crossTermQBulk).""" return self.crossTermQBulk(E - self.dFermi, conv)
[docs] def crossTermQTot(self, E, conv=SURFACE_GREEN_CONVERGENCE): """Total cross-term Q_sym (single contact = crossTermQ).""" return self.crossTermQ(E, 0, conv)
# Get the surface DOS of the Bethe lattice
[docs] def DOS(self, E): """ Calculate surface density of states of the Bethe lattice. Parameters ---------- E : float Energy point for DOS calculation (in eV) Returns ------- float Density of states at energy E """ E_shifted = E - self.dFermi Gr = LA.inv((E_shifted + 1j*self.eta)*jnp.eye(self.dim) - self.H0 - jnp.sum(self.sigmaSurf(E_shifted), axis=0)) return -jnp.trace(Gr).imag/jnp.pi
# Calculate fermi energy using bisection (to specified tolerance)
[docs] def calcFermi(self, ne, tol=FERMI_CALCULATION_TOL): """ Calculate Fermi energy using bisection method. Uses getFermiContact from density.py to find the Fermi energy that gives the correct number of electrons. Parameters ---------- ne : float Target number of electrons tol : float, optional Convergence tolerance (default: 1e-5) Returns ------- float Calculated Fermi energy in eV """ print('Calculating Bulk Bethe Lattice Fermi level...') self.fermi = getFermiContact(self, ne, conv=tol, maxcycles=1000, T=self.T) if self.fermi0 is None: self.fermi0 = self.fermi return self.fermi