scirs2_integrate/specialized/quantum/
core.rs

1//! Core quantum mechanics components
2//!
3//! This module provides basic quantum state representations, potentials,
4//! and Schrödinger equation solvers.
5
6use crate::error::{IntegrateError, IntegrateResult as Result};
7use scirs2_core::constants::{PI, REDUCED_PLANCK};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::simd_ops::SimdUnifiedOps;
11
12/// Quantum state representation
13#[derive(Debug, Clone)]
14pub struct QuantumState {
15    /// Wave function values (complex)
16    pub psi: Array1<Complex64>,
17    /// Spatial grid points
18    pub x: Array1<f64>,
19    /// Time
20    pub t: f64,
21    /// Mass of the particle
22    pub mass: f64,
23    /// Spatial step size
24    pub dx: f64,
25}
26
27impl QuantumState {
28    /// Create a new quantum state
29    pub fn new(psi: Array1<Complex64>, x: Array1<f64>, t: f64, mass: f64) -> Self {
30        let dx = if x.len() > 1 { x[1] - x[0] } else { 1.0 };
31
32        Self {
33            psi,
34            x,
35            t,
36            mass,
37            dx,
38        }
39    }
40
41    /// Normalize the wave function
42    pub fn normalize(&mut self) {
43        let norm_squared: f64 = self.psi.iter().map(|&c| (c.conj() * c).re).sum::<f64>() * self.dx;
44
45        let norm = norm_squared.sqrt();
46        if norm > 0.0 {
47            self.psi.mapv_inplace(|c| c / norm);
48        }
49    }
50
51    /// Calculate expectation value of position
52    pub fn expectation_position(&self) -> f64 {
53        self.expectation_position_simd()
54    }
55
56    /// SIMD-optimized expectation value of position
57    pub fn expectation_position_simd(&self) -> f64 {
58        let prob_density = self.probability_density_simd();
59        f64::simd_dot(&self.x.view(), &prob_density.view()) * self.dx
60    }
61
62    /// Fallback scalar implementation for expectation value of position
63    pub fn expectation_position_scalar(&self) -> f64 {
64        self.x
65            .iter()
66            .zip(self.psi.iter())
67            .map(|(&x, &psi)| x * (psi.conj() * psi).re)
68            .sum::<f64>()
69            * self.dx
70    }
71
72    /// Calculate expectation value of momentum
73    pub fn expectation_momentum(&self) -> f64 {
74        let n = self.psi.len();
75        let mut momentum = 0.0;
76
77        // Central difference for derivative
78        for i in 1..n - 1 {
79            let dpsi_dx = (self.psi[i + 1] - self.psi[i - 1]) / (2.0 * self.dx);
80            momentum += (self.psi[i].conj() * Complex64::new(0.0, -REDUCED_PLANCK) * dpsi_dx).re;
81        }
82
83        momentum * self.dx
84    }
85
86    /// Calculate probability density
87    pub fn probability_density(&self) -> Array1<f64> {
88        self.probability_density_simd()
89    }
90
91    /// SIMD-optimized probability density calculation
92    pub fn probability_density_simd(&self) -> Array1<f64> {
93        // Convert complex numbers to real and imaginary parts for SIMD processing
94        let real_parts: Array1<f64> = self.psi.mapv(|c| c.re);
95        let imag_parts: Array1<f64> = self.psi.mapv(|c| c.im);
96
97        // Calculate |psi|^2 = Re(psi)^2 + Im(psi)^2 using SIMD
98        let real_squared = f64::simd_mul(&real_parts.view(), &real_parts.view());
99        let imag_squared = f64::simd_mul(&imag_parts.view(), &imag_parts.view());
100        let result = f64::simd_add(&real_squared.view(), &imag_squared.view());
101
102        result
103    }
104
105    /// Fallback scalar implementation for probability density
106    pub fn probability_density_scalar(&self) -> Array1<f64> {
107        self.psi.mapv(|c| (c.conj() * c).re)
108    }
109}
110
111/// Quantum potential trait
112pub trait QuantumPotential: Send + Sync {
113    /// Evaluate potential at given position
114    fn evaluate(&self, x: f64) -> f64;
115
116    /// Evaluate potential for array of positions
117    fn evaluate_array(&self, x: &ArrayView1<f64>) -> Array1<f64> {
118        x.mapv(|xi| self.evaluate(xi))
119    }
120}
121
122/// Harmonic oscillator potential
123#[derive(Debug, Clone)]
124pub struct HarmonicOscillator {
125    /// Spring constant
126    pub k: f64,
127    /// Center position
128    pub x0: f64,
129}
130
131impl QuantumPotential for HarmonicOscillator {
132    fn evaluate(&self, x: f64) -> f64 {
133        0.5 * self.k * (x - self.x0).powi(2)
134    }
135}
136
137/// Particle in a box potential
138#[derive(Debug, Clone)]
139pub struct ParticleInBox {
140    /// Left boundary
141    pub left: f64,
142    /// Right boundary
143    pub right: f64,
144    /// Barrier height
145    pub barrier_height: f64,
146}
147
148impl QuantumPotential for ParticleInBox {
149    fn evaluate(&self, x: f64) -> f64 {
150        if x < self.left || x > self.right {
151            self.barrier_height
152        } else {
153            0.0
154        }
155    }
156}
157
158/// Hydrogen-like atom potential
159#[derive(Debug, Clone)]
160pub struct HydrogenAtom {
161    /// Nuclear charge
162    pub z: f64,
163    /// Electron charge squared / (4π ε₀)
164    pub e2_4pi_eps0: f64,
165}
166
167impl QuantumPotential for HydrogenAtom {
168    fn evaluate(&self, r: f64) -> f64 {
169        if r > 0.0 {
170            -self.z * self.e2_4pi_eps0 / r
171        } else {
172            f64::NEG_INFINITY
173        }
174    }
175}
176
177/// Solver for the Schrödinger equation
178pub struct SchrodingerSolver {
179    /// Spatial grid size
180    pub n_points: usize,
181    /// Time step size
182    pub dt: f64,
183    /// Potential function
184    pub potential: Box<dyn QuantumPotential>,
185    /// Solver method
186    pub method: SchrodingerMethod,
187}
188
189/// Available methods for solving the Schrödinger equation
190#[derive(Debug, Clone, Copy)]
191pub enum SchrodingerMethod {
192    /// Split-operator method (fast and accurate)
193    SplitOperator,
194    /// Crank-Nicolson method (implicit, stable)
195    CrankNicolson,
196    /// Explicit Euler (simple but less stable)
197    ExplicitEuler,
198    /// Fourth-order Runge-Kutta
199    RungeKutta4,
200}
201
202impl SchrodingerSolver {
203    /// Create a new Schrödinger solver
204    pub fn new(
205        n_points: usize,
206        dt: f64,
207        potential: Box<dyn QuantumPotential>,
208        method: SchrodingerMethod,
209    ) -> Self {
210        Self {
211            n_points,
212            dt,
213            potential,
214            method,
215        }
216    }
217
218    /// Solve time-dependent Schrödinger equation
219    pub fn solve_time_dependent(
220        &self,
221        initial_state: &QuantumState,
222        t_final: f64,
223    ) -> Result<Vec<QuantumState>> {
224        let mut states = vec![initial_state.clone()];
225        let mut current_state = initial_state.clone();
226
227        // Ensure x and psi have consistent lengths
228        if current_state.x.len() != current_state.psi.len() {
229            // Resize x to match psi if they differ (e.g., due to FFT padding requirements)
230            let n = current_state.psi.len();
231            let x_min = current_state.x[0];
232            let x_max = current_state.x[current_state.x.len() - 1];
233            current_state.x = Array1::linspace(x_min, x_max, n);
234            current_state.dx = (x_max - x_min) / (n - 1) as f64;
235        }
236
237        let n_steps = (t_final / self.dt).ceil() as usize;
238
239        match self.method {
240            SchrodingerMethod::SplitOperator => {
241                for _ in 0..n_steps {
242                    self.split_operator_step(&mut current_state)?;
243                    current_state.t += self.dt;
244                    states.push(current_state.clone());
245                }
246            }
247            SchrodingerMethod::CrankNicolson => {
248                for _ in 0..n_steps {
249                    self.crank_nicolson_step(&mut current_state)?;
250                    current_state.t += self.dt;
251                    states.push(current_state.clone());
252                }
253            }
254            SchrodingerMethod::ExplicitEuler => {
255                for _ in 0..n_steps {
256                    self.explicit_euler_step(&mut current_state)?;
257                    current_state.t += self.dt;
258                    states.push(current_state.clone());
259                }
260            }
261            SchrodingerMethod::RungeKutta4 => {
262                for _ in 0..n_steps {
263                    self.runge_kutta4_step(&mut current_state)?;
264                    current_state.t += self.dt;
265                    states.push(current_state.clone());
266                }
267            }
268        }
269
270        Ok(states)
271    }
272
273    /// Split-operator method step
274    fn split_operator_step(&self, state: &mut QuantumState) -> Result<()> {
275        use scirs2_fft::{fft, ifft};
276
277        // Ensure x and psi have the same length before proceeding
278        if state.x.len() != state.psi.len() {
279            // This shouldn't happen, but handle it gracefully
280            let n = state.psi.len().min(state.x.len());
281            if state.psi.len() > n {
282                state.psi = state.psi.slice(scirs2_core::ndarray::s![..n]).to_owned();
283            }
284            if state.x.len() > n {
285                state.x = state.x.slice(scirs2_core::ndarray::s![..n]).to_owned();
286            }
287        }
288
289        let n = state.psi.len();
290
291        // Potential energy evolution (half step)
292        let v = self.potential.evaluate_array(&state.x.view());
293
294        for i in 0..n {
295            let phase = -v[i] * self.dt / (2.0 * REDUCED_PLANCK);
296            state.psi[i] *= Complex64::new(phase.cos(), phase.sin());
297        }
298
299        // Kinetic energy evolution in momentum space using FFT
300        // Transform to momentum space
301        let psi_k = fft(&state.psi.to_vec(), None).map_err(|e| {
302            crate::error::IntegrateError::ComputationError(format!("FFT failed: {e:?}"))
303        })?;
304
305        // Calculate k-space grid (momentum values)
306        let dk = 2.0 * PI / (n as f64 * state.dx);
307        let mut k_values = vec![0.0; n];
308        for (i, k_value) in k_values.iter_mut().enumerate().take(n) {
309            if i < n / 2 {
310                *k_value = i as f64 * dk;
311            } else {
312                *k_value = (i as f64 - n as f64) * dk;
313            }
314        }
315
316        // Apply kinetic energy operator in momentum space
317        let mut psi_k_evolved = psi_k;
318        for i in 0..n {
319            let k = k_values[i];
320            let kinetic_phase = -REDUCED_PLANCK * k * k * self.dt / (2.0 * state.mass);
321            psi_k_evolved[i] *= Complex64::new(kinetic_phase.cos(), kinetic_phase.sin());
322        }
323
324        // Transform back to position space
325        let psi_evolved = ifft(&psi_k_evolved, None).map_err(|e| {
326            crate::error::IntegrateError::ComputationError(format!("IFFT failed: {e:?}"))
327        })?;
328
329        // Update state with evolved wave function
330        // Ensure we preserve the original size (FFT might have padded)
331        let psi_vec = if psi_evolved.len() != n {
332            psi_evolved[..n].to_vec()
333        } else {
334            psi_evolved
335        };
336        state.psi = Array1::from_vec(psi_vec);
337
338        // Potential energy evolution (half step)
339        for i in 0..n {
340            let phase = -v[i] * self.dt / (2.0 * REDUCED_PLANCK);
341            state.psi[i] *= Complex64::new(phase.cos(), phase.sin());
342        }
343
344        // Normalize to conserve probability
345        state.normalize();
346
347        Ok(())
348    }
349
350    /// Crank-Nicolson method step
351    fn crank_nicolson_step(&self, state: &mut QuantumState) -> Result<()> {
352        let n = state.psi.len();
353        let alpha = Complex64::new(
354            0.0,
355            REDUCED_PLANCK * self.dt / (4.0 * state.mass * state.dx.powi(2)),
356        );
357
358        // Build tridiagonal matrices
359        let v = self.potential.evaluate_array(&state.x.view());
360        let mut a = vec![Complex64::new(0.0, 0.0); n];
361        let mut b = vec![Complex64::new(0.0, 0.0); n];
362        let mut c = vec![Complex64::new(0.0, 0.0); n];
363
364        for i in 0..n {
365            let v_term = Complex64::new(0.0, -v[i] * self.dt / (2.0 * REDUCED_PLANCK));
366            b[i] = Complex64::new(1.0, 0.0) + 2.0 * alpha - v_term;
367
368            if i > 0 {
369                a[i] = -alpha;
370            }
371            if i < n - 1 {
372                c[i] = -alpha;
373            }
374        }
375
376        // Build right-hand side
377        let mut rhs = vec![Complex64::new(0.0, 0.0); n];
378        for i in 0..n {
379            let v_term = Complex64::new(0.0, v[i] * self.dt / (2.0 * REDUCED_PLANCK));
380            rhs[i] = state.psi[i] * (Complex64::new(1.0, 0.0) - 2.0 * alpha + v_term);
381
382            if i > 0 {
383                rhs[i] += alpha * state.psi[i - 1];
384            }
385            if i < n - 1 {
386                rhs[i] += alpha * state.psi[i + 1];
387            }
388        }
389
390        // Solve tridiagonal system using Thomas algorithm
391        let new_psi = self.solve_tridiagonal(&a, &b, &c, &rhs)?;
392        state.psi = Array1::from_vec(new_psi);
393
394        // Normalize
395        state.normalize();
396
397        Ok(())
398    }
399
400    /// Explicit Euler method step
401    fn explicit_euler_step(&self, state: &mut QuantumState) -> Result<()> {
402        let n = state.psi.len();
403        let mut dpsi_dt = Array1::zeros(n);
404
405        // Calculate time derivative using Schrödinger equation
406        let v = self.potential.evaluate_array(&state.x.view());
407        let prefactor = Complex64::new(0.0, -1.0 / REDUCED_PLANCK);
408
409        for i in 0..n {
410            // Kinetic energy term (second derivative)
411            let d2psi_dx2 = if i == 0 {
412                state.psi[1] - 2.0 * state.psi[0] + state.psi[0]
413            } else if i == n - 1 {
414                state.psi[n - 1] - 2.0 * state.psi[n - 1] + state.psi[n - 2]
415            } else {
416                state.psi[i + 1] - 2.0 * state.psi[i] + state.psi[i - 1]
417            } / state.dx.powi(2);
418
419            // Hamiltonian action
420            let h_psi =
421                -REDUCED_PLANCK.powi(2) / (2.0 * state.mass) * d2psi_dx2 + v[i] * state.psi[i];
422
423            dpsi_dt[i] = prefactor * h_psi;
424        }
425
426        // Update wave function
427        state.psi += &(dpsi_dt * self.dt);
428
429        // Normalize
430        state.normalize();
431
432        Ok(())
433    }
434
435    /// Fourth-order Runge-Kutta method step
436    fn runge_kutta4_step(&self, state: &mut QuantumState) -> Result<()> {
437        let n = state.psi.len();
438        let v = self.potential.evaluate_array(&state.x.view());
439
440        // Helper function to compute derivative
441        let compute_derivative = |psi: &Array1<Complex64>| -> Array1<Complex64> {
442            let mut dpsi = Array1::zeros(n);
443            let prefactor = Complex64::new(0.0, -1.0 / REDUCED_PLANCK);
444
445            for i in 0..n {
446                let d2psi_dx2 = if i == 0 {
447                    psi[1] - 2.0 * psi[0] + psi[0]
448                } else if i == n - 1 {
449                    psi[n - 1] - 2.0 * psi[n - 1] + psi[n - 2]
450                } else {
451                    psi[i + 1] - 2.0 * psi[i] + psi[i - 1]
452                } / state.dx.powi(2);
453
454                let h_psi =
455                    -REDUCED_PLANCK.powi(2) / (2.0 * state.mass) * d2psi_dx2 + v[i] * psi[i];
456
457                dpsi[i] = prefactor * h_psi;
458            }
459            dpsi
460        };
461
462        // RK4 steps
463        let k1 = compute_derivative(&state.psi);
464        let k2 = compute_derivative(&(&state.psi + &k1 * (self.dt / 2.0)));
465        let k3 = compute_derivative(&(&state.psi + &k2 * (self.dt / 2.0)));
466        let k4 = compute_derivative(&(&state.psi + &k3 * self.dt));
467
468        // Update
469        state.psi += &((k1 + k2 * 2.0 + k3 * 2.0 + k4) * (self.dt / 6.0));
470
471        // Normalize
472        state.normalize();
473
474        Ok(())
475    }
476
477    /// Solve tridiagonal system using Thomas algorithm
478    fn solve_tridiagonal(
479        &self,
480        a: &[Complex64],
481        b: &[Complex64],
482        c: &[Complex64],
483        d: &[Complex64],
484    ) -> Result<Vec<Complex64>> {
485        let n = b.len();
486        let mut c_star = vec![Complex64::new(0.0, 0.0); n];
487        let mut d_star = vec![Complex64::new(0.0, 0.0); n];
488        let mut x = vec![Complex64::new(0.0, 0.0); n];
489
490        // Forward sweep
491        c_star[0] = c[0] / b[0];
492        d_star[0] = d[0] / b[0];
493
494        for i in 1..n {
495            let m = b[i] - a[i] * c_star[i - 1];
496            c_star[i] = c[i] / m;
497            d_star[i] = (d[i] - a[i] * d_star[i - 1]) / m;
498        }
499
500        // Back substitution
501        x[n - 1] = d_star[n - 1];
502        for i in (0..n - 1).rev() {
503            x[i] = d_star[i] - c_star[i] * x[i + 1];
504        }
505
506        Ok(x)
507    }
508
509    /// Solve time-independent Schrödinger equation (eigenvalue problem)
510    pub fn solve_time_independent(
511        &self,
512        x_min: f64,
513        x_max: f64,
514        n_states: usize,
515    ) -> Result<(Array1<f64>, Array2<f64>)> {
516        let dx = (x_max - x_min) / (self.n_points - 1) as f64;
517        let x = Array1::linspace(x_min, x_max, self.n_points);
518
519        // Build the Hamiltonian matrix using finite difference method
520        let mut hamiltonian = Array2::<f64>::zeros((self.n_points, self.n_points));
521
522        // Kinetic energy contribution (second derivative via finite differences)
523        // The kinetic energy operator is -ℏ²/(2m) d²/dx²
524        // Using finite differences: d²ψ/dx² ≈ (ψ[i+1] - 2ψ[i] + ψ[i-1])/dx²
525        // So the matrix elements are: T[i,i] = ℏ²/(m*dx²), T[i,i±1] = -ℏ²/(2m*dx²)
526        let kinetic_factor = REDUCED_PLANCK.powi(2) / (2.0 * 1.0 * dx.powi(2)); // mass = 1.0 for simplicity
527
528        // Build tridiagonal kinetic energy matrix
529        for i in 0..self.n_points {
530            if i > 0 {
531                hamiltonian[[i, i - 1]] = -kinetic_factor;
532            }
533            hamiltonian[[i, i]] = 2.0 * kinetic_factor;
534            if i < self.n_points - 1 {
535                hamiltonian[[i, i + 1]] = -kinetic_factor;
536            }
537        }
538
539        // Add potential energy contribution (diagonal)
540        let v = self.potential.evaluate_array(&x.view());
541        for i in 0..self.n_points {
542            hamiltonian[[i, i]] += v[i];
543        }
544
545        // Apply boundary conditions (wave function vanishes at boundaries)
546        // For Dirichlet boundary conditions, we can work with a reduced system
547        // excluding the boundary points, but for simplicity, we'll keep them
548        // and set the first and last rows to enforce ψ(0) = ψ(L) = 0
549        // by making the boundary points have very high energy
550        hamiltonian.row_mut(0).fill(0.0);
551        hamiltonian[[0, 0]] = 1e6; // Large value to push this state to high energy
552        hamiltonian.row_mut(self.n_points - 1).fill(0.0);
553        hamiltonian[[self.n_points - 1, self.n_points - 1]] = 1e6;
554
555        // Find eigenvalues and eigenvectors using a simple power iteration method
556        // for the lowest n_states eigenpairs
557        let mut energies = Array1::zeros(n_states);
558        let mut wavefunctions = Array2::zeros((self.n_points, n_states));
559
560        // Use inverse power iteration with shifts to find lowest eigenvalues
561        for state in 0..n_states {
562            let mut psi = Array1::from_elem(self.n_points, 1.0);
563            psi[0] = 0.0;
564            psi[self.n_points - 1] = 0.0;
565
566            // Normalize initial guess
567            let norm: f64 = psi.iter().map(|&x| x * x * dx).sum::<f64>().sqrt();
568            psi /= norm;
569
570            // Gram-Schmidt orthogonalization against previous eigenstates
571            for j in 0..state {
572                let overlap: f64 = psi
573                    .iter()
574                    .zip(wavefunctions.column(j).iter())
575                    .map(|(&a, &b)| a * b * dx)
576                    .sum();
577                for i in 0..self.n_points {
578                    psi[i] -= overlap * wavefunctions[[i, j]];
579                }
580            }
581
582            // Power iteration to find eigenvalue
583            let mut eigenvalue = 0.0;
584            for _ in 0..100 {
585                // iterations
586                // Apply Hamiltonian
587                let mut h_psi = Array1::zeros(self.n_points);
588                for i in 1..self.n_points - 1 {
589                    h_psi[i] = hamiltonian[[i, i]] * psi[i];
590                    if i > 0 {
591                        h_psi[i] += hamiltonian[[i, i - 1]] * psi[i - 1];
592                    }
593                    if i < self.n_points - 1 {
594                        h_psi[i] += hamiltonian[[i, i + 1]] * psi[i + 1];
595                    }
596                }
597
598                // Calculate eigenvalue estimate
599                eigenvalue = psi
600                    .iter()
601                    .zip(h_psi.iter())
602                    .map(|(&a, &b)| a * b * dx)
603                    .sum::<f64>();
604
605                // Update eigenvector
606                psi = h_psi;
607
608                // Orthogonalize against previous states
609                for j in 0..state {
610                    let overlap: f64 = psi
611                        .iter()
612                        .zip(wavefunctions.column(j).iter())
613                        .map(|(&a, &b)| a * b * dx)
614                        .sum();
615                    for i in 0..self.n_points {
616                        psi[i] -= overlap * wavefunctions[[i, j]];
617                    }
618                }
619
620                // Normalize
621                let norm: f64 = psi.iter().map(|&x| x * x * dx).sum::<f64>().sqrt();
622                if norm > 1e-10 {
623                    psi /= norm;
624                }
625            }
626
627            energies[state] = eigenvalue;
628            wavefunctions.column_mut(state).assign(&psi);
629        }
630
631        // Sort by energy
632        let mut indices: Vec<usize> = (0..n_states).collect();
633        indices.sort_by(|&i, &j| energies[i].partial_cmp(&energies[j]).unwrap());
634
635        let sorted_energies = Array1::from_vec(indices.iter().map(|&i| energies[i]).collect());
636        let mut sorted_wavefunctions = Array2::zeros((self.n_points, n_states));
637        for (new_idx, &old_idx) in indices.iter().enumerate() {
638            sorted_wavefunctions
639                .column_mut(new_idx)
640                .assign(&wavefunctions.column(old_idx));
641        }
642
643        Ok((sorted_energies, sorted_wavefunctions))
644    }
645
646    /// Create initial Gaussian wave packet
647    pub fn gaussian_wave_packet(
648        x: &Array1<f64>,
649        x0: f64,
650        sigma: f64,
651        k0: f64,
652        mass: f64,
653    ) -> QuantumState {
654        let norm = 1.0 / (2.0 * PI * sigma.powi(2)).powf(0.25);
655
656        // For FFT efficiency, ensure we use a power of 2 size
657        let original_n = x.len();
658        let fft_n = original_n.next_power_of_two();
659
660        // Create arrays with appropriate size
661        let (x_final, psi_final) = if fft_n != original_n {
662            // Need to pad to power of 2
663            let x_min = x[0];
664            let x_max = x[original_n - 1];
665            let x_padded = Array1::linspace(x_min, x_max, fft_n);
666
667            let psi_padded = x_padded.mapv(|xi| {
668                let gaussian = norm * (-(xi - x0).powi(2) / (4.0 * sigma.powi(2))).exp();
669                let phase = k0 * xi;
670                Complex64::new(gaussian * phase.cos(), gaussian * phase.sin())
671            });
672
673            (x_padded, psi_padded)
674        } else {
675            // Already a power of 2
676            let psi = x.mapv(|xi| {
677                let gaussian = norm * (-(xi - x0).powi(2) / (4.0 * sigma.powi(2))).exp();
678                let phase = k0 * xi;
679                Complex64::new(gaussian * phase.cos(), gaussian * phase.sin())
680            });
681            (x.clone(), psi)
682        };
683
684        let mut state = QuantumState::new(psi_final, x_final, 0.0, mass);
685        state.normalize();
686        state
687    }
688}
689
690#[cfg(test)]
691mod tests {
692    use super::*;
693    use approx::assert_relative_eq;
694
695    #[test]
696    #[ignore] // FIXME: Harmonic oscillator ground state test failing
697    fn test_harmonic_oscillator_ground_state() {
698        let potential = Box::new(HarmonicOscillator { k: 1.0, x0: 0.0 });
699        let solver = SchrodingerSolver::new(100, 0.01, potential, SchrodingerMethod::SplitOperator);
700
701        let (energies, wavefunctions) = solver.solve_time_independent(-5.0, 5.0, 3).unwrap();
702
703        // Ground state energy should be ℏω/2 = 0.5 (with ℏ=1, ω=1)
704        assert_relative_eq!(energies[0], 0.5, epsilon = 0.01);
705
706        // First excited state should be 3ℏω/2 = 1.5
707        assert_relative_eq!(energies[1], 1.5, epsilon = 0.01);
708    }
709
710    #[test]
711    #[ignore = "timeout"]
712    fn test_wave_packet_evolution() {
713        let potential = Box::new(HarmonicOscillator { k: 0.0, x0: 0.0 }); // Free particle
714        let solver =
715            SchrodingerSolver::new(200, 0.001, potential, SchrodingerMethod::SplitOperator);
716
717        let x = Array1::linspace(-10.0, 10.0, 200);
718        let initial_state = SchrodingerSolver::gaussian_wave_packet(&x, -5.0, 1.0, 2.0, 1.0);
719
720        let states = solver.solve_time_dependent(&initial_state, 1.0).unwrap();
721
722        // Check normalization is preserved
723        for state in &states {
724            let norm_squared: f64 =
725                state.psi.iter().map(|&c| (c.conj() * c).re).sum::<f64>() * state.dx;
726            assert_relative_eq!(norm_squared, 1.0, epsilon = 1e-6);
727        }
728
729        // Wave packet should move to the right
730        let final_position = states.last().unwrap().expectation_position();
731        assert!(final_position > -5.0);
732    }
733}