"""
Surface Green's function implementation for Bethe lattice contacts.
This module provides a Bethe lattice implementation for modeling semi-infinite
metallic contacts in quantum transport calculations. It supports:
- FCC [111] surface geometry with proper orbital symmetries
- Slater-Koster parameterization for s, p, and d orbitals
- Temperature-dependent calculations
- Spin-restricted and unrestricted calculations
The implementation follows the ANT.Gaussian approach [1], using a minimal basis
set with s, p, and d orbitals for each contact atom. The Bethe lattice model
provides an efficient way to describe bulk metallic electrodes while maintaining
proper orbital symmetries and electronic structure. This approach allows for
accurate modeling of metal-molecule interfaces without the computational cost
of explicit periodic boundary conditions.
References
----------
[1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models
in density functional theory based quantum transport calculations.
The Journal of Chemical Physics, 134(4), 044118.
DOI: 10.1063/1.3526044
"""
# Developed packages (import config BEFORE jax)
from gauNEGF.density import getFermiContact
from gauNEGF.config import (ETA, TEMPERATURE, SURFACE_GREEN_CONVERGENCE,
FERMI_CALCULATION_TOL)
from gauNEGF.utils import fractional_matrix_power
# Python packages
import jax
import jax.numpy as jnp
import jax.numpy.linalg as LA
from jax import jit
import jax.lax as lax
#Constants
dim = 9 # size of single atom matrix: 1s + 3p + 5d
har_to_eV = 27.211386 # eV/Hartree
bohr_to_ang = 0.529177 # Bohr radius to Angstrom
# Bethe lattice surface Green's function for a device with contacts
[docs]
class surfGB:
"""
Surface Green's function calculator for Bethe lattice contacts.
This class implements the Bethe lattice approximation for modeling
semi-infinite metallic contacts. It handles:
- Contact geometry detection and setup
- Slater-Koster parameter management
- Energy-dependent self-energy calculations
- Temperature effects
- Spin configurations
The Bethe lattice model represents the contacts as a semi-infinite tree-like
structure with proper coordination number and orbital symmetries matching
those of bulk FCC metals. This provides an efficient way to compute surface
Green's functions and self-energies for the electrodes, as demonstrated by
Jacob & Palacios in their 2011 paper [1].
Parameters
----------
F : ndarray
Fock matrix
S : ndarray
Overlap matrix
contacts : list of lists
List of atom indices for each contact
bar : object
Gaussian interface object containing basis set information
latFile : str, optional
Filename for Bethe lattice parameters (default: 'Au')
spin : str, optional
Spin configuration ('r' for restricted) (default: 'r')
eta : float, optional
Broadening parameter in eV (default: 1e-9)
T : float, optional
Temperature in Kelvin (default: 0)
Attributes
----------
cVecs : list
Normal vectors for each contact surface
latVecs : list
Lattice vectors for each contact
indsLists : list
Orbital indices for each contact atom
dirLists : list
Direction vectors for nearest neighbors
nIndLists : list
Nearest neighbor indices for each atom
gList : list
surfGBAt objects for each contact
References
----------
[1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models
in density functional theory based quantum transport calculations.
The Journal of Chemical Physics, 134(4), 044118.
DOI: 10.1063/1.3526044
"""
def __init__(self, F, S, contacts, bar, latFile='Au', spin='r', eta=ETA, T=TEMPERATURE):
#Read contact/orbital information and store
self.cVecs = []
self.latVecs = []
self.indsLists = []
self.dirLists = []
self.nIndLists = []
self.Xi = fractional_matrix_power(S, 0.5)
if spin != 'r':
self.Xi = self.Xi[::2, ::2]
# Spin independent implementation, add degenerate spin terms during sigma generation
self.spin = spin
orbMap = bar.ibfatm[bar.ibfatm>0]
orbTyp = bar.ibftyp[bar.ibfatm>0]
coords =jnp.array([bar.c[i*3:(i+1)*3] for i in range(len(bar.c)//3)])*bohr_to_ang
self.N = len(orbMap)
# Collect contact information
for contact in contacts:
indsList = []
cList = []
for atom in contact:
inds = jnp.where(jnp.isin(orbMap, atom))[0]
cList.append(coords[atom-1])
assert len(inds) == 9, f'Error: Atom {atom} has {len(inds)} basis functions, expecting 9'
inds = inds[jnp.argsort(abs(orbTyp[inds])//1000)]
indsList.append(inds)
self.indsLists.append(indsList)
# Calculate plane direction using SVD
cList = jnp.array(cList) # Convert list to JAX array
centeredCoords = cList-jnp.mean(cList, axis=0)
_, _, Vt = LA.svd(centeredCoords)
contDir = jnp.mean(cList, axis=0)-jnp.mean(coords, axis=0)
contVec = Vt[-1]
if jnp.dot(contDir, contVec)<0:
contVec *= -1
self.cVecs.append(contVec)
# Calculate one lattice direction for lining up atoms
vInd = jnp.argmin(jnp.array([LA.norm(v - cList[0]) for v in cList[1:]]))+1
latVec = cList[vInd]-cList[0]
latDist = LA.norm(latVec)
self.latVecs.append(latVec/latDist)
# Calculate rest of lattice directions
nVecs1 = self.genNeighbors(contVec, latVec)
# Calculate second direction (in case off by rotation)
nVecs2 = self.genNeighbors(contVec, -latVec)
# Use lattice vectors to see what nearest neighbors
nIndList = []
nVecs = nVecs1.copy()
for c in cList:
nAtVecs = []
for c2 in coords:
l = LA.norm(c2-c)
# if within 0.2*nearest neighbor dist and not the same atom
if l > 0.8 * latDist and l < 1.2 * latDist and not jnp.allclose(c2, c):
nAtVecs.append((c2-c)/l) #Unit vector for that direction
# Align out of plane vectors (two options)
nVecs = nVecs1.copy()
outOfPlane = [3,4,5,9,10,11]
for vec in nAtVecs:
valList = jnp.array([jnp.dot(vec, direction) for direction in nVecs2])
dirInd = jnp.argmax(valList)
if dirInd in outOfPlane and valList[dirInd]>0.9:
nVecs = nVecs2.copy()
break
# Now that orientation is fixed, track all neighbors
nInds = []
for vec in nAtVecs:
valList = jnp.array([jnp.dot(vec, direction) for direction in nVecs])
dirInd = jnp.argmax(valList)
if valList[dirInd]>0.9:
nInds.append(dirInd)
else:
print(f'Warning: Lattice Vec #{dirInd} mismatch, neighbor not recorded')
# write neighbor indices for each atom
nIndList.append(nInds)
# write direction vectors and neighbors for each contact
self.nIndLists.append(nIndList)
self.dirLists.append(nVecs)
# Read Bethe lattice parameters and generate hopping/overlap matrices
self.readBetheParams(latFile)
self.Slists = []
self.Vlists = []
for dirList in self.dirLists:
# Construct hopping matrices and store to contact
Slist = []
Vlist = []
for d in dirList:
Slist.append(self.constructMat(self.Sdict, d, self.SOC))
Vlist.append(self.constructMat(self.Vdict, d, self.SOC))
self.Slists.append(Slist)
self.Vlists.append(Vlist)
self.num_contacts = len(self.indsLists)
# Use surfGBAt() object to store the atomic Bethe lattice green's function for each contact
self.gList = [surfGBAt(self.H0.copy(), Slist, Vlist, eta, T, self.SOC) for Slist, Vlist in zip(self.Slists, self.Vlists)]
# Calculate fermi level
# With SOC, 18x18 basis already includes both spins -> use full ne
# Without SOC, spin-restricted -> ne/2 per spin channel
ne_fermi = self.ne if self.SOC else self.ne/2
fermi = self.gList[0].calcFermi(ne_fermi)
for g in self.gList:
g.fermi = fermi
g.fermi0 = fermi
# Store variables
self.cList = cList #first contact coords, used for testing
self.F = F
self.S = S
self.eta = eta
# Note: JIT compilation is on surfGBAt.sigmaSurf/sigmaK (the expensive Dyson
# iteration). surfGB.sigma is NOT JIT'd so it can access dFermi and shift
# E before calling the atomic methods -- keeping dFermi outside the JIT
# boundary avoids stale-closure bugs and unnecessary recompilation.
[docs]
def genNeighbors(self, plane_normal, first_neighbor):
"""
Generate 12 nearest neighbor unit vectors for an FCC [111] surface.
Creates a list of unit vectors representing the 12 nearest neighbors in an FCC lattice:
- 6 in-plane vectors forming a hexagonal pattern (3 pairs of opposite vectors)
- 6 out-of-plane vectors forming triangular patterns (3 pairs of opposite vectors)
Parameters
----------
plane_normal : ndarray
Vector normal to the crystal plane (will be normalized)
first_neighbor : ndarray
Vector to one nearest neighbor (will be projected onto plane)
Returns
-------
list
12 unit vectors representing nearest neighbor directions
"""
# Project first_neighbor onto plane perpendicular to plane_normal
proj = first_neighbor - jnp.dot(first_neighbor, plane_normal) * plane_normal
first_neighbor = proj / LA.norm(proj)
# Generate in-plane vectors using 60-degree rotations
in_plane_vectors = []
rotation_angle = jnp.pi / 3 # 60 degrees
for i in range(3):
angle = i * rotation_angle
# Rodrigues rotation formula
cos_theta = jnp.cos(angle)
sin_theta = jnp.sin(angle)
K = jnp.array([[0, -plane_normal[2], plane_normal[1]],
[plane_normal[2], 0, -plane_normal[0]],
[-plane_normal[1], plane_normal[0], 0]])
R = jnp.eye(3) + sin_theta * K + (1 - cos_theta) * jnp.matmul(K, K)
rotated_vector = jnp.dot(R, first_neighbor)
in_plane_vectors.append(rotated_vector / LA.norm(rotated_vector))
# Generate out-of-plane vectors
out_of_plane_angle = jnp.arccos(1/jnp.sqrt(3)) # ~54.74
out_of_plane_vectors = []
# Add 30deg = pi/6 rotation to base vector before going out of plane
rot_angle = jnp.pi/6
K = jnp.array([[0, -plane_normal[2], plane_normal[1]],
[plane_normal[2], 0, -plane_normal[0]],
[-plane_normal[1], plane_normal[0], 0]])
R = jnp.eye(3) + jnp.sin(rot_angle) * K + (1 - jnp.cos(rot_angle)) * jnp.matmul(K, K)
rotated_first = jnp.dot(R, first_neighbor)
out_of_plane_base = jnp.cos(out_of_plane_angle) * rotated_first + \
jnp.sin(out_of_plane_angle) * plane_normal
for i in range(3):
angle = i * 2 * jnp.pi / 3 # 120 degree rotations
cos_theta = jnp.cos(angle)
sin_theta = jnp.sin(angle)
K = jnp.array([[0, -plane_normal[2], plane_normal[1]],
[plane_normal[2], 0, -plane_normal[0]],
[-plane_normal[1], plane_normal[0], 0]])
R = jnp.eye(3) + sin_theta * K + (1 - cos_theta) * jnp.matmul(K, K)
rotated_vector = jnp.dot(R, out_of_plane_base)
out_of_plane_vectors.append(rotated_vector)
# Add corresponding opposite vectors at the (k+6)%12 location
all_vectors = in_plane_vectors + out_of_plane_vectors
for i in range(6):
all_vectors.append(-all_vectors[i])
# Return vectors
return all_vectors
# Read parameters from filename.bethe file, check values, store into dicts
[docs]
def readBetheParams(self, filename):
"""
Read Slater-Koster parameters from a .bethe file.
Reads and validates parameters for minimal basis with single s, p, and d orbitals.
Parameters are stored in dictionaries for onsite energies, hopping integrals,
and overlap matrices.
Parameters
----------
filename : str
Name of the .bethe file (without extension)
Raises
------
AssertionError
If parameters are missing or invalid
Notes
-----
Parameters are sorted into:
- Edict: Onsite energies (converted from Hartrees to eV)
- Vdict: Hopping parameters (converted from Hartrees to eV)
- Sdict: Overlap parameters
"""
params = {}
with open(filename+'.bethe', 'r') as f:
for line in f:
# Skip empty lines
if not line.strip():
continue
# Split on comma and strip whitespace
line = line.replace(' ','')
key, value = line.split('=')
params[key] = float(value)
# Check to make sure parameters are all specified
# Note: set up only for minimal basis with single s, p, and d orbital
expected_keys = {'ne', 'es', 'ep', 'edd', 'edt', 'sss', 'sps', 'pps', 'ppp',
'sds', 'pds', 'pdp', 'dds', 'ddp', 'ddd', 'Ssss', 'Ssps',
'Spps', 'Sppp', 'Ssds', 'Spds', 'Spdp', 'Sdds', 'Sddp', 'Sddd'}
optional_keys = {'soc_p', 'soc_d'}
found_keys = set(params.keys())
assert expected_keys.issubset(found_keys) and found_keys.issubset(expected_keys | optional_keys), \
f"Error reading file: Found Bethe parameters: {list(params.keys())}, expected: {sorted(expected_keys)}"
# sort parameters and convert Hartrees to eV
self.ne = params['ne']
self.Edict = {k[1:]:params[k]*har_to_eV for k in params if k.startswith('e')}
self.Sdict = {k[1:]:params[k] for k in params if k.startswith('S')}
self.Vdict = {k:params[k]*har_to_eV for k in params
if not k.startswith('e') and not k.startswith('S')
and not k.startswith('soc') and k != 'ne'}
# Setup onsite H0 matrix before Fermi level shifting
hdiag = [self.Edict['s']]+ [self.Edict['p']]*3 + [self.Edict['dd']]+ \
[self.Edict['dt']]*2 + [self.Edict['dd'], self.Edict['dt']]
H0 = jnp.diag(jnp.array(hdiag))
if 'soc_p' in params and 'soc_d' in params and self.spin != 'r':
self.SOC = True
lambdas = [0.0, params['soc_p'] * har_to_eV, params['soc_d'] * har_to_eV]
from gauNEGF.spinTools import constructSOCterm
Hsoc = constructSOCterm(lambdas)
self.H0 = jnp.kron(H0, jnp.eye(2)) + jnp.array(Hsoc)
else:
self.SOC = False
self.H0 = H0
[docs]
def constructMat(self, Mdict, dirCosines, SOC=False):
"""
Construct hopping/overlap matrix using Slater-Koster formalism.
Builds a 9x9 matrix for s, p, and d orbital interactions based on the
Slater-Koster two-center approximation. The matrix is first constructed
assuming a [0,0,1] bond direction, then rotated to the given direction
using direction cosines. When SOC=True, expands to 18x18 via kron(M, I2).
Parameters
----------
Mdict : dict
Dictionary of Slater-Koster parameters (sss, sps, pps, etc.)
dirCosines : ndarray
Array [l,m,n] of direction cosines for the bond
Returns
-------
ndarray
9x9 matrix containing orbital interactions in the rotated frame
Notes
-----
Matrix blocks:
- [0,0]: s-s interaction
- [0:4,0:4]: s-p block
- [0:4,4:9]: s-d and p-d blocks
- [4:9,4:9]: d-d block
"""
M = jnp.zeros((dim, dim))
#Original matrix before rotation - assuming [0,0,1] bond direction
# s-s coefficient
M = M.at[0,0].set(Mdict['sss'])
# s-p block
M = M.at[0,3].set(Mdict['sps']) #s-pz
M = M.at[3,0].set(-Mdict['sps']) #pz-s
# p-p block
M = M.at[1,1].set(Mdict['ppp']) #px-px
M = M.at[2,2].set(Mdict['ppp']) #py-py
M = M.at[3,3].set(Mdict['pps']) #pz-pz
# s-d block
M = M.at[0, 4].set(Mdict['sds']) #s - d3z²-r²
M = M.at[4, 0].set(Mdict['sds'])
# p-d block
M = M.at[1,5].set(Mdict['pdp']) #px - dxz
M = M.at[2,6].set(Mdict['pdp']) #py - dyz
M = M.at[3,4].set(Mdict['pds']) #pz - d3z²-r²
M = M.at[5,1].set(-Mdict['pdp']) #dxz - px
M = M.at[6,2].set(-Mdict['pdp']) #dyz - py
M = M.at[4,3].set(-Mdict['pds']) #d3z²-r² - pz
# d-d block
M = M.at[4,4].set(Mdict['dds']) #d3z²-r² - d3z²-r²
M = M.at[5,5].set(Mdict['ddp']) #dxz - dxz
M = M.at[6,6].set(Mdict['ddp']) #dyz - dyz
M = M.at[7,7].set(Mdict['ddd']) #dx²-y² - dx²-y²
M = M.at[8,8].set(Mdict['ddd']) #dxy - dxy
# Initialize 9x9 transformation matrix and polar directions
tr = jnp.zeros((9, 9))
x, y, z = dirCosines
theta = jnp.arccos(z) # polar angle from z-axis
phi = jnp.arctan2(y, x) # azimuthal angle in x-y plane
# s orbital (1x1) at position [0,0] - always 1 since spherically symmetric
tr = tr.at[0,0].set(1.0)
# p orbitals (3x3) at positions [1:4,1:4]
# [px,py,pz] block - describes how p orbitals transform under rotation
tr = tr.at[1:4,1:4].set(jnp.array([
[jnp.cos(theta) * jnp.cos(phi), -jnp.sin(phi) , jnp.sin(theta)*jnp.cos(phi)],
[jnp.cos(theta) * jnp.sin(phi), jnp.cos(phi) , jnp.sin(theta)*jnp.sin(phi)],
[-jnp.sin(theta) , 0 , jnp.cos(theta)]
]))
# d orbitals (5x5) at positions [4:9,4:9]
# [d3z2-r2, dxz, dyz, dx2-y2, dxy] block - transforms the five d orbitals
d_block = jnp.zeros((5,5))
# Copying formula from ANT.Gaussian directly
d_block = d_block.at[0,0].set((3 * z**2 - 1) / 2)
d_block = d_block.at[0,1].set(-jnp.sqrt(3) * jnp.sin(2*theta) / 2)
d_block = d_block.at[0,3].set(jnp.sqrt(3) * jnp.sin(theta)**2 / 2)
d_10 = jnp.sqrt(3) * jnp.sin(2*theta) * jnp.cos(phi) / 2
d_block = d_block.at[1,0].set(d_10)
d_block = d_block.at[1,1].set(jnp.cos(2*theta) * jnp.cos(phi))
d_block = d_block.at[1,2].set(-jnp.cos(theta) * jnp.sin(phi))
d_block = d_block.at[1,3].set(-d_10 / jnp.sqrt(3))
d_block = d_block.at[1,4].set(jnp.sin(theta) * jnp.sin(phi))
d_20 = jnp.sqrt(3) * jnp.sin(2*theta) * jnp.sin(phi) / 2
d_block = d_block.at[2,0].set(d_20)
d_block = d_block.at[2,1].set(jnp.cos(2*theta) * jnp.sin(phi))
d_block = d_block.at[2,2].set(jnp.cos(theta) * jnp.cos(phi))
d_block = d_block.at[2,3].set(-d_20 / jnp.sqrt(3))
d_block = d_block.at[2,4].set(-jnp.sin(theta) * jnp.cos(phi))
d_block = d_block.at[3,0].set(jnp.sqrt(3) * jnp.sin(theta)**2 * jnp.cos(2*phi) / 2)
d_block = d_block.at[3,1].set(jnp.sin(2*theta) * jnp.cos(2*phi) / 2)
d_block = d_block.at[3,2].set(-jnp.sin(theta) * jnp.sin(2*phi))
d_block = d_block.at[3,3].set((1 + jnp.cos(theta)**2) * jnp.cos(2*phi) / 2)
d_block = d_block.at[3,4].set(-jnp.cos(theta) * jnp.sin(2*phi))
d_block = d_block.at[4,0].set(jnp.sqrt(3) * jnp.sin(theta)**2 * jnp.sin(2*phi) / 2)
d_block = d_block.at[4,1].set(jnp.sin(2*theta) * jnp.sin(2*phi) / 2)
d_block = d_block.at[4,2].set(jnp.sin(theta) * jnp.cos(2*phi))
d_block = d_block.at[4,3].set((1 + jnp.cos(theta)**2) * jnp.sin(2*phi) / 2)
d_block = d_block.at[4,4].set(jnp.cos(theta) * jnp.cos(2*phi))
tr = tr.at[4:9,4:9].set(d_block)
# Apply transformation
M_rot = tr @ M @ tr.T
if SOC:
return jnp.kron(M_rot, jnp.eye(2))
return M_rot
[docs]
def sigma(self, E, i, conv=SURFACE_GREEN_CONVERGENCE):
"""
Calculate self-energy matrix for a specific contact.
Computes the self-energy matrix for contact i by:
1. Calculating surface self-energies for all 9 directions
2. Summing contributions from directions not connected to the device
3. Applying de-orthonormalization if needed
4. Handling spin configurations
Parameters
----------
E : float
Energy point for self-energy calculation (in eV)
i : int
Index of the contact to calculate self-energy for
conv : float, optional
Convergence criterion for self-energy calculation (default: SURFACE_GREEN_CONVERGENCE)
Returns
-------
ndarray
Self-energy matrix for the specified contact, with dimensions:
- (N, N) for restricted calculations
- (2N, 2N) for unrestricted or generalized spin calculations
References
----------
[1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models
in density functional theory based quantum transport calculations.
The Journal of Chemical Physics, 134(4), 044118.
DOI: 10.1063/1.3526044
"""
# Shift E to reference frame of immutable H0/Vlist0 (keeps dFermi
# outside the JIT boundary of surfGBAt.sigma)
E_shifted = E - self.gList[i].dFermi
sigSurf = self.gList[i].sigmaSurf(E_shifted, conv)
# Get contact-specific data for this static contact index
nIndLists_i = self.nIndLists[i]
indsLists_i = self.indsLists[i]
if self.SOC:
# SOC matrices are 18x18 (spin already included, interleaved ordering)
sig = jnp.zeros((2*self.N, 2*self.N), dtype=complex)
for nInds, Finds in zip(nIndLists_i, indsLists_i):
sigInds = list(set(range(9)) - {int(x) for x in nInds})
sigAtom = sum(sigSurf[j] for j in sigInds)
# Expand orbital indices to spin-orbital: i -> [2*i, 2*i+1]
socFinds = jnp.array([idx for i in Finds for idx in (2*i, 2*i+1)])
sig = sig.at[jnp.ix_(socFinds, socFinds)].set(sigAtom)
# De-orthonormalization with expanded Xi
sig = lax.cond(self.Sdict['sss'] == 0,
lambda s: jnp.kron(self.Xi, jnp.eye(2)) @ s @ jnp.kron(self.Xi, jnp.eye(2)),
lambda s: s,
sig)
# No kron needed -- spin is already in SOC matrices
else:
sig = jnp.zeros((self.N, self.N), dtype=complex)
for nInds, Finds in zip(nIndLists_i, indsLists_i):
sigInds = list(set(range(9)) - {int(x) for x in nInds})
sigAtom = sum(sigSurf[j] for j in sigInds)
sig = sig.at[jnp.ix_(Finds, Finds)].set(sigAtom)
# Apply de-orthonormalization technique from ANT.Gaussian if orthonormal
sig = lax.cond(self.Sdict['sss'] == 0,
lambda s: self.Xi @ s @ self.Xi,
lambda s: s,
sig)
# Handle spin - use if/else since spin is static
if self.spin == 'u' or self.spin == 'ro':
sig = jnp.kron(jnp.eye(2), sig)
elif self.spin == 'g':
sig = jnp.kron(sig, jnp.eye(2))
# else: spin == 'r', keep sig as-is
return sig
[docs]
def sigmaTot(self, E, conv=SURFACE_GREEN_CONVERGENCE):
"""
Calculate total self-energy matrix from all contacts.
Sums sigma(E, i) over all contacts and returns the result in the full device basis.
Parameters
----------
E : float
Energy point for Green's function calculation (in eV)
conv : float, optional
Convergence criterion for self-energy calculation (default: 1e-5)
Returns
-------
ndarray
Total self-energy matrix in the full device basis
"""
num_contacts = len(self.indsLists)
sigs = [self.sigma(E, i, conv) for i in range(num_contacts)]
return sum(sigs)
[docs]
def crossTermQ(self, E, i, conv=SURFACE_GREEN_CONVERGENCE):
"""Cross-term Q_sym for contact i in full device basis.
Mirrors surfGB.sigma: iterates atoms in contact i, calls
gList[i].crossTermQSurf with per-atom active directions, assembles result,
then applies de-orthonormalization (same as sigma).
"""
nIndLists_i = self.nIndLists[i]
indsLists_i = self.indsLists[i]
E_shifted = E - self.gList[i].dFermi
if self.SOC:
# SOC matrices are 18x18 (spin already included, interleaved ordering)
sig = jnp.zeros((2*self.N, 2*self.N), dtype=complex)
for nInds, Finds in zip(nIndLists_i, indsLists_i):
sigInds = list(set(range(9)) - {int(x) for x in nInds})
Q_atom = self.gList[i].crossTermQSurf(E_shifted, sigInds=sigInds, conv=conv)
# Expand orbital indices to spin-orbital: k -> [2*k, 2*k+1]
socFinds = jnp.array([idx for k in Finds for idx in (2*k, 2*k+1)])
sig = sig.at[jnp.ix_(socFinds, socFinds)].set(Q_atom)
# De-orthonormalization with expanded Xi (same condition as sigma)
sig = lax.cond(self.Sdict['sss'] == 0,
lambda s: jnp.kron(self.Xi, jnp.eye(2)) @ s @ jnp.kron(self.Xi, jnp.eye(2)),
lambda s: s,
sig)
# No trailing spin kron -- spin already in SOC matrices
else:
sig = jnp.zeros((self.N, self.N), dtype=complex)
for nInds, Finds in zip(nIndLists_i, indsLists_i):
sigInds = list(set(range(9)) - {int(x) for x in nInds})
Q_atom = self.gList[i].crossTermQSurf(E_shifted, sigInds=sigInds, conv=conv)
sig = sig.at[jnp.ix_(Finds, Finds)].set(Q_atom)
# Apply de-orthonormalization if orthonormal basis (same as sigma)
sig = lax.cond(self.Sdict['sss'] == 0,
lambda s: self.Xi @ s @ self.Xi,
lambda s: s,
sig)
if self.spin == 'u' or self.spin == 'ro':
sig = jnp.kron(jnp.eye(2), sig)
elif self.spin == 'g':
sig = jnp.kron(sig, jnp.eye(2))
return sig
[docs]
def crossTermQTot(self, E, conv=SURFACE_GREEN_CONVERGENCE):
"""Total cross-term Q_sym from all contacts."""
num_contacts = len(self.indsLists)
qs = [self.crossTermQ(E, i, conv) for i in range(num_contacts)]
return sum(qs)
[docs]
def getSigma(self, Elist=[None, None], conv=SURFACE_GREEN_CONVERGENCE):
"""
Helper method for getting the left and right contact self-energies
Parameters
----------
Elist : tuple, optional
A list of contact energies for selecting sigma,
(default: use contact Fermi energy)
conv: float, optional
Convergence criterion for the self-energy matrix
Returns
-------
tuple
A tuple of both self-energy matrices (ndarrays)
"""
if Elist[0] is None:
Elist[0] = self.gList[0].fermi
if Elist[1] is None:
Elist[1] = self.gList[-1].fermi
return (self.sigma(Elist[0], 0, conv), self.sigma(Elist[1], -1, conv))
[docs]
def updateFermi(self, i, Ef):
"""
Update Fermi energy for a specific contact.
Shifts the Hamiltonian of contact i to align its Fermi level with
the specified energy.
Parameters
----------
i : int
Contact index
Ef : float
New Fermi energy in eV
"""
self.gList[i].updateH(Ef)
[docs]
def setF(self, F, muL, muR):
"""
Update Fock matrix and contact chemical potentials.
Sets the Fock matrix and updates the Fermi levels of the left and
right contacts if they have changed.
Parameters
----------
F : ndarray
New Fock matrix
muL : float
Chemical potential for left contact in eV
muR : float
Chemical potential for right contact in eV
"""
self.F = F
if self.gList[0].fermi != muL:
self.updateFermi(0, muL)
if self.gList[-1].fermi != muR:
self.updateFermi(-1, muR)
## TESTING METHODS FOR SLATER-KOSTER INTERACTIONS:
[docs]
def testDOrbitalFunctions(self):
"""
Test d orbital angular functions.
Validates the angular dependence of d orbital interactions by checking:
- dxy interaction along x-axis (should be zero)
- dx2-y2 interaction along x-axis (should be sqrt(3)/2 * sds)
- dz2 interaction along x-axis (should be -1/2 * sds)
"""
# Use values from the Bethe parameter dictionaries
Vdict = self.Vdict # Contains hopping parameters
# Test along x-axis [1,0,0]
M = self.constructMat(self.Vdict, [1, 0, 0])
# dxy should be zero along x-axis
jnp.testing.assert_almost_equal(M[0,8], 0.0,
err_msg="dxy not zero along x-axis")
# dx2-y2 should be sqrt(3)/2 * sds along x-axis
jnp.testing.assert_almost_equal(M[0,7], jnp.sqrt(3)/2 * Vdict['sds'],
err_msg="dx2-y2 incorrect along x-axis")
# dz2 should be -1/2 along x-axis
jnp.testing.assert_almost_equal(M[0,4], -0.5 * Vdict['sds'],
err_msg="dz2 incorrect along x-axis")
print("d orbital angular function tests passed!")
[docs]
def testDOrbitalSymmetry(self):
"""
Test d orbital symmetry properties.
Validates that d orbital interactions respect inversion symmetry
by comparing interactions along opposite directions.
"""
# Test inversion symmetry
dir1 = [1/jnp.sqrt(2), 1/jnp.sqrt(2), 0]
dir2 = [-1/jnp.sqrt(2), -1/jnp.sqrt(2), 0]
M1 = self.constructMat(self.Vdict, dir1)
M2 = self.constructMat(self.Vdict, dir2)
# d-d block should be identical under inversion
jnp.testing.assert_array_almost_equal(
M1[4:,4:], M2[4:,4:],
err_msg="d-d block not symmetric under inversion")
print("d orbital symmetry tests passed!")
[docs]
def testPDInteraction(self):
"""
Test p-d orbital interactions.
Validates p-d orbital interactions by checking:
- px-dxy interaction along x-axis (should be zero)
- pz-dz2 interaction along z-axis (should be pure sigma)
"""
Vdict = self.Vdict
# Test px-dxy interaction along x-axis
M = self.constructMat(Vdict, [1, 0, 0])
# px-dxy should be zero along x-axis
jnp.testing.assert_almost_equal(
M[1,8], 0.0,
err_msg="px-dxy interaction incorrect along x-axis")
# Test pz-dz2 interaction along z-axis
M = self.constructMat(Vdict, [0, 0, 1])
expected = Vdict['pds'] # Should be pure sigma
jnp.testing.assert_almost_equal(
M[3,4], expected,
err_msg="pz-dz2 interaction incorrect along z-axis")
print("p-d interaction tests passed!")
[docs]
def testDDInteraction(self):
"""
Test d-d orbital interactions.
Validates d-d orbital interactions by checking:
- dyz-dyz interaction along x-axis (should be pure delta)
- dz2-dz2 interaction along z-axis (should be pure sigma)
"""
Vdict = self.Vdict
# Test dyz-dyz interaction along x-axis
M = self.constructMat(Vdict, [1, 0, 0])
# Should be pure delta interaction
expected = Vdict['ddd']
jnp.testing.assert_almost_equal(
M[6,6], expected,
err_msg="dyz-dyz interaction incorrect along x-axis")
# Test dz2-dz2 interaction along x-axis
M = self.constructMat(Vdict, [0, 0, 1])
# Should be pure sigma interaction
expected = Vdict['dds']
jnp.testing.assert_almost_equal(
M[4,4], expected,
err_msg="dz2-dz2 interaction incorrect along z-axis")
print("d-d interaction tests passed!")
[docs]
def testHoppingPhysics(self):
"""
Test physical properties of hopping matrices.
Validates hopping matrix physics by checking:
- s-p hopping antisymmetry
- Conservation of total s-p hopping magnitude
- Proper angular dependence along principal axes and 45-degree rotations
"""
eps = 1e-10 # Tolerance for floating point comparisons
# Get reference hopping values from [0,0,1] configuration
s_p_mag = abs(self.Vdict['sps']) # Magnitude of s-p hopping
# Test set of physically important directions
test_cases = [
# Principal axes
([0, 0, 1], "z-axis"),
([1, 0, 0], "x-axis"),
([0, 1, 0], "y-axis"),
# 45-degree rotations
([1/jnp.sqrt(2), 0, 1/jnp.sqrt(2)], "45° in xz-plane"),
([0, 1/jnp.sqrt(2), 1/jnp.sqrt(2)], "45° in yz-plane"),
([1/jnp.sqrt(2), 1/jnp.sqrt(2), 0], "45° in xy-plane"),
]
print("\nTesting hopping matrix physics...")
for direction, name in test_cases:
direction = jnp.array(direction)
x, y, z = direction
print(f"\nChecking {name} direction: [{x:.3f}, {y:.3f}, {z:.3f}]")
V = self.constructMat(self.Vdict, direction)
# Check s-p hopping antisymmetry
for i in range(1, 4): # Check all p orbitals
assert abs(V[0,i] + V[i,0]) < eps, \
f"s-p hopping not antisymmetric for p{i}"
# Check total s-p hopping magnitude is preserved
s_p_total = jnp.sqrt(V[0,1]**2 + V[0,2]**2 + V[0,3]**2)
assert abs(s_p_total - s_p_mag) < eps, \
f"s-p hopping magnitude not preserved: {s_p_total:.6f} != {s_p_mag:.6f}"
# Print values for verification
print(f"s-px: {V[0,1]:.3f}, px-s: {V[1,0]:.3f}")
print(f"s-py: {V[0,2]:.3f}, py-s: {V[2,0]:.3f}")
print(f"s-pz: {V[0,3]:.3f}, pz-s: {V[3,0]:.3f}")
print(f"Total s-p magnitude: {s_p_total:.3f}")
print("\nAll hopping physics tests passed!")
[docs]
def runAllTests(self):
"""
Run all validation tests for surfGB.
Executes all test methods to validate:
- d orbital angular functions
- d orbital symmetry
- p-d interactions
- d-d interactions
- General hopping physics
"""
print("Running Slater-Koster projection tests...")
self.testDOrbitalFunctions()
self.testDOrbitalSymmetry()
self.testPDInteraction()
self.testDDInteraction()
self.testHoppingPhysics()
print("\nAll tests passed!")
# Bethe lattice surface Green's function for a single atom
[docs]
class surfGBAt:
"""
Atomic-level Bethe lattice Green's function calculator.
This class implements the surface Green's function calculation for a single atom
in the Bethe lattice, handling:
- Onsite and hopping matrix construction
- Self-energy calculations for bulk and surface
- Temperature effects
- Fermi energy optimization
Parameters
----------
H : ndarray
Onsite Hamiltonian matrix (9x9 for minimal basis)
Slist : list of ndarray
List of 12 overlap matrices for nearest neighbors
Vlist : list of ndarray
List of 12 hopping matrices for nearest neighbors
eta : float
Broadening parameter in eV
T : float, optional
Temperature in Kelvin (default: 0)
Attributes
----------
NN : int
Number of nearest neighbors (fixed to 12 for FCC)
fermi : float
Current Fermi energy
F : ndarray
Extended Fock matrix including neighbors
S : ndarray
Extended overlap matrix including neighbors
"""
[docs]
def __init__(self, H, Slist, Vlist, eta, T=TEMPERATURE, SOC=False):
"""
Initialize surfGBAt with Hamiltonian and neighbor matrices.
Parameters
----------
H : ndarray
Onsite Hamiltonian matrix (9x9 for minimal basis)
Slist : list of ndarray
List of 12 overlap matrices for nearest neighbors
Vlist : list of ndarray
List of 12 hopping matrices for nearest neighbors
eta : float
Broadening parameter in eV
T : float, optional
Temperature in Kelvin (default: 0)
SOC : bool, optional
Whether to include spin-orbit coupling (default: False)
Raises
------
AssertionError
If matrix dimensions are incorrect or number of neighbors != 12
"""
self.dim = dim*2 if SOC else dim
assert jnp.shape(H) == (self.dim,self.dim), f"Error with H dim, should be {self.dim}x{self.dim}"
for S,V in zip(Slist, Vlist):
assert jnp.shape(S) == (self.dim,self.dim), f"Error with S dim, should be {self.dim}x{self.dim}"
assert jnp.shape(V) == (self.dim,self.dim), f"Error with F dim, should be {self.dim}x{self.dim}"
self.H = H
self.Slist = Slist
self.Vlist = Vlist
self.spin = 'r' # spin-dependence not implemented yet
self.NN = len(Slist)
assert self.NN == 12, "Error: surfGBAt only implemented for FCC using 12 NN"
self.eta = eta
self.T = T
self.fermi = None
self.H0 = jnp.array(H) # immutable reference (never mutated)
self.Vlist0 = jnp.array(Vlist) # immutable reference (never mutated)
self.fermi0 = None # reference Fermi level
self.dFermi = 0.0 # shift from reference
self.num_contacts = 1
self.updateH()
# JIT compile methods with self as static argument
self.sigmaK = jit(self.sigmaK, static_argnums=(1,2))
self.sigmaSurf = jit(self.sigmaSurf, static_argnums=(1,2))
[docs]
def updateH(self, fermi=None):
"""
Update Hamiltonian and Fock matrix.
Sets F = H (9x9) and S = I for protocol compatibility with density.py.
Parameters
----------
fermi : float, optional
New Fermi energy setpoint in eV (default: None)
"""
if fermi is not None and fermi != self.fermi:
self.fermi = fermi
self.dFermi = fermi if self.fermi0 is None else fermi - self.fermi0
if self.fermi0 is None:
self.fermi0 = fermi
# Build H and Vlist from H0/Vlist0 plus current dFermi (for non-JIT callers)
self.H = self.H0 + self.dFermi * jnp.eye(self.dim)
self.Vlist = jnp.array([self.Vlist0[j] + self.dFermi * self.Slist[j]
for j in range(self.NN)])
self.F = self.H # 9x9
self.S = jnp.eye(self.dim)
# Calculate sigmaK for the bulk
[docs]
def sigmaK(self, E, conv=SURFACE_GREEN_CONVERGENCE, mix=0.5):
"""
Calculate bulk self-energies for all 12 lattice directions.
Computes self-energies for an FCC lattice with the following geometry:
[3x out of plane dir]
\|/
[3x plane dir] - o - [3x plane dir]
/|\
[3x out of plane dir]
Uses a self-consistent iteration scheme with mixing to solve the Dyson equation.
Parameters
----------
E : float
Energy point for Green's function calculation (in eV)
conv : float, optional
Convergence criterion for Dyson equation (default: 1e-5)
mix : float, optional
Mixing factor for Dyson equation (default: 0.5)
Returns
-------
ndarray
Array of 12 self-energy matrices (9x9 each) in order by lattice direction
Notes
-----
Uses previous solution as initial guess when energy point is close to
previous calculation to improve convergence.
"""
sigmaK = jnp.array([jnp.eye(self.dim)*-1j for k in range(self.NN)], dtype=complex)
# E is pre-shifted by the caller (wrapper subtracts dFermi before entering
# the JIT boundary). Using E directly keeps H0/Vlist0 as immutable
# constants in the compiled code, eliminating JIT recompilation during SCF.
E_eff = E + self.eta*1j
A = E_eff*jnp.eye(self.dim) - self.H0
#Self-consistency loop using jax.lax.while_loop
maxIter = 1000
def cond_fun(state):
count, diff, sigmaK, sigmaK_ = state
return (diff > conv) & (count < maxIter)
def body_fun(state):
count, diff, sigmaK, sigmaK_ = state
sigmaK_ = sigmaK.copy()
sigTot = jnp.sum(sigmaK, axis=0)
for k in range(self.NN):
pair_k = (k + 6)%12 # Opposite direction vector
gK = LA.inv(A - sigTot + sigmaK[pair_k]) # subtracted from sigTot
B = E_eff*self.Slist[k] - self.Vlist0[k]
B_bar = E_eff*self.Slist[k].conj().T - self.Vlist0[k].conj().T
sigmaK = sigmaK.at[k].set(mix*(B@gK@B_bar) + (1-mix)*sigmaK_[k])
# Convergence Check
diff = jnp.max(jnp.abs(sigmaK - sigmaK_))/jnp.max(jnp.abs(sigmaK_))
count += 1
return (count, diff, sigmaK, sigmaK_)
# Initial state: (count, diff, sigmaK, sigmaK_)
init_state = (0, jnp.inf, sigmaK, sigmaK.copy())
count, diff, sigmaK, sigmaK_ = lax.while_loop(cond_fun, body_fun, init_state)
return sigmaK
[docs]
def sigmaSurf(self, E, conv=SURFACE_GREEN_CONVERGENCE, mix=0.5):
"""
Calculate surface self-energies for an FCC lattice.
Computes self-energies for atoms at the surface with the geometry:
[3x plane dir] - o - [3x plane dir]
/|\
[3x out of plane dir]
Uses a self-consistent iteration scheme with mixing to solve the Dyson equation.
The implementation follows the Bethe lattice approach described in Jacob & Palacios (2011),
where the self-energy is computed recursively for a semi-infinite tree-like structure
that preserves the proper coordination number and orbital symmetries of bulk FCC metals.
Parameters
----------
E : float
Energy point for Green's function calculation (in eV)
conv : float, optional
Convergence criterion for Dyson equation (default: 1e-5)
mix : float, optional
Mixing factor for Dyson equation (default: 0.5)
Returns
-------
list
List of self-energy matrices for the surface atom.
Notes
-----
First calculates bulk self-energies using sigmaK, then iterates to find
surface self-energies for the 9 surface directions. The recursive method
ensures proper treatment of the metal-molecule interface while maintaining
computational efficiency.
References
----------
[1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models
in density functional theory based quantum transport calculations.
The Journal of Chemical Physics, 134(4), 044118.
DOI: 10.1063/1.3526044
"""
sigSurf = self.sigmaK(E, conv, mix)[:9]
#Self-consistency loop using jax.lax.while_loop
maxIter = 1000
E_eff = E + self.eta*1j
A = E_eff*jnp.eye(self.dim) - self.H0
planeVec = [0,1,2,6,7,8] # Location of vectors in plane
def cond_fun(state):
count, diff, sigSurf, sigSurf_ = state
return (diff > conv) & (count < maxIter)
def body_fun(state):
count, diff, sigSurf, sigSurf_ = state
sigSurf_ = sigSurf.copy()
sigTot = jnp.sum(sigSurf, axis=0)
g = LA.inv(A - sigTot) # subtracted from sigTot
for k in planeVec:
pair_k = (k + 6)%12 # Opposite direction vector
B = E_eff*self.Slist[k] - self.Vlist0[k]
B_bar = E_eff*self.Slist[k].conj().T - self.Vlist0[k].conj().T
sigSurf = sigSurf.at[k].set(mix*(B@g@B_bar) + (1-mix)*sigSurf_[k])
# Convergence Check
diff = jnp.max(jnp.abs(sigSurf - sigSurf_))/jnp.max(jnp.abs(sigSurf_))
count += 1
return (count, diff, sigSurf, sigSurf_)
init_state = (0, jnp.inf, sigSurf, sigSurf.copy())
count, diff, sigSurf, sigSurf_ = lax.while_loop(cond_fun, body_fun, init_state)
return sigSurf
[docs]
def crossTermQSurf(self, E, sigInds=None, conv=SURFACE_GREEN_CONVERGENCE, mix=0.5):
"""Symmetrized cross-term Q_sym = sum_k (B_k g_k S_k + S_k g_k B_k^bar) / 2.
sigInds: list of surface direction indices to include (default: all 9).
Uses the neighbor Green's function g_k = inv(A - sigTot + sigSurf[pair_k])
for each direction k, not the center atom's own Green's function.
For out-of-plane UP directions (3,4,5) whose pairs (9,10,11) are not
surface directions, g_k = inv(A - sigTot) since the pair self-energy
is already absent from sigTot.
"""
if sigInds is None:
sigInds = list(range(9))
# Get converged surface self-energies (9 self-energies for surface)
# E is pre-shifted by the caller
sigSurf = self.sigmaSurf(E, conv, mix) # shape (9, dim, dim)
E_eff = E + self.eta * 1j
A = E_eff * jnp.eye(self.dim) - self.H0
sigTot = jnp.sum(sigSurf, axis=0)
Q = jnp.zeros((self.dim, self.dim), dtype=complex)
for k in sigInds:
pair_k = (k + 6) % 12
if pair_k < 9:
g_k = LA.inv(A - sigTot + sigSurf[pair_k])
else:
g_k = LA.inv(A - sigTot)
B_k = E_eff * self.Slist[k] - self.Vlist0[k]
B_k_bar = E_eff * self.Slist[k].conj().T - self.Vlist0[k].conj().T
Q_fwd = B_k @ g_k @ self.Slist[k].conj().T
Q_rev = self.Slist[k] @ g_k @ B_k_bar
Q = Q + (Q_fwd + Q_rev) / 2
return Q
[docs]
def crossTermQBulk(self, E, conv=SURFACE_GREEN_CONVERGENCE, mix=0.5):
"""Bulk cross-term Q_sym over all 12 directions.
Uses g_k = inv(A - sigTot + sigK[pair_k]) for each direction k,
i.e. the neighbor's Green's function excluding the coupling back
to the center atom.
"""
# E is pre-shifted by the caller
sigK = self.sigmaK(E, conv, mix) # 12 bulk self-energies
E_eff = E + self.eta * 1j
A = E_eff * jnp.eye(self.dim) - self.H0
sigTot = jnp.sum(sigK, axis=0)
Q = jnp.zeros((self.dim, self.dim), dtype=complex)
for k in range(12):
pair_k = (k + 6) % 12
g_k = LA.inv(A - sigTot + sigK[pair_k])
B_k = E_eff * self.Slist[k] - self.Vlist0[k]
B_k_bar = E_eff * self.Slist[k].conj().T - self.Vlist0[k].conj().T
Q_fwd = B_k @ g_k @ self.Slist[k].conj().T
Q_rev = self.Slist[k] @ g_k @ B_k_bar
Q = Q + (Q_fwd + Q_rev) / 2
return Q
# Empty function for compatibility with density.py methods
[docs]
def setF(self, F, mu1, mu2):
"""
Empty function for compatibility with density.py methods.
Bethe lattice bulk properties are intrinsic (dependent on TB parameters).
Parameters
----------
F : ndarray
Fock matrix (unused)
mu1 : float
First chemical potential (unused)
mu2 : float
Second chemical potential (unused)
"""
pass # Bethe lattice bulk properties are intrinsic (dependent on TB parameters)
[docs]
def sigmaTot(self, E, conv=SURFACE_GREEN_CONVERGENCE):
return jnp.sum(self.sigmaK(E - self.dFermi, conv), axis=0)
[docs]
def sigma(self, E, i, conv=SURFACE_GREEN_CONVERGENCE):
"""Self-energy for contact i (only i=0, single bulk contact)."""
return self.sigmaTot(E, conv)
[docs]
def crossTermQ(self, E, i, conv=SURFACE_GREEN_CONVERGENCE):
"""Cross-term Q_sym for contact i (delegates to crossTermQBulk)."""
return self.crossTermQBulk(E - self.dFermi, conv)
[docs]
def crossTermQTot(self, E, conv=SURFACE_GREEN_CONVERGENCE):
"""Total cross-term Q_sym (single contact = crossTermQ)."""
return self.crossTermQ(E, 0, conv)
# Get the surface DOS of the Bethe lattice
[docs]
def DOS(self, E):
"""
Calculate surface density of states of the Bethe lattice.
Parameters
----------
E : float
Energy point for DOS calculation (in eV)
Returns
-------
float
Density of states at energy E
"""
E_shifted = E - self.dFermi
Gr = LA.inv((E_shifted + 1j*self.eta)*jnp.eye(self.dim) - self.H0 - jnp.sum(self.sigmaSurf(E_shifted), axis=0))
return -jnp.trace(Gr).imag/jnp.pi
# Calculate fermi energy using bisection (to specified tolerance)
[docs]
def calcFermi(self, ne, tol=FERMI_CALCULATION_TOL):
"""
Calculate Fermi energy using bisection method.
Uses getFermiContact from density.py to find the Fermi energy that gives
the correct number of electrons.
Parameters
----------
ne : float
Target number of electrons
tol : float, optional
Convergence tolerance (default: 1e-5)
Returns
-------
float
Calculated Fermi energy in eV
"""
print('Calculating Bulk Bethe Lattice Fermi level...')
self.fermi = getFermiContact(self, ne, conv=tol, maxcycles=1000, T=self.T)
if self.fermi0 is None:
self.fermi0 = self.fermi
return self.fermi