Source code for gauNEGF.surfG3D

# Developed packages
from gauNEGF.density import *
from gauNEGF.config import (ETA, TEMPERATURE, shard_array)
from gauNEGF.utils import fractional_matrix_power

# Python packages
import jax
import jax.numpy as jnp
from jax.numpy import linalg as LA

#Constants
dim = 9                 # size of single atom matrix: 1s + 3p + 5d
har_to_eV = 27.211386   # eV/Hartree


# 3D k-grid lattice surface Green's function for a device with 111 contacts
[docs] class surfG3: """ Surface Green's function calculator for 3D lattice with [111] surface. Parameters ---------- F : ndarray Fock matrix from DFT calculation S : ndarray Overlap matrix from DFT calculation contacts : list of lists Lists of atom indices for each contact region bar : QCBinAr Gaussian interface object containing geometry and orbital information latFile : str, optional Name of .bethe file containing Slater-Koster parameters (default: 'Au') spin : {'r', 'u', 'ro', 'g'}, optional Spin treatment: restricted, unrestricted, or generalized (default: 'r') eta : float, optional Broadening parameter in eV (default: 1e-9) T : float, optional Temperature in Kelvin (default: 0) Attributes ---------- F : ndarray Fock matrix S : ndarray Overlap matrix gList : list of surfGAt3D List of atomic surface Green's function calculators for each contact """ def __init__(self, F, S, contacts, bar, latFile='Au', spin='r', eta=ETA, T=TEMPERATURE): #Read contact/orbital information and store self.cVecs = [] self.latVecs = [] self.indsLists = [] self.dirLists = [] self.nIndLists = [] self.Xi = fractional_matrix_power(S, 0.5) if spin != 'r': self.Xi = self.Xi[::2, ::2] # Spin independent implementation, add degenerate spin terms during sigma generation self.spin = spin orbMap = bar.ibfatm[bar.ibfatm>0] orbTyp = bar.ibftyp[bar.ibfatm>0] self.N = len(orbMap) # Collect contact information for contact in contacts: indsList = [] cList = [] for atom in contact: inds = jnp.where(jnp.isin(orbMap, atom))[0] cList.append(jnp.array(bar.c[(atom-1)*3:atom*3])) assert len(inds) == 9, f'Error: Atom {atom} has {len(inds)} basis functions, expecting 9' inds = inds[jnp.argsort(abs(orbTyp[inds])//1000)] indsList.append(inds) self.indsLists.append(indsList) # Calculate plane direction using SVD cList = jnp.array(cList) centeredCoords = cList-jnp.mean(cList, axis=0) _, _, Vt = LA.svd(centeredCoords) self.cVecs.append(Vt[-1]) # Calculate one lattice direction for lining up atoms vInd = jnp.argmin(jnp.array([LA.norm(v - cList[0]) for v in cList[1:]]))+1 latVec = cList[vInd]-cList[0] self.latVecs.append(latVec/LA.norm(latVec)) # Calculate rest of lattice directions nVecs = self.genNeighbors(Vt[-1], latVec) # Use lattice vectors to see what nearest neighbors nIndList = [] for c in cList: nAtVecs = [] for c2 in cList: l = LA.norm(c2-c) # if within 1.5*nearest neighbor dist and not the same atom if l < 1.5 * LA.norm(latVec) and not jnp.allclose(c2, c): nAtVecs.append((c2-c)/l) #Unit vector for that direction nInds = [] for vec in nAtVecs: valList = jnp.array([jnp.dot(vec, direction) for direction in nVecs]) nInds.append(jnp.argmax(valList)) assert valList[nInds[-1]] > 0.9 and nInds[-1] in [0,1,2,6,7,8], \ 'Error: Lattice mismatch in atoms!' # write neighbor indices for each atom nIndList.append(nInds) # write direction vectors and neighbors for each contact self.nIndLists.append(nIndList) self.dirLists.append(nVecs) # Read Bethe lattice parameters and generate hopping/overlap matrices self.readBetheParams(latFile) self.Slists = [] self.Vlists = [] for dirList in self.dirLists: # Construct hopping matrices and store to contact Slist = [] Vlist = [] for d in dirList: Slist.append(self.constructMat(self.Sdict, d)) Vlist.append(self.constructMat(self.Vdict, d)) self.Slists.append(Slist) self.Vlists.append(Vlist) self.num_contacts = len(self.indsLists) # Use surfGBAt() object to store the atomic Bethe lattice green's function for each contact self.gList = [] for Slist, Vlist, vecs in zip(self.Slists, self.Vlists, self.dirLists): self.gList.append(surfGAt3D(self.H0.copy(), Slist, Vlist,vecs, eta, T)) # Calculate fermi level fermi = self.gList[0].calcFermi(self.ne/2) for g in self.gList: g.fermi = fermi g.fermi0 = fermi # Store variables self.cList = cList #first contact coords, used for testing self.F = F self.S = S self.eta = eta
[docs] def genNeighbors(self, plane_normal, first_neighbor): """ Generate 12 nearest neighbor unit vectors for an FCC [111] surface. Creates a list of unit vectors representing the 12 nearest neighbors in an FCC lattice: - 6 in-plane vectors forming a hexagonal pattern (3 pairs of opposite vectors) - 6 out-of-plane vectors forming triangular patterns (3 pairs of opposite vectors) Parameters ---------- plane_normal : ndarray Vector normal to the crystal plane (will be normalized) first_neighbor : ndarray Vector to one nearest neighbor (will be projected onto plane) Returns ------- list 12 unit vectors representing nearest neighbor directions """ # Project first_neighbor onto plane perpendicular to plane_normal proj = first_neighbor - jnp.dot(first_neighbor, plane_normal) * plane_normal first_neighbor = proj / jnp.linalg.norm(proj) # Generate in-plane vectors using 60-degree rotations in_plane_vectors = [] rotation_angle = jnp.pi / 3 # 60 degrees for i in range(3): angle = i * rotation_angle # Rodrigues rotation formula cos_theta = jnp.cos(angle) sin_theta = jnp.sin(angle) K = jnp.array([[0, -plane_normal[2], plane_normal[1]], [plane_normal[2], 0, -plane_normal[0]], [-plane_normal[1], plane_normal[0], 0]]) R = jnp.eye(3) + sin_theta * K + (1 - cos_theta) * jnp.matmul(K, K) rotated_vector = jnp.dot(R, first_neighbor) in_plane_vectors.append(rotated_vector / jnp.linalg.norm(rotated_vector)) # Generate out-of-plane vectors out_of_plane_angle = jnp.arccos(1/jnp.sqrt(3)) # ~54.74 out_of_plane_vectors = [] # Add 30° = pi/6 rotation to base vector before going out of plane rot_angle = jnp.pi/6 K = jnp.array([[0, -plane_normal[2], plane_normal[1]], [plane_normal[2], 0, -plane_normal[0]], [-plane_normal[1], plane_normal[0], 0]]) R = jnp.eye(3) + jnp.sin(rot_angle) * K + (1 - jnp.cos(rot_angle)) * jnp.matmul(K, K) rotated_first = jnp.dot(R, first_neighbor) out_of_plane_base = jnp.cos(out_of_plane_angle) * rotated_first + \ jnp.sin(out_of_plane_angle) * plane_normal for i in range(3): angle = i * 2 * jnp.pi / 3 # 120 degree rotations cos_theta = jnp.cos(angle) sin_theta = jnp.sin(angle) K = jnp.array([[0, -plane_normal[2], plane_normal[1]], [plane_normal[2], 0, -plane_normal[0]], [-plane_normal[1], plane_normal[0], 0]]) R = jnp.eye(3) + sin_theta * K + (1 - cos_theta) * jnp.matmul(K, K) rotated_vector = jnp.dot(R, out_of_plane_base) out_of_plane_vectors.append(rotated_vector) # Add corresponding opposite vectors at the (k+6)%12 location all_vectors = in_plane_vectors + out_of_plane_vectors for i in range(6): all_vectors.append(-all_vectors[i]) # Return vectors return all_vectors
# Read parameters from filename.bethe file, check values, store into dicts
[docs] def readBetheParams(self, filename): """ Read Slater-Koster parameters from a .bethe file. Reads and validates parameters for minimal basis with single s, p, and d orbitals. Parameters are stored in dictionaries for onsite energies, hopping integrals, and overlap matrices. Parameters ---------- filename : str Name of the .bethe file (without extension) Raises ------ AssertionError If parameters are missing or invalid Notes ----- Parameters are sorted into: - Edict: Onsite energies (converted from Hartrees to eV) - Vdict: Hopping parameters (converted from Hartrees to eV) - Sdict: Overlap parameters """ params = {} with open(filename+'.bethe', 'r') as f: for line in f: # Skip empty lines if not line.strip(): continue # Split on comma and strip whitespace line = line.replace(' ','') key, value = line.split('=') params[key] = float(value) # Check to make sure parameters are all specified # Note: set up only for minimal basis with single s, p, and d orbital expected_keys = ['ne', 'es', 'ep', 'edd', 'edt', 'sss', 'sps', 'pps', 'ppp', 'sds', 'pds', 'pdp', 'dds', 'ddp', 'ddd', 'Ssss', 'Ssps', 'Spps', 'Sppp', 'Ssds', 'Spds', 'Spdp', 'Sdds', 'Sddp', 'Sddd'] assert len(params.keys()) == len(expected_keys) and set(params.keys()) == set(expected_keys), \ f"Error reading file: Found Bethe parameters: {list(params.keys())}, expected: {expected_keys}" # sort parameters and convert Hartrees to eV self.ne = params['ne'] self.Edict = {k[1:]:params[k]*har_to_eV for k in params if k.startswith('e')} self.Sdict = {k[1:]:params[k] for k in params if k.startswith('S')} self.Vdict = {k:params[k]*har_to_eV for k in params if not k.startswith('e') and not k.startswith('S')} # Setup onsite H0 matrix before Fermi level shifting Hdiag = jnp.array([self.Edict['s']]+ [self.Edict['p']]*3 + \ [self.Edict['dd']]+ [self.Edict['dt']]*2 + [self.Edict['dd'], self.Edict['dt']]) self.H0 = jnp.diag(Hdiag)
[docs] def constructMat(self, Mdict, dirCosines): """ Construct hopping/overlap matrix using Slater-Koster formalism. Builds a 9x9 matrix for s, p, and d orbital interactions based on the Slater-Koster two-center approximation. The matrix is first constructed assuming a [0,0,1] bond direction, then rotated to the given direction using direction cosines. Parameters ---------- Mdict : dict Dictionary of Slater-Koster parameters (ssσ, spσ, ppσ, etc.) dirCosines : ndarray Array [l,m,n] of direction cosines for the bond Returns ------- ndarray 9x9 matrix containing orbital interactions in the rotated frame Notes ----- Matrix blocks: - [0,0]: s-s interaction - [0:4,0:4]: s-p block - [0:4,4:9]: s-d and p-d blocks - [4:9,4:9]: d-d block """ M = jnp.zeros((dim, dim)) #Original matrix before rotation - assuming [0,0,1] bond direction # s-s coefficient M = M.at[0,0].set(Mdict['sss']) # s-p block M = M.at[0,3].set(Mdict['sps']) #s-pz M = M.at[3,0].set(-Mdict['sps']) #pz-s # p-p block M = M.at[1,1].set(Mdict['ppp']) #px-px M = M.at[2,2].set(Mdict['ppp']) #py-py M = M.at[3,3].set(Mdict['pps']) #pz-pz # s-d block M = M.at[0, 4].set(Mdict['sds']) #s - d3z²-r² M = M.at[4, 0].set(Mdict['sds']) # p-d block M = M.at[1,5].set(Mdict['pdp']) #px - dxz M = M.at[2,6].set(Mdict['pdp']) #py - dyz M = M.at[3,4].set(Mdict['pds']) #pz - d3z²-r² M = M.at[5,1].set(-Mdict['pdp']) #dxz - px M = M.at[6,2].set(-Mdict['pdp']) #dyz - py M = M.at[4,3].set(-Mdict['pds']) #d3z²-r² - pz # d-d block M = M.at[4,4].set(Mdict['dds']) #d3z²-r² - d3z²-r² M = M.at[5,5].set(Mdict['ddp']) #dxz - dxz M = M.at[6,6].set(Mdict['ddp']) #dyz - dyz M = M.at[7,7].set(Mdict['ddd']) #dx²-y² - dx²-y² M = M.at[8,8].set(Mdict['ddd']) #dxy - dxy # Initialize 9x9 transformation matrix and polar directions tr = jnp.zeros((9, 9)) x, y, z = dirCosines theta = jnp.arccos(z) # polar angle from z-axis phi = jnp.arctan2(y, x) # azimuthal angle in x-y plane # s orbital (1x1) at position [0,0] - always 1 since spherically symmetric tr = tr.at[0,0].set(1.0) # p orbitals (3x3) at positions [1:4,1:4] # [px,py,pz] block - describes how p orbitals transform under rotation tr = tr.at[1:4,1:4].set(jnp.array([ [jnp.cos(theta) * jnp.cos(phi), -jnp.sin(phi) , jnp.sin(theta)*jnp.cos(phi)], [jnp.cos(theta) * jnp.sin(phi), jnp.cos(phi) , jnp.sin(theta)*jnp.sin(phi)], [-jnp.sin(theta) , 0 , jnp.cos(theta)] ])) # d orbitals (5x5) at positions [4:9,4:9] # [d3z2-r2, dxz, dyz, dx2-y2, dxy] block - transforms the five d orbitals d_block = jnp.zeros((5,5)) # Copying formula from ANT.Gaussian directly d_block = d_block.at[0,0].set((3 * z**2 - 1) / 2) d_block = d_block.at[0,1].set(-jnp.sqrt(3) * jnp.sin(2*theta) / 2) d_block = d_block.at[0,3].set(jnp.sqrt(3) * jnp.sin(theta)**2 / 2) d_10 = jnp.sqrt(3) * jnp.sin(2*theta) * jnp.cos(phi) / 2 d_block = d_block.at[1,0].set(d_10) d_block = d_block.at[1,1].set(jnp.cos(2*theta) * jnp.cos(phi)) d_block = d_block.at[1,2].set(-jnp.cos(theta) * jnp.sin(phi)) d_block = d_block.at[1,3].set(-d_10 / jnp.sqrt(3)) d_block = d_block.at[1,4].set(jnp.sin(theta) * jnp.sin(phi)) d_20 = jnp.sqrt(3) * jnp.sin(2*theta) * jnp.sin(phi) / 2 d_block = d_block.at[2,0].set(d_20) d_block = d_block.at[2,1].set(jnp.cos(2*theta) * jnp.sin(phi)) d_block = d_block.at[2,2].set(jnp.cos(theta) * jnp.cos(phi)) d_block = d_block.at[2,3].set(-d_20 / jnp.sqrt(3)) d_block = d_block.at[2,4].set(-jnp.sin(theta) * jnp.cos(phi)) d_block = d_block.at[3,0].set(jnp.sqrt(3) * jnp.sin(theta)**2 * jnp.cos(2*phi) / 2) d_block = d_block.at[3,1].set(jnp.sin(2*theta) * jnp.cos(2*phi) / 2) d_block = d_block.at[3,2].set(-np.sin(theta) * jnp.sin(2*phi)) d_block = d_block.at[3,3].set((1 + jnp.cos(theta)**2) * jnp.cos(2*phi) / 2) d_block = d_block.at[3,4].set(-np.cos(theta) * jnp.sin(2*phi)) d_block = d_block.at[4,0].set(jnp.sqrt(3) * jnp.sin(theta)**2 * jnp.sin(2*phi) / 2) d_block = d_block.at[4,1].set(jnp.sin(2*theta) * jnp.sin(2*phi) / 2) d_block = d_block.at[4,2].set(jnp.sin(theta) * jnp.cos(2*phi)) d_block = d_block.at[4,3].set((1 + jnp.cos(theta)**2) * jnp.sin(2*phi) / 2) d_block = d_block.at[4,4].set(jnp.cos(theta) * jnp.cos(2*phi)) tr = tr.at[4:9,4:9].set(d_block) # Apply transformation return tr @ M @ tr.T
[docs] def sigma(self, E, i, conv=1e-4): """ Calculate self-energy matrix for a specific contact. Computes the self-energy matrix for contact i by: 1. Calculating surface self-energies for all 9 directions 2. Summing contributions from directions not connected to the device 3. Applying de-orthonormalization if needed 4. Handling spin configurations Parameters ---------- E : float Energy point for self-energy calculation (in eV) i : int Index of the contact to calculate self-energy for conv : float, optional Convergence criterion for self-energy calculation (default: 1e-4) Returns ------- ndarray Self-energy matrix for the specified contact, with dimensions: - (N, N) for restricted calculations - (2N, 2N) for unrestricted or generalized spin calculations References ---------- [1] Jacob, D., & Palacios, J. J. (2011). Critical comparison of electrode models in density functional theory based quantum transport calculations. The Journal of Chemical Physics, 134(4), 044118. DOI: 10.1063/1.3526044 """ # Shift E to reference frame of immutable H0/Vlist0 E_shifted = E - self.gList[i].dFermi sig = jnp.zeros((self.N, self.N), dtype=complex) G_AB = self.gList[i].gSurf(E_shifted, conv) # Apply self energies in first 9 directions that aren't attached to atom for nInds, Finds in zip(self.nIndLists[i], self.indsLists[i]): sigInds = list(set(range(9)) - {int(x) for x in nInds}) sigAtom = self.gList[i].sigmaSurf(E_shifted, active_dirs=sigInds, G_AB=G_AB) sig = sig.at[jnp.ix_(Finds, Finds)].set(sigAtom) # Apply de-orthonormalization technique from ANT.Gaussian if orthonormal if self.Sdict['sss'] == 0: sig = times(self.Xi, sig, self.Xi) if self.spin == 'u' or self.spin == 'ro': sig = jnp.kron(jnp.eye(2), sig) elif self.spin =='g': sig = jnp.kron(sig, jnp.eye(2)) return sig
[docs] def sigmaTot(self, E, conv=1e-4): """ Calculate total self-energy matrix from all contacts. Sums sigma(E, i) over all contacts and returns the result in the full device basis. Parameters ---------- E : float Energy point for Green's function calculation (in eV) conv : float, optional Convergence criterion for self-energy calculation (default: 1e-5) Returns ------- ndarray Total self-energy matrix in the full device basis """ sigs = [self.sigma(E, i, conv) for i in range(len(self.indsLists))] return sum(sigs)
[docs] def crossTermQ(self, E, i, conv=1e-4): """Cross-term Q_sym for contact i in full device basis. Mirrors surfG3.sigma: computes G_AB via gSurf, then for each atom in contact i assembles Q using active directions, applies Xi. """ E_shifted = E - self.gList[i].dFermi sig = jnp.zeros((self.N, self.N), dtype=complex) G_AB = self.gList[i].gSurf(E_shifted, conv) for nInds, Finds in zip(self.nIndLists[i], self.indsLists[i]): sigInds = list(set(range(9)) - {int(x) for x in nInds}) Q_atom = self.gList[i].crossTermQSurf(E_shifted, active_dirs=sigInds, G_AB=G_AB) sig = sig.at[jnp.ix_(Finds, Finds)].set(Q_atom) if self.Sdict['sss'] == 0: sig = times(self.Xi, sig, self.Xi) if self.spin == 'u' or self.spin == 'ro': sig = jnp.kron(jnp.eye(2), sig) elif self.spin == 'g': sig = jnp.kron(sig, jnp.eye(2)) return sig
[docs] def crossTermQTot(self, E, conv=1e-4): """Total cross-term Q_sym from all contacts.""" qs = [self.crossTermQ(E, i, conv) for i in range(len(self.indsLists))] return sum(qs)
[docs] def getSigma(self, Elist=[None, None], conv=1e-4): """ Helper method for getting the left and right contact self-energies Parameters ---------- Elist : tuple, optional A list of contact energies for selecting sigma, (default: use contact Fermi energy) conv: float, optional Convergence criterion for the self-energy matrix Returns ------- tuple A tuple of both self-energy matrices (ndarrays) """ if Elist[0] is None: Elist[0] = self.gList[0].fermi if Elist[1] is None: Elist[1] = self.gList[-1].fermi return (self.sigma(Elist[0], 0, conv), self.sigma(Elist[1], -1, conv))
[docs] def updateFermi(self, i, Ef): """ Update Fermi energy for a specific contact. Shifts the Hamiltonian of contact i to align its Fermi level with the specified energy. Parameters ---------- i : int Contact index Ef : float New Fermi energy in eV """ self.gList[i].updateH(Ef)
[docs] def setF(self, F, muL, muR): """ Update Fock matrix and contact chemical potentials. Sets the Fock matrix and updates the Fermi levels of the left and right contacts if they have changed. Parameters ---------- F : ndarray New Fock matrix muL : float Chemical potential for left contact in eV muR : float Chemical potential for right contact in eV """ self.F = F if self.gList[0].fermi != muL: self.updateFermi(0, muL) if self.gList[-1].fermi != muR: self.updateFermi(-1, muR)
## TESTING METHODS FOR SLATER-KOSTER INTERACTIONS:
[docs] def testDOrbitalFunctions(self): """ Test d orbital angular functions. Validates the angular dependence of d orbital interactions by checking: - dxy interaction along x-axis (should be zero) - dx2-y2 interaction along x-axis (should be sqrt(3)/2 * sds) - dz2 interaction along x-axis (should be -1/2 * sds) """ # Use values from the Bethe parameter dictionaries Vdict = self.Vdict # Contains hopping parameters # Test along x-axis [1,0,0] M = self.constructMat(self.Vdict, [1, 0, 0]) # dxy should be zero along x-axis jnp.testing.assert_almost_equal(M[0,8], 0.0, err_msg="dxy not zero along x-axis") # dx2-y2 should be sqrt(3)/2 * sds along x-axis jnp.testing.assert_almost_equal(M[0,7], jnp.sqrt(3)/2 * Vdict['sds'], err_msg="dx2-y2 incorrect along x-axis") # dz2 should be -1/2 along x-axis jnp.testing.assert_almost_equal(M[0,4], -0.5 * Vdict['sds'], err_msg="dz2 incorrect along x-axis") print("d orbital angular function tests passed!")
[docs] def testDOrbitalSymmetry(self): """ Test d orbital symmetry properties. Validates that d orbital interactions respect inversion symmetry by comparing interactions along opposite directions. """ # Test inversion symmetry dir1 = [1/jnp.sqrt(2), 1/np.sqrt(2), 0] dir2 = [-1/jnp.sqrt(2), -1/np.sqrt(2), 0] M1 = self.constructMat(self.Vdict, dir1) M2 = self.constructMat(self.Vdict, dir2) # d-d block should be identical under inversion jnp.testing.assert_array_almost_equal( M1[4:,4:], M2[4:,4:], err_msg="d-d block not symmetric under inversion") print("d orbital symmetry tests passed!")
[docs] def testPDInteraction(self): """ Test p-d orbital interactions. Validates p-d orbital interactions by checking: - px-dxy interaction along x-axis (should be zero) - pz-dz2 interaction along z-axis (should be pure sigma) """ Vdict = self.Vdict # Test px-dxy interaction along x-axis M = self.constructMat(Vdict, [1, 0, 0]) # px-dxy should be zero along x-axis jnp.testing.assert_almost_equal( M[1,8], 0.0, err_msg="px-dxy interaction incorrect along x-axis") # Test pz-dz2 interaction along z-axis M = self.constructMat(Vdict, [0, 0, 1]) expected = Vdict['pds'] # Should be pure sigma jnp.testing.assert_almost_equal( M[3,4], expected, err_msg="pz-dz2 interaction incorrect along z-axis") print("p-d interaction tests passed!")
[docs] def testDDInteraction(self): """ Test d-d orbital interactions. Validates d-d orbital interactions by checking: - dyz-dyz interaction along x-axis (should be pure delta) - dz2-dz2 interaction along z-axis (should be pure sigma) """ Vdict = self.Vdict # Test dyz-dyz interaction along x-axis M = self.constructMat(Vdict, [1, 0, 0]) # Should be pure delta interaction expected = Vdict['ddd'] jnp.testing.assert_almost_equal( M[6,6], expected, err_msg="dyz-dyz interaction incorrect along x-axis") # Test dz2-dz2 interaction along x-axis M = self.constructMat(Vdict, [0, 0, 1]) # Should be pure sigma interaction expected = Vdict['dds'] jnp.testing.assert_almost_equal( M[4,4], expected, err_msg="dz2-dz2 interaction incorrect along z-axis") print("d-d interaction tests passed!")
[docs] def testHoppingPhysics(self): """ Test physical properties of hopping matrices. Validates hopping matrix physics by checking: - s-p hopping antisymmetry - Conservation of total s-p hopping magnitude - Proper angular dependence along principal axes and 45-degree rotations """ eps = 1e-10 # Tolerance for floating point comparisons # Get reference hopping values from [0,0,1] configuration s_p_mag = abs(self.Vdict['sps']) # Magnitude of s-p hopping # Test set of physically important directions test_cases = [ # Principal axes ([0, 0, 1], "z-axis"), ([1, 0, 0], "x-axis"), ([0, 1, 0], "y-axis"), # 45-degree rotations ([1/jnp.sqrt(2), 0, 1/jnp.sqrt(2)], "45° in xz-plane"), ([0, 1/jnp.sqrt(2), 1/jnp.sqrt(2)], "45° in yz-plane"), ([1/jnp.sqrt(2), 1/jnp.sqrt(2), 0], "45° in xy-plane"), ] print("\nTesting hopping matrix physics...") for direction, name in test_cases: direction = jnp.array(direction) x, y, z = direction print(f"\nChecking {name} direction: [{x:.3f}, {y:.3f}, {z:.3f}]") V = self.constructMat(self.Vdict, direction) # Check s-p hopping antisymmetry for i in range(1, 4): # Check all p orbitals assert abs(V[0,i] + V[i,0]) < eps, \ f"s-p hopping not antisymmetric for p{i}" # Check total s-p hopping magnitude is preserved s_p_total = jnp.sqrt(V[0,1]**2 + V[0,2]**2 + V[0,3]**2) assert abs(s_p_total - s_p_mag) < eps, \ f"s-p hopping magnitude not preserved: {s_p_total:.6f} != {s_p_mag:.6f}" # Print values for verification print(f"s-px: {V[0,1]:.3f}, px-s: {V[1,0]:.3f}") print(f"s-py: {V[0,2]:.3f}, py-s: {V[2,0]:.3f}") print(f"s-pz: {V[0,3]:.3f}, pz-s: {V[3,0]:.3f}") print(f"Total s-p magnitude: {s_p_total:.3f}") print("\nAll hopping physics tests passed!")
[docs] def runAllTests(self): """ Run all validation tests for surfG. Executes all test methods to validate: - d orbital angular functions - d orbital symmetry - p-d interactions - d-d interactions - General hopping physics """ print("Running Slater-Koster projection tests...") self.testDOrbitalFunctions() self.testDOrbitalSymmetry() self.testPDInteraction() self.testDDInteraction() self.testHoppingPhysics() print("\nAll tests passed!")
# 3D surface Green's function for a single atom
[docs] class surfGAt3D: """ Atomic-level 3D Green's function calculator for a single atom. This class implements the surface Green's function calculation for a single atom in the 3D lattice with [111] surface, handling: - Onsite and hopping matrix construction - Self-energy calculations for bulk and surface Parameters ---------- H : ndarray Onsite Hamiltonian matrix (9x9 for minimal basis) Slist : list of ndarray List of 12 overlap matrices for nearest neighbors Vlist : list of ndarray List of 12 hopping matrices for nearest neighbors eta : float Broadening parameter in eV T : float, optional Temperature in Kelvin (default: 0) Attributes ---------- NN : int Number of nearest neighbors (fixed to 12 for FCC) fermi : float Current Fermi energy F : ndarray Extended Fock matrix including neighbors S : ndarray Extended overlap matrix including neighbors """
[docs] def __init__(self, H, Slist, Vlist, vecs, eta, T=TEMPERATURE, kPoints=11): """ Initialize surfGAt3D with Hamiltonian and neighbor matrices. Parameters ---------- H : ndarray Onsite Hamiltonian matrix (9x9 for minimal basis) Slist : list of ndarray List of 12 overlap matrices for nearest neighbors Vlist : list of ndarray List of 12 hopping matrices for nearest neighbors, first six in plane, eta : float Broadening parameter in eV (recommended: >= 1e-4 for numerical stability) T : float, optional Temperature in Kelvin (default: 0) kPoints : int, optional Number of k-points per direction for BZ integration (default: 11) Recommended values: - 11: Fast, reasonable accuracy (121 points total) - 15: Good balance of speed/accuracy (225 points) - 21: High accuracy for production (441 points) Note: Odd values use Monkhorst-Pack grid, even use Gamma-centered Raises ------ AssertionError If matrix dimensions are incorrect or number of neighbors != 12 """ assert jnp.shape(H) == (dim,dim), f"Error with H dim, should be {dim}x{dim}" for S,V in zip(Slist, Vlist): assert jnp.shape(S) == (dim,dim), f"Error with S dim, should be {dim}x{dim}" assert jnp.shape(V) == (dim,dim), f"Error with F dim, should be {dim}x{dim}" self.H = jnp.array(H) self.H0 = jnp.array(H) # immutable reference (never mutated) self.Slist = jnp.array(Slist) self.Vlist = jnp.array(Vlist) self.Vlist0 = jnp.array(Vlist) # immutable reference (never mutated) self.NN = len(Slist) self.kPoints = kPoints assert self.NN == 12, "Error: surfGAt3D only implemented for FCC using 12 NN" assert len(vecs) == 12, "Error: surfGAt3D only implemented for FCC using 12 NN" self.vecs = jnp.array(vecs) # Define 3D lattice vectors for bulk periodicity # For FCC [111]: vecs[0,1,2] are in-plane (z approx 0), vecs[3,4,5] are out-of-plane (+z) self.a1 = self.vecs[0] # First in-plane vector self.a2 = self.vecs[1] # Second in-plane vector self.a3 = self.vecs[3] # First out-of-plane vector (upward) # Reciprocal lattice vectors computed separately for 2D (surface) and 3D (bulk) # See _setup_kmesh_2D() and _setup_kmesh_3D() self.eta = eta self.T = T self.fermi = None self.fermi0 = None # reference Fermi level self.dFermi = 0.0 # shift from reference self.num_contacts = 1 # Pre-compute k-mesh for surface (2D) and bulk (3D) self._setup_kmesh_2D() self._setup_kmesh_3D() self.updateH()
[docs] def updateH(self, fermi=None): """ Update Hamiltonian and protocol matrices F/S. Updates onsite and hopping matrices. F and S are the 9x9 single-atom Hamiltonian and identity overlap, compatible with density.py and the SurfGProtocol interface. Self-energies from bulk neighbors are provided via sigmaTot()/sigma(). Parameters ---------- fermi : float, optional New Fermi energy setpoint in eV (default: None) Notes ----- When fermi is provided and different from current value: - Shifts onsite energies by the Fermi level difference - Updates hopping matrices with overlap contributions """ if fermi is not None and fermi != self.fermi: self.fermi = fermi self.dFermi = fermi if self.fermi0 is None else fermi - self.fermi0 if self.fermi0 is None: self.fermi0 = fermi # Build H and Vlist from H0/Vlist0 plus current dFermi (for non-JIT callers) self.H = self.H0 + self.dFermi * jnp.eye(dim) self.Vlist = jnp.array([self.Vlist0[j] + self.dFermi * self.Slist[j] for j in range(self.NN)]) self.F = self.H # 9x9 (dim x dim) self.S = jnp.eye(dim)
def _setup_kmesh_2D(self): """ Pre-compute 2D k-mesh and phase factors for surface Green's function. Computes 2D reciprocal lattice vectors using surface normal approach: - Surface normal n = a1 × a2 (perpendicular to surface) - 2D reciprocal vectors lie in surface plane, perpendicular to n - This ensures k_z = 0 for all k-points (no periodicity perpendicular to surface) Sets up: - self.b1_2D, self.b2_2D: 2D reciprocal lattice vectors (z-component = 0) - self.kmesh_2D: 2D Cartesian k-points in reciprocal space (nK^2 x 3) - self.expList_2D: Phase factors exp(+ik·r) for all vectors (nK^2 x 12) """ # Compute surface normal (perpendicular to a1 and a2) surface_normal = jnp.cross(self.a1, self.a2) n_hat = surface_normal / jnp.linalg.norm(surface_normal) # Compute 2D reciprocal vectors using surface normal # These lie in the plane perpendicular to n_hat # Formula: b1 = 2π * (a2 × n) / [a1 · (a2 × n)] # b2 = 2π * (n × a1) / [a2 · (n × a1)] cross_a2_n = jnp.cross(self.a2, n_hat) cross_n_a1 = jnp.cross(n_hat, self.a1) denom1 = jnp.dot(self.a1, cross_a2_n) denom2 = jnp.dot(self.a2, cross_n_a1) self.b1_2D = 2 * jnp.pi * cross_a2_n / denom1 self.b2_2D = 2 * jnp.pi * cross_n_a1 / denom2 # Verify orthogonality: n · b1 = 0, n · b2 = 0 (perpendicular to surface normal) # This ensures k-points lie in surface plane # Setup k-mesh: Gamma-centered for even, Monkhorst-Pack for odd if self.kPoints % 2 == 0: k = jnp.arange(self.kPoints) / self.kPoints - 0.5 + 0.5/self.kPoints else: k = (2 * jnp.arange(self.kPoints) + 1) / (2 * self.kPoints) - 0.5 K1, K2 = jnp.meshgrid(k, k, indexing='ij') self.kmesh_2D = K1.flatten()[:, jnp.newaxis] * self.b1_2D + \ K2.flatten()[:, jnp.newaxis] * self.b2_2D # nK^2 x 3 # Phase factors: exp(+ik·r) for forward Fourier transform self.expList_2D = jnp.exp(+1j * self.kmesh_2D @ (self.vecs.T)) # nK^2 x 12 def _setup_kmesh_3D(self): """ Pre-compute 3D k-mesh and phase factors for bulk Green's function. Computes 3D reciprocal lattice vectors from a1, a2, a3: - Full 3D periodicity (all 12 neighbors included) - Reciprocal vectors may have non-zero components in all directions Sets up: - self.b1_3D, self.b2_3D, self.b3_3D: 3D reciprocal lattice vectors - self.kmesh_3D: 3D Cartesian k-points in reciprocal space (nK^3 x 3) - self.expList_3D: Phase factors exp(+ik·r) for all vectors (nK^3 x 12) """ # Compute 3D reciprocal lattice vectors # Formula: b_i = 2π * (a_j × a_k) / [a_i · (a_j × a_k)] vol = jnp.dot(self.a1, jnp.cross(self.a2, self.a3)) self.b1_3D = 2 * jnp.pi * jnp.cross(self.a2, self.a3) / vol self.b2_3D = 2 * jnp.pi * jnp.cross(self.a3, self.a1) / vol self.b3_3D = 2 * jnp.pi * jnp.cross(self.a1, self.a2) / vol # Setup k-mesh: Gamma-centered for even, Monkhorst-Pack for odd if self.kPoints % 2 == 0: k = jnp.arange(self.kPoints) / self.kPoints - 0.5 + 0.5/self.kPoints else: k = (2 * jnp.arange(self.kPoints) + 1) / (2 * self.kPoints) - 0.5 K1, K2, K3 = jnp.meshgrid(k, k, k, indexing='ij') # Flatten and stack into nK^3 x 3 array kmesh_flat = jnp.stack([K1.flatten(), K2.flatten(), K3.flatten()], axis=1) # nK^3 x 3 # Convert fractional coordinates to Cartesian using 3D reciprocal lattice self.kmesh_3D = (kmesh_flat[:, 0:1] * self.b1_3D + kmesh_flat[:, 1:2] * self.b2_3D + kmesh_flat[:, 2:3] * self.b3_3D) # nK^3 x 3 # Phase factors: exp(+ik·r) for forward Fourier transform self.expList_3D = jnp.exp(+1j * self.kmesh_3D @ (self.vecs.T)) # nK^3 x 12
[docs] def gBulk(self, E): """ Calculate bulk Green's function with full 3D periodicity via direct inversion. For bulk with all neighbors included, no Dyson equation needed: g(k) = [(E + iη)S(k) - H(k)]^-1 Uses 3D k-space mesh with full periodicity in all directions. Parameters ---------- E : float Energy point for Green's function calculation (in eV) Returns ------- ndarray Real-space G_AB propagator matrix (12*dim x 12*dim). Block (A,B) at G_AB[A*dim:(A+1)*dim, B*dim:(B+1)*dim] gives the propagator G(R_A - R_B) between directions A and B. """ # Construct H(k) = sum_R V(R)*exp(+ik*R) and S(k) = sum_R S(R)*exp(+ik*R) # Use 3D k-mesh Flist = self.expList_3D[:, :, None, None] * self.Vlist0[None, :, :, :] # nK^3 x 12 x dim x dim Slist = self.expList_3D[:, :, None, None] * self.Slist[None, :, :, :] # nK^3 x 12 x dim x dim # Sum over ALL 12 neighbors (in-plane + out-of-plane) + onsite Hk = jnp.sum(Flist, axis=1) + jnp.repeat(self.H0[None, :, :], self.kPoints**3, axis=0) Sk = jnp.sum(Slist, axis=1) + jnp.repeat(jnp.eye(dim)[None, :, :], self.kPoints**3, axis=0) # Shard k-point data across devices for parallel computation Hk_sharded = shard_array(Hk, axis=0) Sk_sharded = shard_array(Sk, axis=0) # Direct inversion: g(k) = [(E + i*eta)S(k) - H(k)]^-1 # Vectorized over all k-points, automatically parallelized across devices # E is pre-shifted by the caller (wrapper subtracts dFermi) g_k = jax.vmap(lambda H, S: LA.inv((E + self.eta*1j)*S - H))(Hk_sharded, Sk_sharded) # Inverse FT: build 108x108 real-space propagator G_AB # G_AB[A,B] = (1/Nk) sum_k exp(+ik*R_A) * g_k * exp(-ik*R_B) phases = self.expList_3D # nK^3 x 12 nK3 = self.kPoints**3 G_AB_blocks = jnp.einsum('ka,kij,kb->abij', phases, g_k, phases.conj()) / nK3 G_AB = G_AB_blocks.transpose(0, 2, 1, 3).reshape(12 * dim, 12 * dim) return G_AB
[docs] def sigmaBulk(self, E): """ Total bulk self-energy for the center atom (dim x dim). Uses G_AB propagator with all 12 neighbors: Sigma = tau @ G_AB @ bar_tau where tau = [tau_0 | tau_1 | ... | tau_11] is the horizontal concatenation of all 12 phase-free coupling matrices. Parameters ---------- E : float Energy point for self-energy calculation (in eV), pre-shifted by the caller to remove dFermi. Returns ------- ndarray Total self-energy matrix (shape: dim x dim) """ G_AB = self.gBulk(E) # 108 x 108 z = E + self.eta * 1j tau = jnp.concatenate([z * self.Slist[k] - self.Vlist0[k] for k in range(12)], axis=1) # dim x 108 bar_tau = jnp.concatenate([z * self.Slist[k].conj().T - self.Vlist0[k].conj().T for k in range(12)], axis=0) # 108 x dim return tau @ G_AB @ bar_tau # dim x dim
# Calculate Green's function for the surface
[docs] def gSurf(self, E, conv=1e-4, mix=0.1, maxIter=5000): # Set up dyson equation using pre-computed 2D phase factors # Use 2D k-mesh for surface Flist = self.expList_2D[:, :, None, None]*self.Vlist0[None, :, :, :]# nK**2 x NN x dim x dim Slist = self.expList_2D[:, :, None, None]*self.Slist[None, :, :, :]# nK**2 x NN x dim x dim # E is pre-shifted by the caller (wrapper subtracts dFermi) # A matrix: in-plane neighbors only (vecs 0,1,2,6,7,8) # These are the 6 in-plane directions with z=0 in_plane_indices = jnp.array([0, 1, 2, 6, 7, 8]) Fak = jnp.sum(Flist[:, in_plane_indices, :, :], axis=1) + \ jnp.repeat(self.H0[None, :, :], self.kPoints**2, axis=0) Sak = jnp.sum(Slist[:, in_plane_indices, :, :], axis=1) + \ jnp.repeat(jnp.eye(dim)[None, :, :], self.kPoints**2, axis=0) A = (E + self.eta*1j)*Sak - Fak # B matrix: out-of-plane neighbors pointing UP (vecs 3,4,5) # These connect surface to bulk above out_plane_indices = jnp.array([3, 4, 5]) Fbk = jnp.sum(Flist[:, out_plane_indices, :, :], axis=1) Sbk = jnp.sum(Slist[:, out_plane_indices, :, :], axis=1) z = E + self.eta*1j B = z*Sbk - Fbk B_bar = z*jnp.conj(Sbk).transpose(0,2,1) - jnp.conj(Fbk).transpose(0,2,1) # Converge each k-point independently with robust solver def converge_single_k(A_k, B_k, B_bar_k): """ Converge Dyson equation for a single k-point using robust iteration. Solves: g = [A - B @ g @ B†]^-1 Uses adaptive mixing and proper convergence criteria to ensure retarded Green's function with Im[g] < 0. """ def cond_fun(state): count, diff, g, g_ = state return (diff > conv) & (count < maxIter) def body_fun(state): count, diff, g, g_ = state # Compute self-energy sig = B_k @ g @ B_bar_k # Update Green's function gNew = LA.inv(A_k - sig) g_ = g.copy() g = gNew * mix + (1 - mix) * g # Convergence check: use relative change in norm diff = jnp.linalg.norm(g - g_) / (jnp.linalg.norm(g_) + 1e-12) count += 1 return (count, diff, g, g_) # Initialize with bare Green's function (no self-energy) g_init = LA.inv(A_k) - 1j*jnp.eye(dim)*self.eta init_state = (0, jnp.inf, g_init, g_init.copy()) count, diff, g, g_ = jax.lax.while_loop(cond_fun, body_fun, init_state) return g, count, diff # Shard k-point data across devices for parallel computation A_sharded = shard_array(A, axis=0) B_sharded = shard_array(B, axis=0) B_bar_sharded = shard_array(B_bar, axis=0) # Vectorize over all k-points, automatically parallelized across devices g_k, counts, diffs = jax.vmap(converge_single_k)(A_sharded, B_sharded, B_bar_sharded) # Inverse Fourier transform: build 81x81 real-space propagator G_AB # G_AB[A*dim:(A+1)*dim, B*dim:(B+1)*dim] = (1/Nk) sum_k exp(+ik*R_A) * g_k * exp(-ik*R_B) # Only surface directions [0..8] surf_dirs = jnp.arange(9) phases = self.expList_2D[:, surf_dirs] # nK^2 x 9 nK2 = self.kPoints**2 G_AB_blocks = jnp.einsum('ka,kij,kb->abij', phases, g_k, phases.conj()) / nK2 # Reshape from (9, 9, dim, dim) block form to (81, 81) matrix G_AB = G_AB_blocks.transpose(0, 2, 1, 3).reshape(9 * dim, 9 * dim) return G_AB
[docs] def sigmaSurf(self, E, active_dirs=None, conv=1e-4, mix=0.1, G_AB=None): """ Calculate surface self-energy using G_AB propagator. Computes Sigma = tau @ G_S @ tau' where: - G_S is the sub-block of G_AB for active directions - tau = [tau_0 | tau_1 | ...] is the horizontal concatenation of phase-free coupling matrices tau_a = (E+i*eta)*S_a - V_a Parameters ---------- E : float Energy point for Green's function calculation (in eV) active_dirs : list of int, optional Direction indices to include. Default None means all 9 surface directions [0..8]. conv : float, optional Convergence criterion for Dyson equation (default: 1e-4) mix : float, optional Mixing factor for Dyson equation (default: 0.1) G_AB : ndarray, optional Pre-computed G_AB propagator from gSurf(). If None, gSurf() is called internally. Returns ------- ndarray Self-energy matrix of shape (dim, dim). """ if active_dirs is None: active_dirs = list(range(9)) if G_AB is None: G_AB = self.gSurf(E, conv, mix) # 81 x 81 # Build tau: horizontal concat of phase-free coupling matrices # active_dirs must be a static Python list (not a traced JAX value) # E is pre-shifted by the caller z = E + self.eta*1j tau_blocks = [z * self.Slist[a] - self.Vlist0[a] for a in active_dirs] tau = jnp.concatenate(tau_blocks, axis=1) # dim x (nDirs*dim) bar_tau_blocks = [z * self.Slist[a].conj().T - self.Vlist0[a].conj().T for a in active_dirs] bar_tau = jnp.concatenate(bar_tau_blocks, axis=0) # (nDirs*dim) x dim # Extract sub-block of G_AB for active directions # G_AB is 81x81 with block structure [A*dim:(A+1)*dim, B*dim:(B+1)*dim] row_idx = jnp.concatenate([jnp.arange(a*dim, (a+1)*dim) for a in active_dirs]) col_idx = row_idx # square sub-block G_sub = G_AB[jnp.ix_(row_idx, col_idx)] # (nDirs*dim) x (nDirs*dim) # Self-energy: tau @ G_sub @ bar_tau return tau @ G_sub @ bar_tau # dim x dim
[docs] def crossTermQSurf(self, E, active_dirs=None, conv=1e-4, mix=0.1, G_AB=None): """Symmetrized cross-term Q_sym using G_AB propagator. Q_sym = (tau @ G_sub @ S_LD + S_DL @ G_sub @ tau^dagger) / 2 where tau is the same phase-free coupling used in sigma(). """ if active_dirs is None: active_dirs = list(range(9)) if G_AB is None: G_AB = self.gSurf(E, conv, mix) # E is pre-shifted by the caller z = E + self.eta * 1j tau_blocks = [z * self.Slist[a] - self.Vlist0[a] for a in active_dirs] tau = jnp.concatenate(tau_blocks, axis=1) # (dim, nDirs*dim) # S_LD: stack of S[a]^H for each active direction, shape (nDirs*dim, dim) S_LD = jnp.concatenate([self.Slist[a].conj().T for a in active_dirs], axis=0) # S_DL: row stack of S[a] for each active direction, shape (dim, nDirs*dim) S_DL = jnp.concatenate([self.Slist[a] for a in active_dirs], axis=1) # bar_tau: right-side coupling, uses z (not z*) with S^H, V^H bar_tau = jnp.concatenate([z * self.Slist[a].conj().T - self.Vlist0[a].conj().T for a in active_dirs], axis=0) row_idx = jnp.concatenate([jnp.arange(a * dim, (a + 1) * dim) for a in active_dirs]) G_sub = G_AB[jnp.ix_(row_idx, row_idx)] Q_fwd = tau @ G_sub @ S_LD # (dim, dim) Q_rev = S_DL @ G_sub @ bar_tau # (dim, dim) return (Q_fwd + Q_rev) / 2
# Empty function for compatibility with density.py methods
[docs] def setF(self, F, mu1, mu2): """ Empty function for compatibility with density.py methods. Bethe lattice bulk properties are intrinsic (dependent on TB parameters). Parameters ---------- F : ndarray Fock matrix (unused) mu1 : float First chemical potential (unused) mu2 : float Second chemical potential (unused) """ pass # Bethe lattice bulk properties are intrinsic (dependent on TB parameters)
[docs] def sigmaTot(self, E): """ Total self-energy for the center atom (dim x dim). Delegates to sigmaBulk after removing the Fermi shift so that H0/Vlist0 (immutable reference frame) are used. Parameters ---------- E : float Energy point for self-energy calculation (in eV) Returns ------- ndarray Total self-energy matrix (shape: dim x dim) """ return self.sigmaBulk(E - self.dFermi)
[docs] def sigma(self, E, i): """Self-energy for contact i (only i=0, single bulk contact). Parameters ---------- E : float Energy point in eV i : int Contact index (must be 0) Returns ------- ndarray Self-energy matrix (shape: dim x dim) """ return self.sigmaTot(E)
[docs] def crossTermQBulk(self, E): """Bulk cross-term Q_sym using G_AB propagator (dim x dim). Computes the symmetrized cross-term for Mulliken population correction: Q_sym = (tau @ G_AB @ S_LD + S_DL @ G_AB @ bar_tau) / 2 where tau and bar_tau use all 12 neighbor directions. Parameters ---------- E : float Energy point in eV, pre-shifted by the caller to remove dFermi. Returns ------- ndarray Symmetrized cross-term matrix (shape: dim x dim) """ G_AB = self.gBulk(E) # 108 x 108 z = E + self.eta * 1j tau = jnp.concatenate([z * self.Slist[k] - self.Vlist0[k] for k in range(12)], axis=1) bar_tau = jnp.concatenate([z * self.Slist[k].conj().T - self.Vlist0[k].conj().T for k in range(12)], axis=0) S_LD = jnp.concatenate([self.Slist[k].conj().T for k in range(12)], axis=0) S_DL = jnp.concatenate([self.Slist[k] for k in range(12)], axis=1) Q_fwd = tau @ G_AB @ S_LD Q_rev = S_DL @ G_AB @ bar_tau return (Q_fwd + Q_rev) / 2
[docs] def crossTermQ(self, E, i): """Cross-term Q_sym for contact i (delegates to crossTermQBulk). Parameters ---------- E : float Energy point in eV i : int Contact index (must be 0) Returns ------- ndarray Symmetrized cross-term matrix (shape: dim x dim) """ return self.crossTermQBulk(E - self.dFermi)
[docs] def crossTermQTot(self, E): """Total cross-term Q_sym (single bulk contact). Parameters ---------- E : float Energy point in eV Returns ------- ndarray Symmetrized cross-term matrix (shape: dim x dim) """ return self.crossTermQ(E, 0)
[docs] def DOS(self, E, conv=1e-4, mix=0.1): """ Use surface Green's function to calculate density of states. Parameters ---------- E : float Energy point for DOS calculation (in eV) conv : float, optional Convergence criterion for self-energy calculation (default: 1e-5) mix : float, optional Mixing parameter for self-energy convergence (default: 0.5) Returns ------- float Density of states at energy E """ E_shifted = E - self.dFermi sig = self.sigmaSurf(E_shifted, conv=conv, mix=mix) # dim x dim Gr = LA.inv((E_shifted + self.eta*1j)*jnp.eye(dim) - self.H0 - sig) return -jnp.trace(Gr).imag / jnp.pi
[docs] def generate_band_plot(self, E_fermi=0.0, n_points=40, plot=True, save_path=None): """ Generate band structure plot along high-symmetry path for FCC. Uses the pre-computed reciprocal lattice vectors (b1_3D, b2_3D, b3_3D) to convert fractional k-coordinates to Cartesian k-space. Parameters ---------- E_fermi : float, optional Fermi energy to shift bands relative to (default: 0.0) n_points : int, optional Number of points between each high-symmetry point (default: 40) plot : bool, optional Whether to generate matplotlib plot (default: True) save_path : str, optional Path to save plot (default: None, display only) Returns ------- dict Dictionary containing: - 'distances': 1D array of k-path distances - 'bands': 2D array (n_kpoints x n_bands) of eigenvalues - 'k_labels': List of high-symmetry point labels - 'k_positions': Positions of high-symmetry points along path """ # Compute correct FCC high-symmetry k-points for the [111]-frame # rhombohedral primitive cell (a1=vecs[0], a2=vecs[1], a3=vecs[3]) B_inv = jnp.linalg.inv(jnp.column_stack([self.b1_3D, self.b2_3D, self.b3_3D])) ex = jnp.array([1, -1, 0]) / jnp.sqrt(2) ey = jnp.array([1, 1, -2]) / jnp.sqrt(6) ez = jnp.array([1, 1, 1]) / jnp.sqrt(3) R = jnp.array([ex, ey, ez]) scale = 2 * jnp.pi / jnp.sqrt(2) # nearest-neighbor distance = 1 -> a = sqrt(2) def _kpt(cubic_vec): return B_inv @ (R @ (scale * jnp.array(cubic_vec))) k_points_special = { 'G': jnp.zeros(3), 'X': _kpt([1.0, 0.0, 0.0]), 'W': _kpt([1.0, 0.5, 0.0]), 'L': _kpt([0.5, 0.5, 0.5]), 'K': _kpt([0.75, 0.75, 0.0]), } path = ['G', 'X', 'W', 'L', 'G', 'K'] # Generate k-path k_path_frac = [] k_labels = [] k_positions = [] distance = 0.0 for i in range(len(path) - 1): k_start = k_points_special[path[i]] k_end = k_points_special[path[i+1]] if i > 0: k_labels.append('') k_positions.append(distance) k_labels.append(path[i]) k_positions.append(distance) for j in range(n_points): t = j / (n_points - 1) k_frac = k_start + t * (k_end - k_start) k_path_frac.append(k_frac) if j > 0 and len(k_path_frac) > 1: # Convert to Cartesian k-space for distance calculation k_cart_prev = (k_path_frac[-2][0] * self.b1_3D + k_path_frac[-2][1] * self.b2_3D + k_path_frac[-2][2] * self.b3_3D) k_cart = (k_frac[0] * self.b1_3D + k_frac[1] * self.b2_3D + k_frac[2] * self.b3_3D) dk = jnp.linalg.norm(k_cart - k_cart_prev) distance += dk k_labels.append(path[-1]) k_positions.append(distance) # Calculate band structure bands = [] distances = jnp.zeros(len(k_path_frac)) for idx, k_frac in enumerate(k_path_frac): # Convert fractional to Cartesian k-space k_cart = (k_frac[0] * self.b1_3D + k_frac[1] * self.b2_3D + k_frac[2] * self.b3_3D) # Build H(k) and S(k) using Bloch sum H_k = jnp.array(self.H, dtype=complex) S_k = jnp.eye(dim, dtype=complex) for i, vec in enumerate(self.vecs): phase = jnp.exp(1j * jnp.dot(k_cart, vec)) H_k += phase * self.Vlist[i] S_k += phase * self.Slist[i] # Solve generalized eigenvalue problem: H|psi> = E S|psi> # Transform to standard eigenvalue problem via symmetric Lowdin: # X H X |phi> = E |phi> with X = S^(-1/2). # X H X is Hermitian (both X and H Hermitian, conjugation preserves # Hermiticity), so eigh is correct. The previous form S_inv @ H is # NOT Hermitian for non-trivial S, and eigh would silently take its # upper triangle and return wrong band eigenvalues. (Comment and # math now agree.) fractional_matrix_power is @jit-decorated, so # the per-k cost is dominated by a small eigh on S_k. X_k = fractional_matrix_power(S_k, -0.5) H_transformed = X_k @ H_k @ X_k evals, _ = LA.eigh(H_transformed) bands.append(jnp.sort(jnp.real(evals)) - E_fermi) # Calculate cumulative distance if idx > 0: k_cart_prev = (k_path_frac[idx-1][0] * self.b1_3D + k_path_frac[idx-1][1] * self.b2_3D + k_path_frac[idx-1][2] * self.b3_3D) dk = jnp.linalg.norm(k_cart - k_cart_prev) distances = distances.at[idx].set(distances[idx-1] + dk) bands = jnp.array(bands) # Shape: (n_kpoints, n_bands) # Generate plot if requested if plot: try: import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(10, 6)) # Plot all bands for i in range(dim): ax.plot(distances, bands[:, i], 'b-', linewidth=1.5) # Fermi level ax.axhline(0, color='r', linestyle='--', linewidth=2, label='E_F') # High-symmetry point labels ax.set_xticks(k_positions) ax.set_xticklabels(k_labels, fontsize=12) # Vertical lines at high-symmetry points for pos in k_positions: ax.axvline(pos, color='k', linestyle='-', linewidth=0.5, alpha=0.3) ax.set_ylabel('Energy - E$_F$ (eV)', fontsize=12) ax.set_title('Band Structure (surfGAt3D)', fontsize=14, fontweight='bold') ax.set_ylim([-8, 5]) ax.grid(True, alpha=0.3, axis='y') ax.legend(fontsize=10) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Band structure plot saved to: {save_path}") else: plt.show() except ImportError: print("Warning: matplotlib not available, skipping plot generation") return { 'distances': distances, 'bands': bands, 'k_labels': k_labels, 'k_positions': k_positions, }
# Calculate fermi energy using bisection (to specified tolerance)
[docs] def calcFermi(self, ne, tol=1e-5): """ Calculate Fermi energy using bisection method. Uses getFermiContact from density.py to find the Fermi energy that gives the correct number of electrons. Parameters ---------- ne : float Target number of electrons tol : float, optional Convergence tolerance (default: 1e-5) Returns ------- float Calculated Fermi energy in eV """ print('Calculating Bulk Lattice Fermi Energy...') self.fermi = getFermiContact(self, ne, conv=tol, maxcycles=1000, T=self.T) if self.fermi0 is None: self.fermi0 = self.fermi return self.fermi