1use nalgebra::DMatrix;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct KPoint {
13 pub frac: [f64; 3],
15 pub label: Option<String>,
17 pub path_distance: f64,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct BandStructure {
24 pub kpoints: Vec<KPoint>,
26 pub bands: Vec<Vec<f64>>,
28 pub n_bands: usize,
30 pub n_kpoints: usize,
32 pub fermi_energy: f64,
34 pub direct_gap: Option<f64>,
36 pub indirect_gap: Option<f64>,
38 pub high_symmetry_points: Vec<(String, usize)>,
40}
41
42#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
44pub struct BandStructureConfig {
45 pub n_kpoints_per_segment: usize,
47 pub path: Vec<([f64; 3], String)>,
49}
50
51impl Default for BandStructureConfig {
52 fn default() -> Self {
53 Self {
55 n_kpoints_per_segment: 50,
56 path: vec![
57 ([0.0, 0.0, 0.0], "Γ".to_string()),
58 ([0.5, 0.0, 0.0], "X".to_string()),
59 ([0.5, 0.5, 0.0], "M".to_string()),
60 ([0.0, 0.0, 0.0], "Γ".to_string()),
61 ],
62 }
63 }
64}
65
66pub fn compute_band_structure(
74 elements: &[u8],
75 positions: &[[f64; 3]],
76 lattice: &[[f64; 3]; 3],
77 config: &BandStructureConfig,
78 n_electrons: usize,
79) -> Result<BandStructure, String> {
80 if elements.is_empty() {
81 return Err("No atoms provided".to_string());
82 }
83
84 let kpoints = generate_kpath(&config.path, config.n_kpoints_per_segment);
86 let n_kpts = kpoints.len();
87
88 let eht_result = crate::eht::solve_eht(elements, positions, None)?;
90 let n_basis = eht_result.energies.len();
91
92 let mut bands = Vec::with_capacity(n_kpts);
94 let mut high_sym = Vec::new();
95
96 for (k_idx, kpt) in kpoints.iter().enumerate() {
97 let (h_k, s_k) = build_bloch_matrices(elements, positions, lattice, &kpt.frac, n_basis);
99
100 let eigenvalues = solve_generalized_eigen(&h_k, &s_k)?;
102 bands.push(eigenvalues);
103
104 if let Some(ref label) = kpt.label {
105 high_sym.push((label.clone(), k_idx));
106 }
107 }
108
109 let n_occupied = n_electrons / 2;
111 let fermi_energy = estimate_fermi_energy(&bands, n_occupied);
112
113 let (direct_gap, indirect_gap) = compute_band_gaps(&bands, n_occupied);
115
116 Ok(BandStructure {
117 kpoints,
118 bands,
119 n_bands: n_basis,
120 n_kpoints: n_kpts,
121 fermi_energy,
122 direct_gap,
123 indirect_gap,
124 high_symmetry_points: high_sym,
125 })
126}
127
128fn generate_kpath(path: &[([f64; 3], String)], n_per_segment: usize) -> Vec<KPoint> {
130 let mut kpoints = Vec::new();
131 let mut path_dist = 0.0;
132
133 for i in 0..path.len() {
134 let (k, label) = &path[i];
135
136 if i == 0 {
137 kpoints.push(KPoint {
138 frac: *k,
139 label: Some(label.clone()),
140 path_distance: 0.0,
141 });
142 continue;
143 }
144
145 let (k_prev, _) = &path[i - 1];
146 let dk = [k[0] - k_prev[0], k[1] - k_prev[1], k[2] - k_prev[2]];
147 let seg_len = (dk[0] * dk[0] + dk[1] * dk[1] + dk[2] * dk[2]).sqrt();
148
149 for j in 1..=n_per_segment {
150 let t = j as f64 / n_per_segment as f64;
151 let frac = [
152 k_prev[0] + t * dk[0],
153 k_prev[1] + t * dk[1],
154 k_prev[2] + t * dk[2],
155 ];
156 let is_endpoint = j == n_per_segment;
157 path_dist += seg_len / n_per_segment as f64;
158
159 kpoints.push(KPoint {
160 frac,
161 label: if is_endpoint {
162 Some(label.clone())
163 } else {
164 None
165 },
166 path_distance: path_dist,
167 });
168 }
169 }
170
171 kpoints
172}
173
174fn build_bloch_matrices(
176 elements: &[u8],
177 positions: &[[f64; 3]],
178 lattice: &[[f64; 3]; 3],
179 k: &[f64; 3],
180 n_basis: usize,
181) -> (DMatrix<f64>, DMatrix<f64>) {
182 let basis = crate::eht::basis::build_basis(elements, positions);
184 let s_0 = crate::eht::overlap::build_overlap_matrix(&basis);
185 let h_0 = crate::eht::hamiltonian::build_hamiltonian(&basis, &s_0, None);
186
187 let n = n_basis.min(s_0.nrows());
188
189 let mut h_k = DMatrix::zeros(n, n);
191 let mut s_k = DMatrix::zeros(n, n);
192
193 for i in 0..n {
194 for j in 0..n {
195 h_k[(i, j)] = h_0[(i, j)];
196 s_k[(i, j)] = s_0[(i, j)];
197 }
198 }
199
200 let translations: Vec<[i32; 3]> = vec![
202 [1, 0, 0],
203 [-1, 0, 0],
204 [0, 1, 0],
205 [0, -1, 0],
206 [0, 0, 1],
207 [0, 0, -1],
208 ];
209
210 for r in &translations {
211 let phase = 2.0
212 * std::f64::consts::PI
213 * (k[0] * r[0] as f64 + k[1] * r[1] as f64 + k[2] * r[2] as f64);
214 let cos_phase = phase.cos();
215
216 let translated: Vec<[f64; 3]> = positions
218 .iter()
219 .map(|p| {
220 [
221 p[0] + r[0] as f64 * lattice[0][0]
222 + r[1] as f64 * lattice[1][0]
223 + r[2] as f64 * lattice[2][0],
224 p[1] + r[0] as f64 * lattice[0][1]
225 + r[1] as f64 * lattice[1][1]
226 + r[2] as f64 * lattice[2][1],
227 p[2] + r[0] as f64 * lattice[0][2]
228 + r[1] as f64 * lattice[1][2]
229 + r[2] as f64 * lattice[2][2],
230 ]
231 })
232 .collect();
233
234 let basis_r = crate::eht::basis::build_basis(elements, &translated);
236 let mut combined = basis.clone();
239 combined.extend_from_slice(&basis_r);
240 let s_combined = crate::eht::overlap::build_overlap_matrix(&combined);
241 let s_r = s_combined.view((0, n), (n, basis_r.len())).clone_owned();
242 let h_r = build_intercell_hamiltonian(&basis, &basis_r, &s_r);
243
244 let nr = n.min(s_r.nrows()).min(s_r.ncols());
245 for i in 0..nr {
246 for j in 0..nr {
247 h_k[(i, j)] += cos_phase * h_r[(i, j)];
248 s_k[(i, j)] += cos_phase * s_r[(i, j)];
249 }
250 }
251 }
252
253 (h_k, s_k)
254}
255
256fn build_intercell_hamiltonian(
258 basis_0: &[crate::eht::basis::AtomicOrbital],
259 basis_r: &[crate::eht::basis::AtomicOrbital],
260 s_0r: &DMatrix<f64>,
261) -> DMatrix<f64> {
262 let n = basis_0.len().min(s_0r.nrows());
263 let m = basis_r.len().min(s_0r.ncols());
264 let mut h = DMatrix::zeros(n, m);
265 let k_wh = 1.75; for i in 0..n {
268 let hii = basis_0[i].vsip;
269 for j in 0..m {
270 let hjj = basis_r[j].vsip;
271 h[(i, j)] = 0.5 * k_wh * (hii + hjj) * s_0r[(i, j)];
272 }
273 }
274
275 h
276}
277
278fn solve_generalized_eigen(h: &DMatrix<f64>, s: &DMatrix<f64>) -> Result<Vec<f64>, String> {
280 let n = h.nrows();
281 if n == 0 {
282 return Ok(vec![]);
283 }
284
285 let s_eigen = nalgebra::SymmetricEigen::new(s.clone());
287 let mut s_inv_sqrt = DMatrix::zeros(n, n);
288
289 for (i, &eval) in s_eigen.eigenvalues.iter().enumerate() {
290 if eval > 1e-8 {
291 let inv_sqrt = 1.0 / eval.sqrt();
292 for j in 0..n {
293 for k in 0..n {
294 s_inv_sqrt[(j, k)] +=
295 inv_sqrt * s_eigen.eigenvectors[(j, i)] * s_eigen.eigenvectors[(k, i)];
296 }
297 }
298 }
299 }
300
301 let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
303 let eigen = nalgebra::SymmetricEigen::new(h_prime);
304
305 let mut eigenvalues: Vec<f64> = eigen.eigenvalues.iter().copied().collect();
306 eigenvalues.sort_by(|a, b| a.partial_cmp(b).unwrap());
307
308 Ok(eigenvalues)
309}
310
311fn estimate_fermi_energy(bands: &[Vec<f64>], n_occupied: usize) -> f64 {
313 if bands.is_empty() || n_occupied == 0 {
314 return 0.0;
315 }
316
317 let mut all_occupied: Vec<f64> = bands
319 .iter()
320 .filter_map(|eigenvals| {
321 if eigenvals.len() > n_occupied {
322 Some((eigenvals[n_occupied - 1] + eigenvals[n_occupied]) / 2.0)
323 } else {
324 eigenvals.last().copied()
325 }
326 })
327 .collect();
328
329 all_occupied.sort_by(|a, b| a.partial_cmp(b).unwrap());
330 if all_occupied.is_empty() {
331 return 0.0;
332 }
333 all_occupied[all_occupied.len() / 2]
334}
335
336fn compute_band_gaps(bands: &[Vec<f64>], n_occupied: usize) -> (Option<f64>, Option<f64>) {
338 if bands.is_empty() || n_occupied == 0 {
339 return (None, None);
340 }
341
342 let mut min_direct = f64::MAX;
343 let mut max_vb = f64::MIN;
344 let mut min_cb = f64::MAX;
345
346 for eigenvals in bands {
347 if eigenvals.len() <= n_occupied {
348 continue;
349 }
350 let vb_top = eigenvals[n_occupied - 1];
351 let cb_bottom = eigenvals[n_occupied];
352
353 let gap = cb_bottom - vb_top;
354 if gap < min_direct && gap > 0.0 {
355 min_direct = gap;
356 }
357
358 if vb_top > max_vb {
359 max_vb = vb_top;
360 }
361 if cb_bottom < min_cb {
362 min_cb = cb_bottom;
363 }
364 }
365
366 let direct = if min_direct < f64::MAX {
367 Some(min_direct)
368 } else {
369 None
370 };
371
372 let indirect = if min_cb > max_vb {
373 Some(min_cb - max_vb)
374 } else {
375 None
376 };
377
378 (direct, indirect)
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_generate_kpath() {
387 let config = BandStructureConfig::default();
388 let kpoints = generate_kpath(&config.path, 10);
389 assert!(!kpoints.is_empty());
390 assert_eq!(kpoints[0].label.as_deref(), Some("Γ"));
392 }
393
394 #[test]
395 fn test_band_gaps() {
396 let bands = vec![vec![-5.0, -3.0, 1.0, 3.0], vec![-4.5, -2.5, 1.5, 3.5]];
397 let (direct, indirect) = compute_band_gaps(&bands, 2);
398 assert!(direct.is_some());
399 assert!(indirect.is_some());
400 assert!(indirect.unwrap() > 0.0);
401 }
402}