Source code for gauNEGF.integrate

"""
JAX-powered Green's Functions Integration

Supports both retarded (Gr) and lesser (G<) Green's functions.

Author: William Livernois
"""

import jax
import jax.numpy as jnp
import numpy as np
import os
import time
import socket
import tempfile
import logging
from jax import jit

# Enable double precision for accurate comparisons with NumPy
jax.config.update("jax_enable_x64", True)

# Setup node-specific logging for integration operations
hostname = socket.gethostname()
pid = os.getpid()

from gauNEGF.config import LOG_LEVEL, LOG_PERFORMANCE

if LOG_PERFORMANCE:
    log_file = f'integrate_performance_{hostname}_{pid}.log'
else:
    temp_dir = tempfile.gettempdir()
    log_file = os.path.join(temp_dir, f'integrate_performance_{hostname}_{pid}.log')

log_level = getattr(logging, LOG_LEVEL.upper(), logging.DEBUG)

parallel_logger = logging.getLogger('gauNEGF.integrate')
parallel_logger.setLevel(log_level)

# Create file handler that appends (avoid duplicate handlers on reload)
if not parallel_logger.handlers:
    handler = logging.FileHandler(log_file, mode='a')
    handler.setFormatter(logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s'
    ))
    parallel_logger.addHandler(handler)

parallel_logger.info("JAX integration framework initialized")
parallel_logger.debug(f"Number of devices: {len(jax.devices())}")
parallel_logger.debug(f"Device List: {jax.devices()}")

# =============================================================================
# INTEGRATION-SPECIFIC CONSTANTS
# =============================================================================

MAX_VMAP_MEMORY_GB = 5.0              # Use vmap if estimated memory < this (GB)
FORCE_SYNCHRONOUS = True              # Force synchronous operation (for accurate timing)

# Memory calculation constants
MEMORY_PER_MATRIX_FACTOR = 16         # Bytes per complex128 element
BYTES_TO_GB = 1e9                     # Conversion factor

# =============================================================================
# MODULE-LEVEL JIT FUNCTIONS (clean, no nesting)
# =============================================================================

# Jit G^R function: (g is static)
@jit
def _gr_matrix_ops(sigTot, E, F, S):
    """Retarded Green's function matrix operations (used by both vmap and workers)."""
    mat = E * S - F - sigTot
    return jnp.linalg.solve(mat, jnp.eye(F.shape[0]))

# Jit G< function: (g, ind are static)
@jit
def _gless_matrix_ops(sig, sigTot, E, F, S):
    """Lesser Green's function matrix operations (used by both vmap and workers)."""
    mat = E * S - F - sigTot
    Gr_E = jnp.linalg.solve(mat, jnp.eye(F.shape[0]))
    Ga_E = jnp.conj(Gr_E).T
    gamma_E = 1j * (sig - jnp.conj(sig).T)
    gless = Gr_E @ gamma_E @ Ga_E
    return gless

def _GInt(weighted_func, F, S, g, Elist, weights, ind=None):
    assert Elist.size == weights.size, "Elist and weights must have the same length"
    assert F.shape == S.shape, "F and S must have the same shape"
    assert F.shape[0] == F.shape[1], "F and S must be square matrices"

    start_time = time.time()

    # Convert to JAX arrays
    F_jax = jnp.array(F)
    S_jax = jnp.array(S)
    Elist_jax = jnp.array(Elist)
    weights_jax = jnp.array(weights)

    # Decision logic: vmap for small matrices, workers for large matrices
    matrix_size = F.shape[0]
    num_energies = len(Elist)
    matrix_size_gb = (matrix_size * matrix_size * MEMORY_PER_MATRIX_FACTOR) / BYTES_TO_GB

    if num_energies * matrix_size_gb < MAX_VMAP_MEMORY_GB:
        parallel_logger.info(f"GInt using vmap: {matrix_size}x{matrix_size} matrix, {num_energies} energies, {num_energies*matrix_size_gb:.2f}GB")
        result = jax.vmap(weighted_func, in_axes=(0, 0, None, None))(Elist_jax, weights_jax, F_jax, S_jax)
        integrated = jnp.sum(result, axis=0)
        if FORCE_SYNCHRONOUS:
            jax.block_until_ready(integrated)
        elapsed = time.time() - start_time
        parallel_logger.debug(f"GInt vmap completed in {elapsed:.3f}s")
        return integrated

    else:
        batch_size = max(1, int(MAX_VMAP_MEMORY_GB//matrix_size_gb))
        parallel_logger.info(f"GInt using batched mapping: {matrix_size}x{matrix_size} matrix, {num_energies} energies, Batch size: {batch_size} ({MAX_VMAP_MEMORY_GB:.2f}GB/batch)")
        start_time = time.time()
        def scan_fn(carry, inputs):
            E_batch, w_batch = inputs
            result = jax.vmap(weighted_func, in_axes=(0, 0, None, None))(E_batch, w_batch, F_jax, S_jax)
            carry += jnp.sum(result, axis=0)
            count = jnp.ones(result.shape[0])
            return carry, count

        # Reshape into fixed-size batches
        n_batches = len(Elist) // batch_size
        Elist_batched = Elist_jax[:n_batches * batch_size].reshape(n_batches, batch_size)
        weights_batched = weights_jax[:n_batches * batch_size].reshape(n_batches, batch_size)
        Elist_tail = Elist_jax[n_batches*batch_size:]
        weights_tail = weights_jax[n_batches*batch_size:]

        # scan over batches (sequential), then vmap within each batch (parallel)
        result = jnp.zeros_like(F_jax, dtype=complex)
        result, count = jax.lax.scan(scan_fn, result, (Elist_batched, weights_batched))
        total = np.sum(count)
        if len(Elist_tail)>0:
            result, count2 = scan_fn(result, (Elist_tail, weights_tail))
            total += np.sum(count2)
        assert total == num_energies, f"Integration only used {total} points, expected {num_energies} points"
        if FORCE_SYNCHRONOUS:
            jax.block_until_ready(result)
        elapsed = time.time() - start_time
        parallel_logger.debug(f"GInt map completed in {elapsed:.3f}s")
        return result



[docs] def GrInt(F, S, g, Elist, weights): """ Integrate retarded Green's function over energy using JAX parallelization. Parameters ---------- F : ndarray Fock matrix (NxN) S : ndarray Overlap matrix (NxN) g : surfG object Surface Green's function calculator with sigmaTot(E) method Elist : ndarray Array of energies in eV (Mx1) weights : ndarray Array of weights for each energy (Mx1) Returns ------- ndarray Integrated retarded Green's function (NxN) """ def weighted_func_Gr(E, weight, F_jax, S_jax): sigTot = g.sigmaTot(E) Gr = _gr_matrix_ops(sigTot, E, F_jax, S_jax) return weight * Gr parallel_logger.info(f"Calculating G^R with GInt...") return _GInt(weighted_func_Gr, F, S, g, Elist, weights)
[docs] def GrLessInt(F, S, g, Elist, weights, ind=None): """ Integrate lesser Green's function over energy using JAX parallelization. Parameters ---------- F : ndarray Fock matrix (NxN) S : ndarray Overlap matrix (NxN) g : surfG object Surface Green's function calculator Elist : ndarray Array of energies in eV (Mx1) weights : ndarray Array of weights for each energy (Mx1) ind : int, optional Contact index for partial density calculation (default: None) Returns ------- ndarray Integrated lesser Green's function (NxN) """ def weighted_func_GrLess(E, weight, F_jax, S_jax): useTot = (ind is None) sigTot = g.sigmaTot(E) sigma = sigTot if useTot else g.sigma(E, ind) Gless = _gless_matrix_ops(sigma, sigTot, E, F_jax, S_jax) return weight * Gless parallel_logger.info(f"Calculating G< with GInt...") return _GInt(weighted_func_GrLess, F, S, g, Elist, weights, ind)