Skip to main content

sci_form/eht/
band_structure.rs

1//! EHT band structure calculation for periodic systems.
2//!
3//! Computes electronic band structure along high-symmetry k-paths
4//! in the Brillouin zone for crystalline materials with periodic
5//! boundary conditions.
6
7use nalgebra::DMatrix;
8use serde::{Deserialize, Serialize};
9
10/// A k-point in reciprocal space.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct KPoint {
13    /// Fractional coordinates in reciprocal space.
14    pub frac: [f64; 3],
15    /// Label (e.g., "Γ", "X", "M", "K").
16    pub label: Option<String>,
17    /// Linear path distance for plotting.
18    pub path_distance: f64,
19}
20
21/// Result of a band structure calculation.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct BandStructure {
24    /// k-points along the path.
25    pub kpoints: Vec<KPoint>,
26    /// Eigenvalues at each k-point: bands[k_idx][band_idx] in eV.
27    pub bands: Vec<Vec<f64>>,
28    /// Number of bands (= number of basis functions per cell).
29    pub n_bands: usize,
30    /// Number of k-points.
31    pub n_kpoints: usize,
32    /// Fermi energy estimate (eV).
33    pub fermi_energy: f64,
34    /// Direct band gap (eV), if any.
35    pub direct_gap: Option<f64>,
36    /// Indirect band gap (eV), if any.
37    pub indirect_gap: Option<f64>,
38    /// High-symmetry point labels and their k-indices.
39    pub high_symmetry_points: Vec<(String, usize)>,
40}
41
42/// Configuration for band structure calculation.
43#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
44pub struct BandStructureConfig {
45    /// Number of k-points between high-symmetry points.
46    pub n_kpoints_per_segment: usize,
47    /// High-symmetry path (pairs of labels).
48    pub path: Vec<([f64; 3], String)>,
49}
50
51impl Default for BandStructureConfig {
52    fn default() -> Self {
53        // Default: Γ → X → M → Γ for cubic systems
54        Self {
55            n_kpoints_per_segment: 50,
56            path: vec![
57                ([0.0, 0.0, 0.0], "Γ".to_string()),
58                ([0.5, 0.0, 0.0], "X".to_string()),
59                ([0.5, 0.5, 0.0], "M".to_string()),
60                ([0.0, 0.0, 0.0], "Γ".to_string()),
61            ],
62        }
63    }
64}
65
66/// Compute electronic band structure using EHT with Bloch's theorem.
67///
68/// For a periodic system, the EHT Hamiltonian at k-point **k** is:
69///   H(k) = Σ_R H(R) exp(i k·R)
70///   S(k) = Σ_R S(R) exp(i k·R)
71///
72/// where R are lattice translation vectors.
73pub fn compute_band_structure(
74    elements: &[u8],
75    positions: &[[f64; 3]],
76    lattice: &[[f64; 3]; 3],
77    config: &BandStructureConfig,
78    n_electrons: usize,
79) -> Result<BandStructure, String> {
80    if elements.is_empty() {
81        return Err("No atoms provided".to_string());
82    }
83
84    // Generate k-point path
85    let kpoints = generate_kpath(&config.path, config.n_kpoints_per_segment);
86    let n_kpts = kpoints.len();
87
88    // Build real-space EHT matrices for the unit cell and nearest neighbors
89    let eht_result = crate::eht::solve_eht(elements, positions, None)?;
90    let n_basis = eht_result.energies.len();
91
92    // At each k-point, solve the generalized eigenvalue problem
93    let mut bands = Vec::with_capacity(n_kpts);
94    let mut high_sym = Vec::new();
95
96    for (k_idx, kpt) in kpoints.iter().enumerate() {
97        // Build H(k) and S(k) using Bloch phase factors
98        let (h_k, s_k) = build_bloch_matrices(elements, positions, lattice, &kpt.frac, n_basis);
99
100        // Solve generalized eigenvalue problem H(k) C = S(k) C ε
101        let eigenvalues = solve_generalized_eigen(&h_k, &s_k)?;
102        bands.push(eigenvalues);
103
104        if let Some(ref label) = kpt.label {
105            high_sym.push((label.clone(), k_idx));
106        }
107    }
108
109    // Estimate Fermi energy
110    let n_occupied = n_electrons / 2;
111    let fermi_energy = estimate_fermi_energy(&bands, n_occupied);
112
113    // Compute band gaps
114    let (direct_gap, indirect_gap) = compute_band_gaps(&bands, n_occupied);
115
116    Ok(BandStructure {
117        kpoints,
118        bands,
119        n_bands: n_basis,
120        n_kpoints: n_kpts,
121        fermi_energy,
122        direct_gap,
123        indirect_gap,
124        high_symmetry_points: high_sym,
125    })
126}
127
128/// Generate a k-point path through the Brillouin zone.
129fn generate_kpath(path: &[([f64; 3], String)], n_per_segment: usize) -> Vec<KPoint> {
130    let mut kpoints = Vec::new();
131    let mut path_dist = 0.0;
132
133    for i in 0..path.len() {
134        let (k, label) = &path[i];
135
136        if i == 0 {
137            kpoints.push(KPoint {
138                frac: *k,
139                label: Some(label.clone()),
140                path_distance: 0.0,
141            });
142            continue;
143        }
144
145        let (k_prev, _) = &path[i - 1];
146        let dk = [k[0] - k_prev[0], k[1] - k_prev[1], k[2] - k_prev[2]];
147        let seg_len = (dk[0] * dk[0] + dk[1] * dk[1] + dk[2] * dk[2]).sqrt();
148
149        for j in 1..=n_per_segment {
150            let t = j as f64 / n_per_segment as f64;
151            let frac = [
152                k_prev[0] + t * dk[0],
153                k_prev[1] + t * dk[1],
154                k_prev[2] + t * dk[2],
155            ];
156            let is_endpoint = j == n_per_segment;
157            path_dist += seg_len / n_per_segment as f64;
158
159            kpoints.push(KPoint {
160                frac,
161                label: if is_endpoint {
162                    Some(label.clone())
163                } else {
164                    None
165                },
166                path_distance: path_dist,
167            });
168        }
169    }
170
171    kpoints
172}
173
174/// Build Bloch Hamiltonian and overlap matrices at k-point.
175fn build_bloch_matrices(
176    elements: &[u8],
177    positions: &[[f64; 3]],
178    lattice: &[[f64; 3]; 3],
179    k: &[f64; 3],
180    n_basis: usize,
181) -> (DMatrix<f64>, DMatrix<f64>) {
182    // For the R=0 image, use the standard EHT matrices
183    let basis = crate::eht::basis::build_basis(elements, positions);
184    let s_0 = crate::eht::overlap::build_overlap_matrix(&basis);
185    let h_0 = crate::eht::hamiltonian::build_hamiltonian(&basis, &s_0, None);
186
187    let n = n_basis.min(s_0.nrows());
188
189    // Start with R=0 contribution (phase factor = 1 for k·0 = 0)
190    let mut h_k = DMatrix::zeros(n, n);
191    let mut s_k = DMatrix::zeros(n, n);
192
193    for i in 0..n {
194        for j in 0..n {
195            h_k[(i, j)] = h_0[(i, j)];
196            s_k[(i, j)] = s_0[(i, j)];
197        }
198    }
199
200    // Add contributions from nearest-neighbor cells R = ±a, ±b, ±c
201    let translations: Vec<[i32; 3]> = vec![
202        [1, 0, 0],
203        [-1, 0, 0],
204        [0, 1, 0],
205        [0, -1, 0],
206        [0, 0, 1],
207        [0, 0, -1],
208    ];
209
210    for r in &translations {
211        let phase = 2.0
212            * std::f64::consts::PI
213            * (k[0] * r[0] as f64 + k[1] * r[1] as f64 + k[2] * r[2] as f64);
214        let cos_phase = phase.cos();
215
216        // Build translated positions
217        let translated: Vec<[f64; 3]> = positions
218            .iter()
219            .map(|p| {
220                [
221                    p[0] + r[0] as f64 * lattice[0][0]
222                        + r[1] as f64 * lattice[1][0]
223                        + r[2] as f64 * lattice[2][0],
224                    p[1] + r[0] as f64 * lattice[0][1]
225                        + r[1] as f64 * lattice[1][1]
226                        + r[2] as f64 * lattice[2][1],
227                    p[2] + r[0] as f64 * lattice[0][2]
228                        + r[1] as f64 * lattice[1][2]
229                        + r[2] as f64 * lattice[2][2],
230                ]
231            })
232            .collect();
233
234        // Build inter-cell overlap using combined basis
235        let basis_r = crate::eht::basis::build_basis(elements, &translated);
236        // Compute cross-overlap S_{0R} by building overlap of combined [basis, basis_r]
237        // and extracting the off-diagonal block
238        let mut combined = basis.clone();
239        combined.extend_from_slice(&basis_r);
240        let s_combined = crate::eht::overlap::build_overlap_matrix(&combined);
241        let s_r = s_combined.view((0, n), (n, basis_r.len())).clone_owned();
242        let h_r = build_intercell_hamiltonian(&basis, &basis_r, &s_r);
243
244        let nr = n.min(s_r.nrows()).min(s_r.ncols());
245        for i in 0..nr {
246            for j in 0..nr {
247                h_k[(i, j)] += cos_phase * h_r[(i, j)];
248                s_k[(i, j)] += cos_phase * s_r[(i, j)];
249            }
250        }
251    }
252
253    (h_k, s_k)
254}
255
256/// Build inter-cell Hamiltonian using Wolfsberg-Helmholz approximation.
257fn build_intercell_hamiltonian(
258    basis_0: &[crate::eht::basis::AtomicOrbital],
259    basis_r: &[crate::eht::basis::AtomicOrbital],
260    s_0r: &DMatrix<f64>,
261) -> DMatrix<f64> {
262    let n = basis_0.len().min(s_0r.nrows());
263    let m = basis_r.len().min(s_0r.ncols());
264    let mut h = DMatrix::zeros(n, m);
265    let k_wh = 1.75; // Wolfsberg-Helmholz K parameter
266
267    for i in 0..n {
268        let hii = basis_0[i].vsip;
269        for j in 0..m {
270            let hjj = basis_r[j].vsip;
271            h[(i, j)] = 0.5 * k_wh * (hii + hjj) * s_0r[(i, j)];
272        }
273    }
274
275    h
276}
277
278/// Solve generalized eigenvalue problem H C = S C ε via Löwdin orthogonalization.
279fn solve_generalized_eigen(h: &DMatrix<f64>, s: &DMatrix<f64>) -> Result<Vec<f64>, String> {
280    let n = h.nrows();
281    if n == 0 {
282        return Ok(vec![]);
283    }
284
285    // S^{-1/2} via eigendecomposition
286    let s_eigen = nalgebra::SymmetricEigen::new(s.clone());
287    let mut s_inv_sqrt = DMatrix::zeros(n, n);
288
289    for (i, &eval) in s_eigen.eigenvalues.iter().enumerate() {
290        if eval > 1e-8 {
291            let inv_sqrt = 1.0 / eval.sqrt();
292            for j in 0..n {
293                for k in 0..n {
294                    s_inv_sqrt[(j, k)] +=
295                        inv_sqrt * s_eigen.eigenvectors[(j, i)] * s_eigen.eigenvectors[(k, i)];
296                }
297            }
298        }
299    }
300
301    // H' = S^{-1/2} H S^{-1/2}
302    let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
303    let eigen = nalgebra::SymmetricEigen::new(h_prime);
304
305    let mut eigenvalues: Vec<f64> = eigen.eigenvalues.iter().copied().collect();
306    eigenvalues.sort_by(|a, b| a.partial_cmp(b).unwrap());
307
308    Ok(eigenvalues)
309}
310
311/// Estimate Fermi energy from band eigenvalues.
312fn estimate_fermi_energy(bands: &[Vec<f64>], n_occupied: usize) -> f64 {
313    if bands.is_empty() || n_occupied == 0 {
314        return 0.0;
315    }
316
317    // Collect all occupied orbital energies
318    let mut all_occupied: Vec<f64> = bands
319        .iter()
320        .filter_map(|eigenvals| {
321            if eigenvals.len() > n_occupied {
322                Some((eigenvals[n_occupied - 1] + eigenvals[n_occupied]) / 2.0)
323            } else {
324                eigenvals.last().copied()
325            }
326        })
327        .collect();
328
329    all_occupied.sort_by(|a, b| a.partial_cmp(b).unwrap());
330    if all_occupied.is_empty() {
331        return 0.0;
332    }
333    all_occupied[all_occupied.len() / 2]
334}
335
336/// Compute direct and indirect band gaps.
337fn compute_band_gaps(bands: &[Vec<f64>], n_occupied: usize) -> (Option<f64>, Option<f64>) {
338    if bands.is_empty() || n_occupied == 0 {
339        return (None, None);
340    }
341
342    let mut min_direct = f64::MAX;
343    let mut max_vb = f64::MIN;
344    let mut min_cb = f64::MAX;
345
346    for eigenvals in bands {
347        if eigenvals.len() <= n_occupied {
348            continue;
349        }
350        let vb_top = eigenvals[n_occupied - 1];
351        let cb_bottom = eigenvals[n_occupied];
352
353        let gap = cb_bottom - vb_top;
354        if gap < min_direct && gap > 0.0 {
355            min_direct = gap;
356        }
357
358        if vb_top > max_vb {
359            max_vb = vb_top;
360        }
361        if cb_bottom < min_cb {
362            min_cb = cb_bottom;
363        }
364    }
365
366    let direct = if min_direct < f64::MAX {
367        Some(min_direct)
368    } else {
369        None
370    };
371
372    let indirect = if min_cb > max_vb {
373        Some(min_cb - max_vb)
374    } else {
375        None
376    };
377
378    (direct, indirect)
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_generate_kpath() {
387        let config = BandStructureConfig::default();
388        let kpoints = generate_kpath(&config.path, 10);
389        assert!(!kpoints.is_empty());
390        // First point should be Γ
391        assert_eq!(kpoints[0].label.as_deref(), Some("Γ"));
392    }
393
394    #[test]
395    fn test_band_gaps() {
396        let bands = vec![vec![-5.0, -3.0, 1.0, 3.0], vec![-4.5, -2.5, 1.5, 3.5]];
397        let (direct, indirect) = compute_band_gaps(&bands, 2);
398        assert!(direct.is_some());
399        assert!(indirect.is_some());
400        assert!(indirect.unwrap() > 0.0);
401    }
402}