1use crate::error::{IntegrateError, IntegrateResult as Result};
7use scirs2_core::constants::{PI, REDUCED_PLANCK};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::simd_ops::SimdUnifiedOps;
11
12#[derive(Debug, Clone)]
14pub struct QuantumState {
15 pub psi: Array1<Complex64>,
17 pub x: Array1<f64>,
19 pub t: f64,
21 pub mass: f64,
23 pub dx: f64,
25}
26
27impl QuantumState {
28 pub fn new(psi: Array1<Complex64>, x: Array1<f64>, t: f64, mass: f64) -> Self {
30 let dx = if x.len() > 1 { x[1] - x[0] } else { 1.0 };
31
32 Self {
33 psi,
34 x,
35 t,
36 mass,
37 dx,
38 }
39 }
40
41 pub fn normalize(&mut self) {
43 let norm_squared: f64 = self.psi.iter().map(|&c| (c.conj() * c).re).sum::<f64>() * self.dx;
44
45 let norm = norm_squared.sqrt();
46 if norm > 0.0 {
47 self.psi.mapv_inplace(|c| c / norm);
48 }
49 }
50
51 pub fn expectation_position(&self) -> f64 {
53 self.expectation_position_simd()
54 }
55
56 pub fn expectation_position_simd(&self) -> f64 {
58 let prob_density = self.probability_density_simd();
59 f64::simd_dot(&self.x.view(), &prob_density.view()) * self.dx
60 }
61
62 pub fn expectation_position_scalar(&self) -> f64 {
64 self.x
65 .iter()
66 .zip(self.psi.iter())
67 .map(|(&x, &psi)| x * (psi.conj() * psi).re)
68 .sum::<f64>()
69 * self.dx
70 }
71
72 pub fn expectation_momentum(&self) -> f64 {
74 let n = self.psi.len();
75 let mut momentum = 0.0;
76
77 for i in 1..n - 1 {
79 let dpsi_dx = (self.psi[i + 1] - self.psi[i - 1]) / (2.0 * self.dx);
80 momentum += (self.psi[i].conj() * Complex64::new(0.0, -REDUCED_PLANCK) * dpsi_dx).re;
81 }
82
83 momentum * self.dx
84 }
85
86 pub fn probability_density(&self) -> Array1<f64> {
88 self.probability_density_simd()
89 }
90
91 pub fn probability_density_simd(&self) -> Array1<f64> {
93 let real_parts: Array1<f64> = self.psi.mapv(|c| c.re);
95 let imag_parts: Array1<f64> = self.psi.mapv(|c| c.im);
96
97 let real_squared = f64::simd_mul(&real_parts.view(), &real_parts.view());
99 let imag_squared = f64::simd_mul(&imag_parts.view(), &imag_parts.view());
100 let result = f64::simd_add(&real_squared.view(), &imag_squared.view());
101
102 result
103 }
104
105 pub fn probability_density_scalar(&self) -> Array1<f64> {
107 self.psi.mapv(|c| (c.conj() * c).re)
108 }
109}
110
111pub trait QuantumPotential: Send + Sync {
113 fn evaluate(&self, x: f64) -> f64;
115
116 fn evaluate_array(&self, x: &ArrayView1<f64>) -> Array1<f64> {
118 x.mapv(|xi| self.evaluate(xi))
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct HarmonicOscillator {
125 pub k: f64,
127 pub x0: f64,
129}
130
131impl QuantumPotential for HarmonicOscillator {
132 fn evaluate(&self, x: f64) -> f64 {
133 0.5 * self.k * (x - self.x0).powi(2)
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct ParticleInBox {
140 pub left: f64,
142 pub right: f64,
144 pub barrier_height: f64,
146}
147
148impl QuantumPotential for ParticleInBox {
149 fn evaluate(&self, x: f64) -> f64 {
150 if x < self.left || x > self.right {
151 self.barrier_height
152 } else {
153 0.0
154 }
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct HydrogenAtom {
161 pub z: f64,
163 pub e2_4pi_eps0: f64,
165}
166
167impl QuantumPotential for HydrogenAtom {
168 fn evaluate(&self, r: f64) -> f64 {
169 if r > 0.0 {
170 -self.z * self.e2_4pi_eps0 / r
171 } else {
172 f64::NEG_INFINITY
173 }
174 }
175}
176
177pub struct SchrodingerSolver {
179 pub n_points: usize,
181 pub dt: f64,
183 pub potential: Box<dyn QuantumPotential>,
185 pub method: SchrodingerMethod,
187}
188
189#[derive(Debug, Clone, Copy)]
191pub enum SchrodingerMethod {
192 SplitOperator,
194 CrankNicolson,
196 ExplicitEuler,
198 RungeKutta4,
200}
201
202impl SchrodingerSolver {
203 pub fn new(
205 n_points: usize,
206 dt: f64,
207 potential: Box<dyn QuantumPotential>,
208 method: SchrodingerMethod,
209 ) -> Self {
210 Self {
211 n_points,
212 dt,
213 potential,
214 method,
215 }
216 }
217
218 pub fn solve_time_dependent(
220 &self,
221 initial_state: &QuantumState,
222 t_final: f64,
223 ) -> Result<Vec<QuantumState>> {
224 let mut states = vec![initial_state.clone()];
225 let mut current_state = initial_state.clone();
226
227 if current_state.x.len() != current_state.psi.len() {
229 let n = current_state.psi.len();
231 let x_min = current_state.x[0];
232 let x_max = current_state.x[current_state.x.len() - 1];
233 current_state.x = Array1::linspace(x_min, x_max, n);
234 current_state.dx = (x_max - x_min) / (n - 1) as f64;
235 }
236
237 let n_steps = (t_final / self.dt).ceil() as usize;
238
239 match self.method {
240 SchrodingerMethod::SplitOperator => {
241 for _ in 0..n_steps {
242 self.split_operator_step(&mut current_state)?;
243 current_state.t += self.dt;
244 states.push(current_state.clone());
245 }
246 }
247 SchrodingerMethod::CrankNicolson => {
248 for _ in 0..n_steps {
249 self.crank_nicolson_step(&mut current_state)?;
250 current_state.t += self.dt;
251 states.push(current_state.clone());
252 }
253 }
254 SchrodingerMethod::ExplicitEuler => {
255 for _ in 0..n_steps {
256 self.explicit_euler_step(&mut current_state)?;
257 current_state.t += self.dt;
258 states.push(current_state.clone());
259 }
260 }
261 SchrodingerMethod::RungeKutta4 => {
262 for _ in 0..n_steps {
263 self.runge_kutta4_step(&mut current_state)?;
264 current_state.t += self.dt;
265 states.push(current_state.clone());
266 }
267 }
268 }
269
270 Ok(states)
271 }
272
273 fn split_operator_step(&self, state: &mut QuantumState) -> Result<()> {
275 use scirs2_fft::{fft, ifft};
276
277 if state.x.len() != state.psi.len() {
279 let n = state.psi.len().min(state.x.len());
281 if state.psi.len() > n {
282 state.psi = state.psi.slice(scirs2_core::ndarray::s![..n]).to_owned();
283 }
284 if state.x.len() > n {
285 state.x = state.x.slice(scirs2_core::ndarray::s![..n]).to_owned();
286 }
287 }
288
289 let n = state.psi.len();
290
291 let v = self.potential.evaluate_array(&state.x.view());
293
294 for i in 0..n {
295 let phase = -v[i] * self.dt / (2.0 * REDUCED_PLANCK);
296 state.psi[i] *= Complex64::new(phase.cos(), phase.sin());
297 }
298
299 let psi_k = fft(&state.psi.to_vec(), None).map_err(|e| {
302 crate::error::IntegrateError::ComputationError(format!("FFT failed: {e:?}"))
303 })?;
304
305 let dk = 2.0 * PI / (n as f64 * state.dx);
307 let mut k_values = vec![0.0; n];
308 for (i, k_value) in k_values.iter_mut().enumerate().take(n) {
309 if i < n / 2 {
310 *k_value = i as f64 * dk;
311 } else {
312 *k_value = (i as f64 - n as f64) * dk;
313 }
314 }
315
316 let mut psi_k_evolved = psi_k;
318 for i in 0..n {
319 let k = k_values[i];
320 let kinetic_phase = -REDUCED_PLANCK * k * k * self.dt / (2.0 * state.mass);
321 psi_k_evolved[i] *= Complex64::new(kinetic_phase.cos(), kinetic_phase.sin());
322 }
323
324 let psi_evolved = ifft(&psi_k_evolved, None).map_err(|e| {
326 crate::error::IntegrateError::ComputationError(format!("IFFT failed: {e:?}"))
327 })?;
328
329 let psi_vec = if psi_evolved.len() != n {
332 psi_evolved[..n].to_vec()
333 } else {
334 psi_evolved
335 };
336 state.psi = Array1::from_vec(psi_vec);
337
338 for i in 0..n {
340 let phase = -v[i] * self.dt / (2.0 * REDUCED_PLANCK);
341 state.psi[i] *= Complex64::new(phase.cos(), phase.sin());
342 }
343
344 state.normalize();
346
347 Ok(())
348 }
349
350 fn crank_nicolson_step(&self, state: &mut QuantumState) -> Result<()> {
352 let n = state.psi.len();
353 let alpha = Complex64::new(
354 0.0,
355 REDUCED_PLANCK * self.dt / (4.0 * state.mass * state.dx.powi(2)),
356 );
357
358 let v = self.potential.evaluate_array(&state.x.view());
360 let mut a = vec![Complex64::new(0.0, 0.0); n];
361 let mut b = vec![Complex64::new(0.0, 0.0); n];
362 let mut c = vec![Complex64::new(0.0, 0.0); n];
363
364 for i in 0..n {
365 let v_term = Complex64::new(0.0, -v[i] * self.dt / (2.0 * REDUCED_PLANCK));
366 b[i] = Complex64::new(1.0, 0.0) + 2.0 * alpha - v_term;
367
368 if i > 0 {
369 a[i] = -alpha;
370 }
371 if i < n - 1 {
372 c[i] = -alpha;
373 }
374 }
375
376 let mut rhs = vec![Complex64::new(0.0, 0.0); n];
378 for i in 0..n {
379 let v_term = Complex64::new(0.0, v[i] * self.dt / (2.0 * REDUCED_PLANCK));
380 rhs[i] = state.psi[i] * (Complex64::new(1.0, 0.0) - 2.0 * alpha + v_term);
381
382 if i > 0 {
383 rhs[i] += alpha * state.psi[i - 1];
384 }
385 if i < n - 1 {
386 rhs[i] += alpha * state.psi[i + 1];
387 }
388 }
389
390 let new_psi = self.solve_tridiagonal(&a, &b, &c, &rhs)?;
392 state.psi = Array1::from_vec(new_psi);
393
394 state.normalize();
396
397 Ok(())
398 }
399
400 fn explicit_euler_step(&self, state: &mut QuantumState) -> Result<()> {
402 let n = state.psi.len();
403 let mut dpsi_dt = Array1::zeros(n);
404
405 let v = self.potential.evaluate_array(&state.x.view());
407 let prefactor = Complex64::new(0.0, -1.0 / REDUCED_PLANCK);
408
409 for i in 0..n {
410 let d2psi_dx2 = if i == 0 {
412 state.psi[1] - 2.0 * state.psi[0] + state.psi[0]
413 } else if i == n - 1 {
414 state.psi[n - 1] - 2.0 * state.psi[n - 1] + state.psi[n - 2]
415 } else {
416 state.psi[i + 1] - 2.0 * state.psi[i] + state.psi[i - 1]
417 } / state.dx.powi(2);
418
419 let h_psi =
421 -REDUCED_PLANCK.powi(2) / (2.0 * state.mass) * d2psi_dx2 + v[i] * state.psi[i];
422
423 dpsi_dt[i] = prefactor * h_psi;
424 }
425
426 state.psi += &(dpsi_dt * self.dt);
428
429 state.normalize();
431
432 Ok(())
433 }
434
435 fn runge_kutta4_step(&self, state: &mut QuantumState) -> Result<()> {
437 let n = state.psi.len();
438 let v = self.potential.evaluate_array(&state.x.view());
439
440 let compute_derivative = |psi: &Array1<Complex64>| -> Array1<Complex64> {
442 let mut dpsi = Array1::zeros(n);
443 let prefactor = Complex64::new(0.0, -1.0 / REDUCED_PLANCK);
444
445 for i in 0..n {
446 let d2psi_dx2 = if i == 0 {
447 psi[1] - 2.0 * psi[0] + psi[0]
448 } else if i == n - 1 {
449 psi[n - 1] - 2.0 * psi[n - 1] + psi[n - 2]
450 } else {
451 psi[i + 1] - 2.0 * psi[i] + psi[i - 1]
452 } / state.dx.powi(2);
453
454 let h_psi =
455 -REDUCED_PLANCK.powi(2) / (2.0 * state.mass) * d2psi_dx2 + v[i] * psi[i];
456
457 dpsi[i] = prefactor * h_psi;
458 }
459 dpsi
460 };
461
462 let k1 = compute_derivative(&state.psi);
464 let k2 = compute_derivative(&(&state.psi + &k1 * (self.dt / 2.0)));
465 let k3 = compute_derivative(&(&state.psi + &k2 * (self.dt / 2.0)));
466 let k4 = compute_derivative(&(&state.psi + &k3 * self.dt));
467
468 state.psi += &((k1 + k2 * 2.0 + k3 * 2.0 + k4) * (self.dt / 6.0));
470
471 state.normalize();
473
474 Ok(())
475 }
476
477 fn solve_tridiagonal(
479 &self,
480 a: &[Complex64],
481 b: &[Complex64],
482 c: &[Complex64],
483 d: &[Complex64],
484 ) -> Result<Vec<Complex64>> {
485 let n = b.len();
486 let mut c_star = vec![Complex64::new(0.0, 0.0); n];
487 let mut d_star = vec![Complex64::new(0.0, 0.0); n];
488 let mut x = vec![Complex64::new(0.0, 0.0); n];
489
490 c_star[0] = c[0] / b[0];
492 d_star[0] = d[0] / b[0];
493
494 for i in 1..n {
495 let m = b[i] - a[i] * c_star[i - 1];
496 c_star[i] = c[i] / m;
497 d_star[i] = (d[i] - a[i] * d_star[i - 1]) / m;
498 }
499
500 x[n - 1] = d_star[n - 1];
502 for i in (0..n - 1).rev() {
503 x[i] = d_star[i] - c_star[i] * x[i + 1];
504 }
505
506 Ok(x)
507 }
508
509 pub fn solve_time_independent(
516 &self,
517 x_min: f64,
518 x_max: f64,
519 n_states: usize,
520 ) -> Result<(Array1<f64>, Array2<f64>)> {
521 let dx = (x_max - x_min) / (self.n_points - 1) as f64;
522 let x = Array1::linspace(x_min, x_max, self.n_points);
523
524 let n_int = self.n_points - 2; if n_int < 2 {
528 return Err(IntegrateError::InvalidInput(
529 "Too few grid points for eigenvalue solve".to_string(),
530 ));
531 }
532
533 let hbar: f64 = 1.0; let mass: f64 = 1.0; let kinetic_factor = hbar.powi(2) / (2.0 * mass * dx.powi(2));
545
546 let v_int: Vec<f64> = (1..self.n_points - 1)
548 .map(|i| self.potential.evaluate(x[i]))
549 .collect();
550
551 let diag: Vec<f64> = (0..n_int)
553 .map(|i| 2.0 * kinetic_factor + v_int[i])
554 .collect();
555 let off: f64 = -kinetic_factor; let mut energies = Array1::zeros(n_states);
559 let mut wavefunctions = Array2::zeros((self.n_points, n_states));
560
561 let max_iter = 500;
573 let tol = 1e-10;
574
575 let diag_min = diag.iter().cloned().fold(f64::INFINITY, f64::min);
580 let gershgorin_lower = diag_min - 2.0 * off.abs();
581 let initial_shift = gershgorin_lower - 0.1 * (off.abs() + 1.0);
583
584 for state in 0..n_states {
585 let mut psi = Array1::from_shape_fn(n_int, |i| {
587 let s = (state + 1) as f64;
588 (s * PI * (i + 1) as f64 / (n_int + 1) as f64).sin()
589 });
590
591 for j in 0..state {
593 let prev_int = wavefunctions
594 .column(j)
595 .slice(scirs2_core::ndarray::s![1..self.n_points - 1])
596 .to_owned();
597 let overlap: f64 = psi
598 .iter()
599 .zip(prev_int.iter())
600 .map(|(&a, &b)| a * b * dx)
601 .sum();
602 psi.zip_mut_with(&prev_int, |a, &b| *a -= overlap * b);
603 }
604
605 let norm: f64 = psi.iter().map(|&v| v * v * dx).sum::<f64>().sqrt();
607 if norm > 1e-14 {
608 psi /= norm;
609 }
610
611 let shift = initial_shift;
617
618 let mut eigenvalue = Self::rayleigh_quotient(&psi, &diag, off, dx);
619 let mut prev_eigenvalue = f64::NEG_INFINITY;
620
621 for _iter in 0..max_iter {
622 let shifted_diag: Vec<f64> = diag.iter().map(|&d| d - shift).collect();
624 let rhs: Vec<f64> = psi.iter().copied().collect();
625
626 let psi_new = Self::solve_tridiagonal_real(&shifted_diag, off, &rhs)?;
627 let mut psi_new_arr = Array1::from_vec(psi_new);
628
629 for j in 0..state {
631 let prev_int = wavefunctions
632 .column(j)
633 .slice(scirs2_core::ndarray::s![1..self.n_points - 1])
634 .to_owned();
635 let overlap: f64 = psi_new_arr
636 .iter()
637 .zip(prev_int.iter())
638 .map(|(&a, &b)| a * b * dx)
639 .sum();
640 psi_new_arr.zip_mut_with(&prev_int, |a, &b| *a -= overlap * b);
641 }
642
643 let norm_new: f64 = psi_new_arr.iter().map(|&v| v * v * dx).sum::<f64>().sqrt();
645 if norm_new < 1e-14 {
646 break;
647 }
648 psi_new_arr /= norm_new;
649 psi = psi_new_arr;
650
651 eigenvalue = Self::rayleigh_quotient(&psi, &diag, off, dx);
653
654 if (eigenvalue - prev_eigenvalue).abs() < tol {
660 break;
661 }
662 prev_eigenvalue = eigenvalue;
663 }
664
665 energies[state] = eigenvalue;
666
667 for i in 0..n_int {
669 wavefunctions[[i + 1, state]] = psi[i];
670 }
671 }
672
673 let mut indices: Vec<usize> = (0..n_states).collect();
675 indices.sort_by(|&i, &j| {
676 energies[i]
677 .partial_cmp(&energies[j])
678 .unwrap_or(std::cmp::Ordering::Equal)
679 });
680
681 let sorted_energies = Array1::from_vec(indices.iter().map(|&i| energies[i]).collect());
682 let mut sorted_wavefunctions = Array2::zeros((self.n_points, n_states));
683 for (new_idx, &old_idx) in indices.iter().enumerate() {
684 sorted_wavefunctions
685 .column_mut(new_idx)
686 .assign(&wavefunctions.column(old_idx));
687 }
688
689 Ok((sorted_energies, sorted_wavefunctions))
690 }
691
692 fn rayleigh_quotient(psi: &Array1<f64>, diag: &[f64], off: f64, dx: f64) -> f64 {
694 let n = psi.len();
695 let mut h_psi = Array1::zeros(n);
696 for i in 0..n {
697 h_psi[i] = diag[i] * psi[i];
698 if i > 0 {
699 h_psi[i] += off * psi[i - 1];
700 }
701 if i < n - 1 {
702 h_psi[i] += off * psi[i + 1];
703 }
704 }
705 psi.iter()
706 .zip(h_psi.iter())
707 .map(|(&a, &b)| a * b * dx)
708 .sum()
709 }
710
711 fn solve_tridiagonal_real(diag: &[f64], off: f64, rhs: &[f64]) -> Result<Vec<f64>> {
714 let n = diag.len();
715 if n == 0 {
716 return Ok(Vec::new());
717 }
718 let mut c_star = vec![0.0_f64; n];
719 let mut d_star = vec![0.0_f64; n];
720
721 if diag[0].abs() < 1e-300 {
723 return Err(IntegrateError::ComputationError(
724 "Singular tridiagonal system during inverse power iteration".to_string(),
725 ));
726 }
727 c_star[0] = off / diag[0];
728 d_star[0] = rhs[0] / diag[0];
729
730 for i in 1..n {
731 let denom = diag[i] - off * c_star[i - 1];
732 if denom.abs() < 1e-300 {
733 return Err(IntegrateError::ComputationError(
734 "Singular tridiagonal system during inverse power iteration".to_string(),
735 ));
736 }
737 c_star[i] = off / denom;
738 d_star[i] = (rhs[i] - off * d_star[i - 1]) / denom;
739 }
740
741 let mut x = vec![0.0_f64; n];
743 x[n - 1] = d_star[n - 1];
744 for i in (0..n - 1).rev() {
745 x[i] = d_star[i] - c_star[i] * x[i + 1];
746 }
747
748 Ok(x)
749 }
750
751 pub fn gaussian_wave_packet(
753 x: &Array1<f64>,
754 x0: f64,
755 sigma: f64,
756 k0: f64,
757 mass: f64,
758 ) -> QuantumState {
759 let norm = 1.0 / (2.0 * PI * sigma.powi(2)).powf(0.25);
760
761 let original_n = x.len();
763 let fft_n = original_n.next_power_of_two();
764
765 let (x_final, psi_final) = if fft_n != original_n {
767 let x_min = x[0];
769 let x_max = x[original_n - 1];
770 let x_padded = Array1::linspace(x_min, x_max, fft_n);
771
772 let psi_padded = x_padded.mapv(|xi| {
773 let gaussian = norm * (-(xi - x0).powi(2) / (4.0 * sigma.powi(2))).exp();
774 let phase = k0 * xi;
775 Complex64::new(gaussian * phase.cos(), gaussian * phase.sin())
776 });
777
778 (x_padded, psi_padded)
779 } else {
780 let psi = x.mapv(|xi| {
782 let gaussian = norm * (-(xi - x0).powi(2) / (4.0 * sigma.powi(2))).exp();
783 let phase = k0 * xi;
784 Complex64::new(gaussian * phase.cos(), gaussian * phase.sin())
785 });
786 (x.clone(), psi)
787 };
788
789 let mut state = QuantumState::new(psi_final, x_final, 0.0, mass);
790 state.normalize();
791 state
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798 use approx::assert_relative_eq;
799
800 #[test]
801 fn test_harmonic_oscillator_ground_state() {
802 let potential = Box::new(HarmonicOscillator { k: 1.0, x0: 0.0 });
803 let solver = SchrodingerSolver::new(100, 0.01, potential, SchrodingerMethod::SplitOperator);
804
805 let (energies, wavefunctions) = solver
806 .solve_time_independent(-5.0, 5.0, 3)
807 .expect("Operation failed");
808
809 assert_relative_eq!(energies[0], 0.5, epsilon = 0.01);
811
812 assert_relative_eq!(energies[1], 1.5, epsilon = 0.01);
814 }
815
816 #[test]
817 fn test_wave_packet_evolution() {
818 let potential = Box::new(HarmonicOscillator { k: 0.0, x0: 0.0 }); let solver =
820 SchrodingerSolver::new(200, 0.001, potential, SchrodingerMethod::SplitOperator);
821
822 let x = Array1::linspace(-10.0, 10.0, 200);
823 let initial_state = SchrodingerSolver::gaussian_wave_packet(&x, -5.0, 1.0, 2.0, 1.0);
824
825 let states = solver
826 .solve_time_dependent(&initial_state, 1.0)
827 .expect("Operation failed");
828
829 for state in &states {
831 let norm_squared: f64 =
832 state.psi.iter().map(|&c| (c.conj() * c).re).sum::<f64>() * state.dx;
833 assert_relative_eq!(norm_squared, 1.0, epsilon = 1e-6);
834 }
835
836 let final_position = states
838 .last()
839 .expect("Operation failed")
840 .expectation_position();
841 assert!(final_position > -5.0);
842 }
843}