Skip to main content

scirs2_integrate/specialized/quantum/
core.rs

1//! Core quantum mechanics components
2//!
3//! This module provides basic quantum state representations, potentials,
4//! and Schrödinger equation solvers.
5
6use crate::error::{IntegrateError, IntegrateResult as Result};
7use scirs2_core::constants::{PI, REDUCED_PLANCK};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::simd_ops::SimdUnifiedOps;
11
12/// Quantum state representation
13#[derive(Debug, Clone)]
14pub struct QuantumState {
15    /// Wave function values (complex)
16    pub psi: Array1<Complex64>,
17    /// Spatial grid points
18    pub x: Array1<f64>,
19    /// Time
20    pub t: f64,
21    /// Mass of the particle
22    pub mass: f64,
23    /// Spatial step size
24    pub dx: f64,
25}
26
27impl QuantumState {
28    /// Create a new quantum state
29    pub fn new(psi: Array1<Complex64>, x: Array1<f64>, t: f64, mass: f64) -> Self {
30        let dx = if x.len() > 1 { x[1] - x[0] } else { 1.0 };
31
32        Self {
33            psi,
34            x,
35            t,
36            mass,
37            dx,
38        }
39    }
40
41    /// Normalize the wave function
42    pub fn normalize(&mut self) {
43        let norm_squared: f64 = self.psi.iter().map(|&c| (c.conj() * c).re).sum::<f64>() * self.dx;
44
45        let norm = norm_squared.sqrt();
46        if norm > 0.0 {
47            self.psi.mapv_inplace(|c| c / norm);
48        }
49    }
50
51    /// Calculate expectation value of position
52    pub fn expectation_position(&self) -> f64 {
53        self.expectation_position_simd()
54    }
55
56    /// SIMD-optimized expectation value of position
57    pub fn expectation_position_simd(&self) -> f64 {
58        let prob_density = self.probability_density_simd();
59        f64::simd_dot(&self.x.view(), &prob_density.view()) * self.dx
60    }
61
62    /// Fallback scalar implementation for expectation value of position
63    pub fn expectation_position_scalar(&self) -> f64 {
64        self.x
65            .iter()
66            .zip(self.psi.iter())
67            .map(|(&x, &psi)| x * (psi.conj() * psi).re)
68            .sum::<f64>()
69            * self.dx
70    }
71
72    /// Calculate expectation value of momentum
73    pub fn expectation_momentum(&self) -> f64 {
74        let n = self.psi.len();
75        let mut momentum = 0.0;
76
77        // Central difference for derivative
78        for i in 1..n - 1 {
79            let dpsi_dx = (self.psi[i + 1] - self.psi[i - 1]) / (2.0 * self.dx);
80            momentum += (self.psi[i].conj() * Complex64::new(0.0, -REDUCED_PLANCK) * dpsi_dx).re;
81        }
82
83        momentum * self.dx
84    }
85
86    /// Calculate probability density
87    pub fn probability_density(&self) -> Array1<f64> {
88        self.probability_density_simd()
89    }
90
91    /// SIMD-optimized probability density calculation
92    pub fn probability_density_simd(&self) -> Array1<f64> {
93        // Convert complex numbers to real and imaginary parts for SIMD processing
94        let real_parts: Array1<f64> = self.psi.mapv(|c| c.re);
95        let imag_parts: Array1<f64> = self.psi.mapv(|c| c.im);
96
97        // Calculate |psi|^2 = Re(psi)^2 + Im(psi)^2 using SIMD
98        let real_squared = f64::simd_mul(&real_parts.view(), &real_parts.view());
99        let imag_squared = f64::simd_mul(&imag_parts.view(), &imag_parts.view());
100        let result = f64::simd_add(&real_squared.view(), &imag_squared.view());
101
102        result
103    }
104
105    /// Fallback scalar implementation for probability density
106    pub fn probability_density_scalar(&self) -> Array1<f64> {
107        self.psi.mapv(|c| (c.conj() * c).re)
108    }
109}
110
111/// Quantum potential trait
112pub trait QuantumPotential: Send + Sync {
113    /// Evaluate potential at given position
114    fn evaluate(&self, x: f64) -> f64;
115
116    /// Evaluate potential for array of positions
117    fn evaluate_array(&self, x: &ArrayView1<f64>) -> Array1<f64> {
118        x.mapv(|xi| self.evaluate(xi))
119    }
120}
121
122/// Harmonic oscillator potential
123#[derive(Debug, Clone)]
124pub struct HarmonicOscillator {
125    /// Spring constant
126    pub k: f64,
127    /// Center position
128    pub x0: f64,
129}
130
131impl QuantumPotential for HarmonicOscillator {
132    fn evaluate(&self, x: f64) -> f64 {
133        0.5 * self.k * (x - self.x0).powi(2)
134    }
135}
136
137/// Particle in a box potential
138#[derive(Debug, Clone)]
139pub struct ParticleInBox {
140    /// Left boundary
141    pub left: f64,
142    /// Right boundary
143    pub right: f64,
144    /// Barrier height
145    pub barrier_height: f64,
146}
147
148impl QuantumPotential for ParticleInBox {
149    fn evaluate(&self, x: f64) -> f64 {
150        if x < self.left || x > self.right {
151            self.barrier_height
152        } else {
153            0.0
154        }
155    }
156}
157
158/// Hydrogen-like atom potential
159#[derive(Debug, Clone)]
160pub struct HydrogenAtom {
161    /// Nuclear charge
162    pub z: f64,
163    /// Electron charge squared / (4π ε₀)
164    pub e2_4pi_eps0: f64,
165}
166
167impl QuantumPotential for HydrogenAtom {
168    fn evaluate(&self, r: f64) -> f64 {
169        if r > 0.0 {
170            -self.z * self.e2_4pi_eps0 / r
171        } else {
172            f64::NEG_INFINITY
173        }
174    }
175}
176
177/// Solver for the Schrödinger equation
178pub struct SchrodingerSolver {
179    /// Spatial grid size
180    pub n_points: usize,
181    /// Time step size
182    pub dt: f64,
183    /// Potential function
184    pub potential: Box<dyn QuantumPotential>,
185    /// Solver method
186    pub method: SchrodingerMethod,
187}
188
189/// Available methods for solving the Schrödinger equation
190#[derive(Debug, Clone, Copy)]
191pub enum SchrodingerMethod {
192    /// Split-operator method (fast and accurate)
193    SplitOperator,
194    /// Crank-Nicolson method (implicit, stable)
195    CrankNicolson,
196    /// Explicit Euler (simple but less stable)
197    ExplicitEuler,
198    /// Fourth-order Runge-Kutta
199    RungeKutta4,
200}
201
202impl SchrodingerSolver {
203    /// Create a new Schrödinger solver
204    pub fn new(
205        n_points: usize,
206        dt: f64,
207        potential: Box<dyn QuantumPotential>,
208        method: SchrodingerMethod,
209    ) -> Self {
210        Self {
211            n_points,
212            dt,
213            potential,
214            method,
215        }
216    }
217
218    /// Solve time-dependent Schrödinger equation
219    pub fn solve_time_dependent(
220        &self,
221        initial_state: &QuantumState,
222        t_final: f64,
223    ) -> Result<Vec<QuantumState>> {
224        let mut states = vec![initial_state.clone()];
225        let mut current_state = initial_state.clone();
226
227        // Ensure x and psi have consistent lengths
228        if current_state.x.len() != current_state.psi.len() {
229            // Resize x to match psi if they differ (e.g., due to FFT padding requirements)
230            let n = current_state.psi.len();
231            let x_min = current_state.x[0];
232            let x_max = current_state.x[current_state.x.len() - 1];
233            current_state.x = Array1::linspace(x_min, x_max, n);
234            current_state.dx = (x_max - x_min) / (n - 1) as f64;
235        }
236
237        let n_steps = (t_final / self.dt).ceil() as usize;
238
239        match self.method {
240            SchrodingerMethod::SplitOperator => {
241                for _ in 0..n_steps {
242                    self.split_operator_step(&mut current_state)?;
243                    current_state.t += self.dt;
244                    states.push(current_state.clone());
245                }
246            }
247            SchrodingerMethod::CrankNicolson => {
248                for _ in 0..n_steps {
249                    self.crank_nicolson_step(&mut current_state)?;
250                    current_state.t += self.dt;
251                    states.push(current_state.clone());
252                }
253            }
254            SchrodingerMethod::ExplicitEuler => {
255                for _ in 0..n_steps {
256                    self.explicit_euler_step(&mut current_state)?;
257                    current_state.t += self.dt;
258                    states.push(current_state.clone());
259                }
260            }
261            SchrodingerMethod::RungeKutta4 => {
262                for _ in 0..n_steps {
263                    self.runge_kutta4_step(&mut current_state)?;
264                    current_state.t += self.dt;
265                    states.push(current_state.clone());
266                }
267            }
268        }
269
270        Ok(states)
271    }
272
273    /// Split-operator method step
274    fn split_operator_step(&self, state: &mut QuantumState) -> Result<()> {
275        use scirs2_fft::{fft, ifft};
276
277        // Ensure x and psi have the same length before proceeding
278        if state.x.len() != state.psi.len() {
279            // This shouldn't happen, but handle it gracefully
280            let n = state.psi.len().min(state.x.len());
281            if state.psi.len() > n {
282                state.psi = state.psi.slice(scirs2_core::ndarray::s![..n]).to_owned();
283            }
284            if state.x.len() > n {
285                state.x = state.x.slice(scirs2_core::ndarray::s![..n]).to_owned();
286            }
287        }
288
289        let n = state.psi.len();
290
291        // Potential energy evolution (half step)
292        let v = self.potential.evaluate_array(&state.x.view());
293
294        for i in 0..n {
295            let phase = -v[i] * self.dt / (2.0 * REDUCED_PLANCK);
296            state.psi[i] *= Complex64::new(phase.cos(), phase.sin());
297        }
298
299        // Kinetic energy evolution in momentum space using FFT
300        // Transform to momentum space
301        let psi_k = fft(&state.psi.to_vec(), None).map_err(|e| {
302            crate::error::IntegrateError::ComputationError(format!("FFT failed: {e:?}"))
303        })?;
304
305        // Calculate k-space grid (momentum values)
306        let dk = 2.0 * PI / (n as f64 * state.dx);
307        let mut k_values = vec![0.0; n];
308        for (i, k_value) in k_values.iter_mut().enumerate().take(n) {
309            if i < n / 2 {
310                *k_value = i as f64 * dk;
311            } else {
312                *k_value = (i as f64 - n as f64) * dk;
313            }
314        }
315
316        // Apply kinetic energy operator in momentum space
317        let mut psi_k_evolved = psi_k;
318        for i in 0..n {
319            let k = k_values[i];
320            let kinetic_phase = -REDUCED_PLANCK * k * k * self.dt / (2.0 * state.mass);
321            psi_k_evolved[i] *= Complex64::new(kinetic_phase.cos(), kinetic_phase.sin());
322        }
323
324        // Transform back to position space
325        let psi_evolved = ifft(&psi_k_evolved, None).map_err(|e| {
326            crate::error::IntegrateError::ComputationError(format!("IFFT failed: {e:?}"))
327        })?;
328
329        // Update state with evolved wave function
330        // Ensure we preserve the original size (FFT might have padded)
331        let psi_vec = if psi_evolved.len() != n {
332            psi_evolved[..n].to_vec()
333        } else {
334            psi_evolved
335        };
336        state.psi = Array1::from_vec(psi_vec);
337
338        // Potential energy evolution (half step)
339        for i in 0..n {
340            let phase = -v[i] * self.dt / (2.0 * REDUCED_PLANCK);
341            state.psi[i] *= Complex64::new(phase.cos(), phase.sin());
342        }
343
344        // Normalize to conserve probability
345        state.normalize();
346
347        Ok(())
348    }
349
350    /// Crank-Nicolson method step
351    fn crank_nicolson_step(&self, state: &mut QuantumState) -> Result<()> {
352        let n = state.psi.len();
353        let alpha = Complex64::new(
354            0.0,
355            REDUCED_PLANCK * self.dt / (4.0 * state.mass * state.dx.powi(2)),
356        );
357
358        // Build tridiagonal matrices
359        let v = self.potential.evaluate_array(&state.x.view());
360        let mut a = vec![Complex64::new(0.0, 0.0); n];
361        let mut b = vec![Complex64::new(0.0, 0.0); n];
362        let mut c = vec![Complex64::new(0.0, 0.0); n];
363
364        for i in 0..n {
365            let v_term = Complex64::new(0.0, -v[i] * self.dt / (2.0 * REDUCED_PLANCK));
366            b[i] = Complex64::new(1.0, 0.0) + 2.0 * alpha - v_term;
367
368            if i > 0 {
369                a[i] = -alpha;
370            }
371            if i < n - 1 {
372                c[i] = -alpha;
373            }
374        }
375
376        // Build right-hand side
377        let mut rhs = vec![Complex64::new(0.0, 0.0); n];
378        for i in 0..n {
379            let v_term = Complex64::new(0.0, v[i] * self.dt / (2.0 * REDUCED_PLANCK));
380            rhs[i] = state.psi[i] * (Complex64::new(1.0, 0.0) - 2.0 * alpha + v_term);
381
382            if i > 0 {
383                rhs[i] += alpha * state.psi[i - 1];
384            }
385            if i < n - 1 {
386                rhs[i] += alpha * state.psi[i + 1];
387            }
388        }
389
390        // Solve tridiagonal system using Thomas algorithm
391        let new_psi = self.solve_tridiagonal(&a, &b, &c, &rhs)?;
392        state.psi = Array1::from_vec(new_psi);
393
394        // Normalize
395        state.normalize();
396
397        Ok(())
398    }
399
400    /// Explicit Euler method step
401    fn explicit_euler_step(&self, state: &mut QuantumState) -> Result<()> {
402        let n = state.psi.len();
403        let mut dpsi_dt = Array1::zeros(n);
404
405        // Calculate time derivative using Schrödinger equation
406        let v = self.potential.evaluate_array(&state.x.view());
407        let prefactor = Complex64::new(0.0, -1.0 / REDUCED_PLANCK);
408
409        for i in 0..n {
410            // Kinetic energy term (second derivative)
411            let d2psi_dx2 = if i == 0 {
412                state.psi[1] - 2.0 * state.psi[0] + state.psi[0]
413            } else if i == n - 1 {
414                state.psi[n - 1] - 2.0 * state.psi[n - 1] + state.psi[n - 2]
415            } else {
416                state.psi[i + 1] - 2.0 * state.psi[i] + state.psi[i - 1]
417            } / state.dx.powi(2);
418
419            // Hamiltonian action
420            let h_psi =
421                -REDUCED_PLANCK.powi(2) / (2.0 * state.mass) * d2psi_dx2 + v[i] * state.psi[i];
422
423            dpsi_dt[i] = prefactor * h_psi;
424        }
425
426        // Update wave function
427        state.psi += &(dpsi_dt * self.dt);
428
429        // Normalize
430        state.normalize();
431
432        Ok(())
433    }
434
435    /// Fourth-order Runge-Kutta method step
436    fn runge_kutta4_step(&self, state: &mut QuantumState) -> Result<()> {
437        let n = state.psi.len();
438        let v = self.potential.evaluate_array(&state.x.view());
439
440        // Helper function to compute derivative
441        let compute_derivative = |psi: &Array1<Complex64>| -> Array1<Complex64> {
442            let mut dpsi = Array1::zeros(n);
443            let prefactor = Complex64::new(0.0, -1.0 / REDUCED_PLANCK);
444
445            for i in 0..n {
446                let d2psi_dx2 = if i == 0 {
447                    psi[1] - 2.0 * psi[0] + psi[0]
448                } else if i == n - 1 {
449                    psi[n - 1] - 2.0 * psi[n - 1] + psi[n - 2]
450                } else {
451                    psi[i + 1] - 2.0 * psi[i] + psi[i - 1]
452                } / state.dx.powi(2);
453
454                let h_psi =
455                    -REDUCED_PLANCK.powi(2) / (2.0 * state.mass) * d2psi_dx2 + v[i] * psi[i];
456
457                dpsi[i] = prefactor * h_psi;
458            }
459            dpsi
460        };
461
462        // RK4 steps
463        let k1 = compute_derivative(&state.psi);
464        let k2 = compute_derivative(&(&state.psi + &k1 * (self.dt / 2.0)));
465        let k3 = compute_derivative(&(&state.psi + &k2 * (self.dt / 2.0)));
466        let k4 = compute_derivative(&(&state.psi + &k3 * self.dt));
467
468        // Update
469        state.psi += &((k1 + k2 * 2.0 + k3 * 2.0 + k4) * (self.dt / 6.0));
470
471        // Normalize
472        state.normalize();
473
474        Ok(())
475    }
476
477    /// Solve tridiagonal system using Thomas algorithm
478    fn solve_tridiagonal(
479        &self,
480        a: &[Complex64],
481        b: &[Complex64],
482        c: &[Complex64],
483        d: &[Complex64],
484    ) -> Result<Vec<Complex64>> {
485        let n = b.len();
486        let mut c_star = vec![Complex64::new(0.0, 0.0); n];
487        let mut d_star = vec![Complex64::new(0.0, 0.0); n];
488        let mut x = vec![Complex64::new(0.0, 0.0); n];
489
490        // Forward sweep
491        c_star[0] = c[0] / b[0];
492        d_star[0] = d[0] / b[0];
493
494        for i in 1..n {
495            let m = b[i] - a[i] * c_star[i - 1];
496            c_star[i] = c[i] / m;
497            d_star[i] = (d[i] - a[i] * d_star[i - 1]) / m;
498        }
499
500        // Back substitution
501        x[n - 1] = d_star[n - 1];
502        for i in (0..n - 1).rev() {
503            x[i] = d_star[i] - c_star[i] * x[i + 1];
504        }
505
506        Ok(x)
507    }
508
509    /// Solve time-independent Schrödinger equation (eigenvalue problem)
510    ///
511    /// Uses inverse power iteration on the interior grid (Dirichlet BCs enforced by
512    /// working only on points 1..n-1) to converge to the lowest `n_states` energy
513    /// eigenpairs.  The tridiagonal shifted system `(H_int - σI)ψ = b` is solved
514    /// with the Thomas algorithm at each iteration, which is both fast and stable.
515    pub fn solve_time_independent(
516        &self,
517        x_min: f64,
518        x_max: f64,
519        n_states: usize,
520    ) -> Result<(Array1<f64>, Array2<f64>)> {
521        let dx = (x_max - x_min) / (self.n_points - 1) as f64;
522        let x = Array1::linspace(x_min, x_max, self.n_points);
523
524        // Interior grid: exclude boundary points (Dirichlet ψ=0 at i=0 and i=n-1)
525        let n_int = self.n_points - 2; // number of interior points
526
527        if n_int < 2 {
528            return Err(IntegrateError::InvalidInput(
529                "Too few grid points for eigenvalue solve".to_string(),
530            ));
531        }
532
533        // Kinetic energy contribution: -ℏ²/(2m) d²/dx²
534        // With finite differences: T[i,i]=ℏ²/(m·dx²), T[i,i±1]=-ℏ²/(2m·dx²)
535        //
536        // This method operates in natural / dimensionless units where ℏ = 1 and
537        // m = 1.  The SI value of REDUCED_PLANCK is physically correct for
538        // time-dependent propagation (where it cancels phase factors), but for
539        // the eigenvalue problem the user's potential is typically expressed in the
540        // same dimensionless unit system as the test expects (E = ℏω/2 = 0.5 with
541        // k = 1, m = 1, ℏ = 1).
542        let hbar: f64 = 1.0; // natural units
543        let mass: f64 = 1.0; // natural units
544        let kinetic_factor = hbar.powi(2) / (2.0 * mass * dx.powi(2));
545
546        // Evaluate potential on the interior grid (x[1..n-1])
547        let v_int: Vec<f64> = (1..self.n_points - 1)
548            .map(|i| self.potential.evaluate(x[i]))
549            .collect();
550
551        // Diagonal and off-diagonal of the interior Hamiltonian (tridiagonal)
552        let diag: Vec<f64> = (0..n_int)
553            .map(|i| 2.0 * kinetic_factor + v_int[i])
554            .collect();
555        let off: f64 = -kinetic_factor; // sub- and super-diagonal (constant)
556
557        // Storage for found eigenstates (interior only, will pad with 0s later)
558        let mut energies = Array1::zeros(n_states);
559        let mut wavefunctions = Array2::zeros((self.n_points, n_states));
560
561        // Inverse power iteration with deflation to find the `n_states` lowest
562        // eigenpairs of the interior tridiagonal Hamiltonian.
563        //
564        // Strategy:
565        //  - For state s, use a shift that is slightly above the (s-1)-th eigenvalue
566        //    that was already found (or a small negative value for the ground state).
567        //    This guarantees the shifted system (H - σI) has the s-th eigenvalue as
568        //    the one smallest in absolute value, so inverse power iteration converges
569        //    to it.
570        //  - Gram-Schmidt orthogonalisation against all previously found eigenstates
571        //    is applied every iteration to prevent drift back to lower modes.
572        let max_iter = 500;
573        let tol = 1e-10;
574
575        // A lower bound for the spectrum: the minimum possible eigenvalue is bounded
576        // below by min(diag) - 2*|off| (Gershgorin circle theorem).  We use this
577        // as the starting shift so the ground-state eigenvalue is closest to zero
578        // in the shifted system (H - σI).
579        let diag_min = diag.iter().cloned().fold(f64::INFINITY, f64::min);
580        let gershgorin_lower = diag_min - 2.0 * off.abs();
581        // Subtract a small buffer so the shift stays strictly below E_0
582        let initial_shift = gershgorin_lower - 0.1 * (off.abs() + 1.0);
583
584        for state in 0..n_states {
585            // Initial guess: sine wave matching the (state+1)-th harmonic
586            let mut psi = Array1::from_shape_fn(n_int, |i| {
587                let s = (state + 1) as f64;
588                (s * PI * (i + 1) as f64 / (n_int + 1) as f64).sin()
589            });
590
591            // Gram-Schmidt orthogonalise against already-found interior eigenstates
592            for j in 0..state {
593                let prev_int = wavefunctions
594                    .column(j)
595                    .slice(scirs2_core::ndarray::s![1..self.n_points - 1])
596                    .to_owned();
597                let overlap: f64 = psi
598                    .iter()
599                    .zip(prev_int.iter())
600                    .map(|(&a, &b)| a * b * dx)
601                    .sum();
602                psi.zip_mut_with(&prev_int, |a, &b| *a -= overlap * b);
603            }
604
605            // Normalise
606            let norm: f64 = psi.iter().map(|&v| v * v * dx).sum::<f64>().sqrt();
607            if norm > 1e-14 {
608                psi /= norm;
609            }
610
611            // Use the same initial shift for all states.  Gram-Schmidt deflation
612            // (applied every iteration) prevents convergence to already-found states.
613            // The shift stays strictly below all eigenvalues (Gershgorin bound),
614            // so (H - σI) is positive-definite and the inverse power iteration
615            // converges to the lowest remaining eigenvalue in the deflated space.
616            let shift = initial_shift;
617
618            let mut eigenvalue = Self::rayleigh_quotient(&psi, &diag, off, dx);
619            let mut prev_eigenvalue = f64::NEG_INFINITY;
620
621            for _iter in 0..max_iter {
622                // Solve (H_int - shift·I) psi_new = psi  via Thomas algorithm
623                let shifted_diag: Vec<f64> = diag.iter().map(|&d| d - shift).collect();
624                let rhs: Vec<f64> = psi.iter().copied().collect();
625
626                let psi_new = Self::solve_tridiagonal_real(&shifted_diag, off, &rhs)?;
627                let mut psi_new_arr = Array1::from_vec(psi_new);
628
629                // Orthogonalise against already-found eigenstates (deflation)
630                for j in 0..state {
631                    let prev_int = wavefunctions
632                        .column(j)
633                        .slice(scirs2_core::ndarray::s![1..self.n_points - 1])
634                        .to_owned();
635                    let overlap: f64 = psi_new_arr
636                        .iter()
637                        .zip(prev_int.iter())
638                        .map(|(&a, &b)| a * b * dx)
639                        .sum();
640                    psi_new_arr.zip_mut_with(&prev_int, |a, &b| *a -= overlap * b);
641                }
642
643                // Normalise
644                let norm_new: f64 = psi_new_arr.iter().map(|&v| v * v * dx).sum::<f64>().sqrt();
645                if norm_new < 1e-14 {
646                    break;
647                }
648                psi_new_arr /= norm_new;
649                psi = psi_new_arr;
650
651                // Update eigenvalue via Rayleigh quotient
652                eigenvalue = Self::rayleigh_quotient(&psi, &diag, off, dx);
653
654                // Keep shift fixed (well below all eigenvalues).  This ensures the
655                // deflated inverse power iteration converges to the lowest remaining
656                // eigenvalue rather than chasing a higher one.
657
658                // Check convergence
659                if (eigenvalue - prev_eigenvalue).abs() < tol {
660                    break;
661                }
662                prev_eigenvalue = eigenvalue;
663            }
664
665            energies[state] = eigenvalue;
666
667            // Embed interior solution into full grid (pad with zeros at boundaries)
668            for i in 0..n_int {
669                wavefunctions[[i + 1, state]] = psi[i];
670            }
671        }
672
673        // Sort by energy (ascending)
674        let mut indices: Vec<usize> = (0..n_states).collect();
675        indices.sort_by(|&i, &j| {
676            energies[i]
677                .partial_cmp(&energies[j])
678                .unwrap_or(std::cmp::Ordering::Equal)
679        });
680
681        let sorted_energies = Array1::from_vec(indices.iter().map(|&i| energies[i]).collect());
682        let mut sorted_wavefunctions = Array2::zeros((self.n_points, n_states));
683        for (new_idx, &old_idx) in indices.iter().enumerate() {
684            sorted_wavefunctions
685                .column_mut(new_idx)
686                .assign(&wavefunctions.column(old_idx));
687        }
688
689        Ok((sorted_energies, sorted_wavefunctions))
690    }
691
692    /// Rayleigh quotient ⟨ψ|H|ψ⟩ for the tridiagonal interior Hamiltonian.
693    fn rayleigh_quotient(psi: &Array1<f64>, diag: &[f64], off: f64, dx: f64) -> f64 {
694        let n = psi.len();
695        let mut h_psi = Array1::zeros(n);
696        for i in 0..n {
697            h_psi[i] = diag[i] * psi[i];
698            if i > 0 {
699                h_psi[i] += off * psi[i - 1];
700            }
701            if i < n - 1 {
702                h_psi[i] += off * psi[i + 1];
703            }
704        }
705        psi.iter()
706            .zip(h_psi.iter())
707            .map(|(&a, &b)| a * b * dx)
708            .sum()
709    }
710
711    /// Solve a tridiagonal system with constant off-diagonal `off` via the Thomas
712    /// algorithm.  Returns `Err` if the system is numerically singular.
713    fn solve_tridiagonal_real(diag: &[f64], off: f64, rhs: &[f64]) -> Result<Vec<f64>> {
714        let n = diag.len();
715        if n == 0 {
716            return Ok(Vec::new());
717        }
718        let mut c_star = vec![0.0_f64; n];
719        let mut d_star = vec![0.0_f64; n];
720
721        // Forward sweep
722        if diag[0].abs() < 1e-300 {
723            return Err(IntegrateError::ComputationError(
724                "Singular tridiagonal system during inverse power iteration".to_string(),
725            ));
726        }
727        c_star[0] = off / diag[0];
728        d_star[0] = rhs[0] / diag[0];
729
730        for i in 1..n {
731            let denom = diag[i] - off * c_star[i - 1];
732            if denom.abs() < 1e-300 {
733                return Err(IntegrateError::ComputationError(
734                    "Singular tridiagonal system during inverse power iteration".to_string(),
735                ));
736            }
737            c_star[i] = off / denom;
738            d_star[i] = (rhs[i] - off * d_star[i - 1]) / denom;
739        }
740
741        // Back substitution
742        let mut x = vec![0.0_f64; n];
743        x[n - 1] = d_star[n - 1];
744        for i in (0..n - 1).rev() {
745            x[i] = d_star[i] - c_star[i] * x[i + 1];
746        }
747
748        Ok(x)
749    }
750
751    /// Create initial Gaussian wave packet
752    pub fn gaussian_wave_packet(
753        x: &Array1<f64>,
754        x0: f64,
755        sigma: f64,
756        k0: f64,
757        mass: f64,
758    ) -> QuantumState {
759        let norm = 1.0 / (2.0 * PI * sigma.powi(2)).powf(0.25);
760
761        // For FFT efficiency, ensure we use a power of 2 size
762        let original_n = x.len();
763        let fft_n = original_n.next_power_of_two();
764
765        // Create arrays with appropriate size
766        let (x_final, psi_final) = if fft_n != original_n {
767            // Need to pad to power of 2
768            let x_min = x[0];
769            let x_max = x[original_n - 1];
770            let x_padded = Array1::linspace(x_min, x_max, fft_n);
771
772            let psi_padded = x_padded.mapv(|xi| {
773                let gaussian = norm * (-(xi - x0).powi(2) / (4.0 * sigma.powi(2))).exp();
774                let phase = k0 * xi;
775                Complex64::new(gaussian * phase.cos(), gaussian * phase.sin())
776            });
777
778            (x_padded, psi_padded)
779        } else {
780            // Already a power of 2
781            let psi = x.mapv(|xi| {
782                let gaussian = norm * (-(xi - x0).powi(2) / (4.0 * sigma.powi(2))).exp();
783                let phase = k0 * xi;
784                Complex64::new(gaussian * phase.cos(), gaussian * phase.sin())
785            });
786            (x.clone(), psi)
787        };
788
789        let mut state = QuantumState::new(psi_final, x_final, 0.0, mass);
790        state.normalize();
791        state
792    }
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798    use approx::assert_relative_eq;
799
800    #[test]
801    fn test_harmonic_oscillator_ground_state() {
802        let potential = Box::new(HarmonicOscillator { k: 1.0, x0: 0.0 });
803        let solver = SchrodingerSolver::new(100, 0.01, potential, SchrodingerMethod::SplitOperator);
804
805        let (energies, wavefunctions) = solver
806            .solve_time_independent(-5.0, 5.0, 3)
807            .expect("Operation failed");
808
809        // Ground state energy should be ℏω/2 = 0.5 (with ℏ=1, ω=1)
810        assert_relative_eq!(energies[0], 0.5, epsilon = 0.01);
811
812        // First excited state should be 3ℏω/2 = 1.5
813        assert_relative_eq!(energies[1], 1.5, epsilon = 0.01);
814    }
815
816    #[test]
817    fn test_wave_packet_evolution() {
818        let potential = Box::new(HarmonicOscillator { k: 0.0, x0: 0.0 }); // Free particle
819        let solver =
820            SchrodingerSolver::new(200, 0.001, potential, SchrodingerMethod::SplitOperator);
821
822        let x = Array1::linspace(-10.0, 10.0, 200);
823        let initial_state = SchrodingerSolver::gaussian_wave_packet(&x, -5.0, 1.0, 2.0, 1.0);
824
825        let states = solver
826            .solve_time_dependent(&initial_state, 1.0)
827            .expect("Operation failed");
828
829        // Check normalization is preserved
830        for state in &states {
831            let norm_squared: f64 =
832                state.psi.iter().map(|&c| (c.conj() * c).re).sum::<f64>() * state.dx;
833            assert_relative_eq!(norm_squared, 1.0, epsilon = 1e-6);
834        }
835
836        // Wave packet should move to the right
837        let final_position = states
838            .last()
839            .expect("Operation failed")
840            .expectation_position();
841        assert!(final_position > -5.0);
842    }
843}