scirs2_integrate/specialized/quantum/
basis_sets.rs1use crate::error::{IntegrateError, IntegrateResult as Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::numeric::Complex64;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct AdvancedBasisSets {
14 pub n_basis: usize,
16 pub basis_type: BasisSetType,
18 pub parameters: Vec<BasisParameter>,
20 pub overlap_matrix: Array2<f64>,
22}
23
24impl AdvancedBasisSets {
25 pub fn new(n_basis: usize, basistype: BasisSetType) -> Self {
27 let parameters = vec![BasisParameter::default(); n_basis];
28 let overlap_matrix = Array2::eye(n_basis);
29
30 Self {
31 n_basis,
32 basis_type: basistype,
33 parameters,
34 overlap_matrix,
35 }
36 }
37
38 pub fn generate_basis_functions(&self, coordinates: &Array2<f64>) -> Result<Array2<Complex64>> {
40 let n_points = coordinates.nrows();
41 let mut basis_functions = Array2::zeros((n_points, self.n_basis));
42
43 match self.basis_type {
44 BasisSetType::Gaussian => {
45 self.generate_gaussian_basis(coordinates, &mut basis_functions)?;
46 }
47 BasisSetType::SlaterType => {
48 self.generate_slater_basis(coordinates, &mut basis_functions)?;
49 }
50 BasisSetType::PlaneWave => {
51 self.generate_plane_wave_basis(coordinates, &mut basis_functions)?;
52 }
53 BasisSetType::Atomic => {
54 self.generate_atomic_basis(coordinates, &mut basis_functions)?;
55 }
56 }
57
58 Ok(basis_functions)
59 }
60
61 fn generate_gaussian_basis(
63 &self,
64 coordinates: &Array2<f64>,
65 basis_functions: &mut Array2<Complex64>,
66 ) -> Result<()> {
67 for (i, param) in self.parameters.iter().enumerate() {
68 for (j, coord_row) in coordinates
69 .axis_iter(scirs2_core::ndarray::Axis(0))
70 .enumerate()
71 {
72 let x = coord_row[0];
73 let y = if coord_row.len() > 1 {
74 coord_row[1]
75 } else {
76 0.0
77 };
78 let z = if coord_row.len() > 2 {
79 coord_row[2]
80 } else {
81 0.0
82 };
83
84 let r_squared = (x - param.center_x).powi(2)
85 + (y - param.center_y).powi(2)
86 + (z - param.center_z).powi(2);
87
88 let gaussian = (-param.exponent * r_squared).exp();
89 basis_functions[[j, i]] = Complex64::new(gaussian * param.normalization, 0.0);
90 }
91 }
92
93 Ok(())
94 }
95
96 fn generate_slater_basis(
98 &self,
99 coordinates: &Array2<f64>,
100 basis_functions: &mut Array2<Complex64>,
101 ) -> Result<()> {
102 for (i, param) in self.parameters.iter().enumerate() {
103 for (j, coord_row) in coordinates
104 .axis_iter(scirs2_core::ndarray::Axis(0))
105 .enumerate()
106 {
107 let x = coord_row[0];
108 let y = if coord_row.len() > 1 {
109 coord_row[1]
110 } else {
111 0.0
112 };
113 let z = if coord_row.len() > 2 {
114 coord_row[2]
115 } else {
116 0.0
117 };
118
119 let r = ((x - param.center_x).powi(2)
120 + (y - param.center_y).powi(2)
121 + (z - param.center_z).powi(2))
122 .sqrt();
123
124 let slater = r.powf(param.angular_momentum as f64) * (-param.exponent * r).exp();
125 basis_functions[[j, i]] = Complex64::new(slater * param.normalization, 0.0);
126 }
127 }
128
129 Ok(())
130 }
131
132 fn generate_plane_wave_basis(
134 &self,
135 coordinates: &Array2<f64>,
136 basis_functions: &mut Array2<Complex64>,
137 ) -> Result<()> {
138 use scirs2_core::constants::PI;
139
140 for (i, param) in self.parameters.iter().enumerate() {
141 for (j, coord_row) in coordinates
142 .axis_iter(scirs2_core::ndarray::Axis(0))
143 .enumerate()
144 {
145 let x = coord_row[0];
146 let y = if coord_row.len() > 1 {
147 coord_row[1]
148 } else {
149 0.0
150 };
151 let z = if coord_row.len() > 2 {
152 coord_row[2]
153 } else {
154 0.0
155 };
156
157 let k_dot_r = param.kx * x + param.ky * y + param.kz * z;
158 let plane_wave = Complex64::new(
159 (k_dot_r).cos() * param.normalization,
160 (k_dot_r).sin() * param.normalization,
161 );
162 basis_functions[[j, i]] = plane_wave;
163 }
164 }
165
166 Ok(())
167 }
168
169 fn generate_atomic_basis(
171 &self,
172 coordinates: &Array2<f64>,
173 basis_functions: &mut Array2<Complex64>,
174 ) -> Result<()> {
175 for (i, param) in self.parameters.iter().enumerate() {
177 for (j, coord_row) in coordinates
178 .axis_iter(scirs2_core::ndarray::Axis(0))
179 .enumerate()
180 {
181 let x = coord_row[0];
182 let y = if coord_row.len() > 1 {
183 coord_row[1]
184 } else {
185 0.0
186 };
187 let z = if coord_row.len() > 2 {
188 coord_row[2]
189 } else {
190 0.0
191 };
192
193 let r = ((x - param.center_x).powi(2)
194 + (y - param.center_y).powi(2)
195 + (z - param.center_z).powi(2))
196 .sqrt();
197
198 let radial = r.powf(param.angular_momentum as f64) * (-param.exponent * r).exp();
200 let orbital = radial * param.normalization;
201 basis_functions[[j, i]] = Complex64::new(orbital, 0.0);
202 }
203 }
204
205 Ok(())
206 }
207
208 pub fn calculate_overlap_matrix(&mut self, coordinates: &Array2<f64>) -> Result<()> {
210 let basis_functions = self.generate_basis_functions(coordinates)?;
211 let n_points = coordinates.nrows();
212
213 self.overlap_matrix = Array2::zeros((self.n_basis, self.n_basis));
214
215 for i in 0..self.n_basis {
216 for j in 0..self.n_basis {
217 let mut overlap = 0.0;
218 for k in 0..n_points {
219 overlap += (basis_functions[[k, i]].conj() * basis_functions[[k, j]]).re;
220 }
221 self.overlap_matrix[[i, j]] = overlap;
222 }
223 }
224
225 Ok(())
226 }
227
228 pub fn orthogonalize_basis(&mut self) -> Result<()> {
230 for i in 1..self.n_basis {
232 for j in 0..i {
233 let overlap = self.overlap_matrix[[i, j]];
234 if overlap.abs() > 1e-12 {
235 let norm_j = self.overlap_matrix[[j, j]].sqrt();
237 if norm_j > 1e-12 {
238 let projection_coeff = overlap / norm_j;
239 self.parameters[i].normalization -=
240 projection_coeff * self.parameters[j].normalization;
241 }
242 }
243 }
244 }
245
246 Ok(())
247 }
248
249 pub fn transform_basis(
251 &self,
252 transformation_matrix: &Array2<f64>,
253 ) -> Result<AdvancedBasisSets> {
254 if transformation_matrix.nrows() != self.n_basis
255 || transformation_matrix.ncols() != self.n_basis
256 {
257 return Err(IntegrateError::InvalidInput(
258 "Transformation matrix dimension mismatch".to_string(),
259 ));
260 }
261
262 let mut transformed_basis = self.clone();
263
264 for i in 0..self.n_basis {
266 let mut new_normalization = 0.0;
267 for j in 0..self.n_basis {
268 new_normalization +=
269 transformation_matrix[[i, j]] * self.parameters[j].normalization;
270 }
271 transformed_basis.parameters[i].normalization = new_normalization;
272 }
273
274 let overlap_transformed = transformation_matrix
276 .t()
277 .dot(&self.overlap_matrix)
278 .dot(transformation_matrix);
279 transformed_basis.overlap_matrix = overlap_transformed;
280
281 Ok(transformed_basis)
282 }
283}
284
285#[derive(Debug, Clone, Copy)]
287pub enum BasisSetType {
288 Gaussian,
290 SlaterType,
292 PlaneWave,
294 Atomic,
296}
297
298#[derive(Debug, Clone)]
300pub struct BasisParameter {
301 pub exponent: f64,
303 pub normalization: f64,
305 pub angular_momentum: i32,
307 pub center_x: f64,
309 pub center_y: f64,
310 pub center_z: f64,
311 pub kx: f64,
313 pub ky: f64,
314 pub kz: f64,
315}
316
317impl Default for BasisParameter {
318 fn default() -> Self {
319 Self {
320 exponent: 1.0,
321 normalization: 1.0,
322 angular_momentum: 0,
323 center_x: 0.0,
324 center_y: 0.0,
325 center_z: 0.0,
326 kx: 0.0,
327 ky: 0.0,
328 kz: 0.0,
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use approx::assert_relative_eq;
337
338 #[test]
339 fn test_basis_set_creation() {
340 let basis = AdvancedBasisSets::new(5, BasisSetType::Gaussian);
341 assert_eq!(basis.n_basis, 5);
342 assert_eq!(basis.parameters.len(), 5);
343 assert_eq!(basis.overlap_matrix.nrows(), 5);
344 assert_eq!(basis.overlap_matrix.ncols(), 5);
345 }
346
347 #[test]
348 fn test_gaussian_basis_generation() {
349 let mut basis = AdvancedBasisSets::new(2, BasisSetType::Gaussian);
350
351 basis.parameters[0].exponent = 1.0;
353 basis.parameters[0].normalization = 1.0;
354 basis.parameters[1].exponent = 2.0;
355 basis.parameters[1].normalization = 1.0;
356 basis.parameters[1].center_x = 1.0;
357
358 let coordinates =
359 Array2::from_shape_vec((3, 3), vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.5, 0.0])
360 .unwrap();
361
362 let basis_functions = basis.generate_basis_functions(&coordinates);
363 assert!(basis_functions.is_ok());
364
365 let functions = basis_functions.unwrap();
366 assert_eq!(functions.nrows(), 3);
367 assert_eq!(functions.ncols(), 2);
368 }
369
370 #[test]
371 fn test_overlap_matrix_calculation() {
372 let mut basis = AdvancedBasisSets::new(2, BasisSetType::Gaussian);
373
374 let coordinates = Array2::from_shape_vec(
375 (10, 3),
376 vec![
377 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.2, 0.0, 0.0, 0.3, 0.0, 0.0, 0.4, 0.0, 0.0, 0.5,
378 0.0, 0.0, 0.6, 0.0, 0.0, 0.7, 0.0, 0.0, 0.8, 0.0, 0.0, 0.9, 0.0, 0.0,
379 ],
380 )
381 .unwrap();
382
383 let result = basis.calculate_overlap_matrix(&coordinates);
384 assert!(result.is_ok());
385
386 for i in 0..basis.n_basis {
388 assert!(basis.overlap_matrix[[i, i]] > 0.0);
389 }
390 }
391}