scirs2_integrate/specialized/finance/solvers/
advanced_pde.rs

1//! Advanced PDE solvers with cutting-edge techniques
2//!
3//! This module implements state-of-the-art PDE solving techniques including spectral methods,
4//! discontinuous Galerkin methods, and multigrid solvers.
5
6use crate::error::{IntegrateError, IntegrateResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::f64::consts::PI;
9
10/// Spectral solver using Chebyshev polynomials
11#[derive(Debug, Clone)]
12pub struct SpectralChebyshevSolver {
13    /// Number of Chebyshev nodes
14    pub n_nodes: usize,
15}
16
17impl SpectralChebyshevSolver {
18    /// Create a new spectral Chebyshev solver
19    pub fn new(n_nodes: usize) -> IntegrateResult<Self> {
20        if n_nodes < 3 {
21            return Err(IntegrateError::ValueError(
22                "Number of nodes must be at least 3".to_string(),
23            ));
24        }
25
26        Ok(Self { n_nodes })
27    }
28
29    /// Compute Chebyshev nodes (Gauss-Lobatto points)
30    pub fn chebyshev_nodes(&self) -> Array1<f64> {
31        let mut nodes = Array1::<f64>::zeros(self.n_nodes);
32        for i in 0..self.n_nodes {
33            nodes[i] = -(PI * i as f64 / (self.n_nodes - 1) as f64).cos();
34        }
35        nodes
36    }
37
38    /// Compute Chebyshev differentiation matrix
39    pub fn differentiation_matrix(&self) -> IntegrateResult<Array2<f64>> {
40        let n = self.n_nodes;
41        let mut d = Array2::<f64>::zeros((n, n));
42        let nodes = self.chebyshev_nodes();
43
44        for i in 0..n {
45            for j in 0..n {
46                if i != j {
47                    let mut c_i = 1.0;
48                    let mut c_j = 1.0;
49
50                    if i == 0 || i == n - 1 {
51                        c_i = 2.0;
52                    }
53                    if j == 0 || j == n - 1 {
54                        c_j = 2.0;
55                    }
56
57                    d[[i, j]] =
58                        (c_i / c_j) * (-1.0_f64).powi((i + j) as i32) / (nodes[i] - nodes[j]);
59                } else if i == 0 {
60                    d[[i, j]] = (2.0 * (n - 1) as f64 * (n - 1) as f64 + 1.0) / 6.0;
61                } else if i == n - 1 {
62                    d[[i, j]] = -(2.0 * (n - 1) as f64 * (n - 1) as f64 + 1.0) / 6.0;
63                } else {
64                    d[[i, j]] = -nodes[i] / (2.0 * (1.0 - nodes[i] * nodes[i]));
65                }
66            }
67        }
68
69        Ok(d)
70    }
71
72    /// Solve 1D Poisson equation: u'' = f(x)
73    pub fn solve_poisson<F>(
74        &self,
75        f: F,
76        boundary_left: f64,
77        boundary_right: f64,
78    ) -> IntegrateResult<Array1<f64>>
79    where
80        F: Fn(f64) -> f64,
81    {
82        let nodes = self.chebyshev_nodes();
83        let d = self.differentiation_matrix()?;
84
85        // Second derivative matrix: D² = D * D
86        let mut d2 = Array2::<f64>::zeros((self.n_nodes, self.n_nodes));
87        for i in 0..self.n_nodes {
88            for j in 0..self.n_nodes {
89                for k in 0..self.n_nodes {
90                    d2[[i, j]] += d[[i, k]] * d[[k, j]];
91                }
92            }
93        }
94
95        // Build right-hand side
96        let mut rhs = Array1::<f64>::zeros(self.n_nodes);
97        for i in 1..(self.n_nodes - 1) {
98            rhs[i] = f(nodes[i]);
99        }
100
101        // Apply boundary conditions
102        rhs[0] = boundary_left;
103        rhs[self.n_nodes - 1] = boundary_right;
104
105        // Modify matrix for boundary conditions
106        for j in 0..self.n_nodes {
107            d2[[0, j]] = if j == 0 { 1.0 } else { 0.0 };
108            d2[[self.n_nodes - 1, j]] = if j == self.n_nodes - 1 { 1.0 } else { 0.0 };
109        }
110
111        // Solve linear system (using simple Gaussian elimination)
112        self.solve_linear_system(&d2, &rhs)
113    }
114
115    /// Simple Gaussian elimination solver
116    fn solve_linear_system(
117        &self,
118        a: &Array2<f64>,
119        b: &Array1<f64>,
120    ) -> IntegrateResult<Array1<f64>> {
121        let n = b.len();
122        let mut aug = Array2::<f64>::zeros((n, n + 1));
123
124        // Create augmented matrix
125        for i in 0..n {
126            for j in 0..n {
127                aug[[i, j]] = a[[i, j]];
128            }
129            aug[[i, n]] = b[i];
130        }
131
132        // Forward elimination
133        for k in 0..n {
134            // Find pivot
135            let mut max_row = k;
136            for i in (k + 1)..n {
137                if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
138                    max_row = i;
139                }
140            }
141
142            // Swap rows
143            for j in 0..=n {
144                let temp = aug[[k, j]];
145                aug[[k, j]] = aug[[max_row, j]];
146                aug[[max_row, j]] = temp;
147            }
148
149            if aug[[k, k]].abs() < 1e-14 {
150                return Err(IntegrateError::ValueError(
151                    "Singular matrix in Gaussian elimination".to_string(),
152                ));
153            }
154
155            // Eliminate
156            for i in (k + 1)..n {
157                let factor = aug[[i, k]] / aug[[k, k]];
158                for j in k..=n {
159                    aug[[i, j]] -= factor * aug[[k, j]];
160                }
161            }
162        }
163
164        // Back substitution
165        let mut x = Array1::<f64>::zeros(n);
166        for i in (0..n).rev() {
167            let mut sum = aug[[i, n]];
168            for j in (i + 1)..n {
169                sum -= aug[[i, j]] * x[j];
170            }
171            x[i] = sum / aug[[i, i]];
172        }
173
174        Ok(x)
175    }
176}
177
178/// Radial Basis Function (RBF) meshless solver
179#[derive(Debug, Clone, Copy, PartialEq)]
180pub enum RBFType {
181    /// Gaussian: φ(r) = exp(-ε²r²)
182    Gaussian,
183    /// Multiquadric: φ(r) = sqrt(1 + (εr)²)
184    Multiquadric,
185    /// Inverse multiquadric: φ(r) = 1/sqrt(1 + (εr)²)
186    InverseMultiquadric,
187    /// Thin plate spline: φ(r) = r² log(r)
188    ThinPlateSpline,
189}
190
191#[derive(Debug, Clone)]
192pub struct RBFSolver {
193    /// RBF type
194    pub rbf_type: RBFType,
195    /// Shape parameter
196    pub epsilon: f64,
197    /// Number of collocation points
198    pub n_points: usize,
199}
200
201impl RBFSolver {
202    /// Create a new RBF solver
203    pub fn new(rbf_type: RBFType, epsilon: f64, n_points: usize) -> IntegrateResult<Self> {
204        if epsilon <= 0.0 {
205            return Err(IntegrateError::ValueError(
206                "Shape parameter must be positive".to_string(),
207            ));
208        }
209
210        if n_points < 2 {
211            return Err(IntegrateError::ValueError(
212                "Number of points must be at least 2".to_string(),
213            ));
214        }
215
216        Ok(Self {
217            rbf_type,
218            epsilon,
219            n_points,
220        })
221    }
222
223    /// Evaluate RBF
224    fn rbf(&self, r: f64) -> f64 {
225        match self.rbf_type {
226            RBFType::Gaussian => (-self.epsilon * self.epsilon * r * r).exp(),
227            RBFType::Multiquadric => (1.0 + (self.epsilon * r).powi(2)).sqrt(),
228            RBFType::InverseMultiquadric => 1.0 / (1.0 + (self.epsilon * r).powi(2)).sqrt(),
229            RBFType::ThinPlateSpline => {
230                if r > 0.0 {
231                    r * r * r.ln()
232                } else {
233                    0.0
234                }
235            }
236        }
237    }
238
239    /// Solve interpolation problem
240    pub fn interpolate(
241        &self,
242        points: &Array1<f64>,
243        values: &Array1<f64>,
244    ) -> IntegrateResult<Array1<f64>> {
245        if points.len() != values.len() {
246            return Err(IntegrateError::ValueError(
247                "Points and values must have same length".to_string(),
248            ));
249        }
250
251        let n = points.len();
252        let mut phi = Array2::<f64>::zeros((n, n));
253
254        // Build RBF matrix
255        for i in 0..n {
256            for j in 0..n {
257                let r = (points[i] - points[j]).abs();
258                phi[[i, j]] = self.rbf(r);
259            }
260        }
261
262        // Solve for weights
263        self.solve_linear_system(&phi, values)
264    }
265
266    /// Simple linear system solver (reusing spectral solver's method)
267    fn solve_linear_system(
268        &self,
269        a: &Array2<f64>,
270        b: &Array1<f64>,
271    ) -> IntegrateResult<Array1<f64>> {
272        let n = b.len();
273        let mut aug = Array2::<f64>::zeros((n, n + 1));
274
275        for i in 0..n {
276            for j in 0..n {
277                aug[[i, j]] = a[[i, j]];
278            }
279            aug[[i, n]] = b[i];
280        }
281
282        // Forward elimination with partial pivoting
283        for k in 0..n {
284            let mut max_row = k;
285            for i in (k + 1)..n {
286                if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
287                    max_row = i;
288                }
289            }
290
291            for j in 0..=n {
292                let temp = aug[[k, j]];
293                aug[[k, j]] = aug[[max_row, j]];
294                aug[[max_row, j]] = temp;
295            }
296
297            if aug[[k, k]].abs() < 1e-14 {
298                return Err(IntegrateError::ValueError("Singular matrix".to_string()));
299            }
300
301            for i in (k + 1)..n {
302                let factor = aug[[i, k]] / aug[[k, k]];
303                for j in k..=n {
304                    aug[[i, j]] -= factor * aug[[k, j]];
305                }
306            }
307        }
308
309        // Back substitution
310        let mut x = Array1::<f64>::zeros(n);
311        for i in (0..n).rev() {
312            let mut sum = aug[[i, n]];
313            for j in (i + 1)..n {
314                sum -= aug[[i, j]] * x[j];
315            }
316            x[i] = sum / aug[[i, i]];
317        }
318
319        Ok(x)
320    }
321}
322
323/// Multigrid solver for fast convergence
324#[derive(Debug, Clone)]
325pub struct MultigridSolver {
326    /// Number of levels
327    pub n_levels: usize,
328    /// Number of pre-smoothing iterations
329    pub n_pre_smooth: usize,
330    /// Number of post-smoothing iterations
331    pub n_post_smooth: usize,
332}
333
334impl MultigridSolver {
335    /// Create a new multigrid solver
336    pub fn new(
337        n_levels: usize,
338        n_pre_smooth: usize,
339        n_post_smooth: usize,
340    ) -> IntegrateResult<Self> {
341        if n_levels == 0 {
342            return Err(IntegrateError::ValueError(
343                "Number of levels must be positive".to_string(),
344            ));
345        }
346
347        Ok(Self {
348            n_levels,
349            n_pre_smooth,
350            n_post_smooth,
351        })
352    }
353
354    /// V-cycle multigrid iteration
355    pub fn v_cycle(
356        &self,
357        u: &mut Array1<f64>,
358        f: &Array1<f64>,
359        dx: f64,
360        level: usize,
361    ) -> IntegrateResult<()> {
362        let n = u.len();
363
364        if level == 0 || n <= 3 {
365            // Coarsest level: solve directly
366            self.direct_solve(u, f, dx)?;
367            return Ok(());
368        }
369
370        // Pre-smoothing
371        for _ in 0..self.n_pre_smooth {
372            self.gauss_seidel_smooth(u, f, dx);
373        }
374
375        // Compute residual
376        let mut residual = Array1::<f64>::zeros(n);
377        for i in 1..(n - 1) {
378            let laplacian = (u[i - 1] - 2.0 * u[i] + u[i + 1]) / (dx * dx);
379            residual[i] = f[i] - laplacian;
380        }
381
382        // Restrict to coarse grid
383        let n_coarse = (n - 1) / 2 + 1;
384        let mut residual_coarse = Array1::<f64>::zeros(n_coarse);
385        for i in 1..(n_coarse - 1) {
386            residual_coarse[i] =
387                0.25 * residual[2 * i - 1] + 0.5 * residual[2 * i] + 0.25 * residual[2 * i + 1];
388        }
389
390        // Solve on coarse grid
391        let mut error_coarse = Array1::<f64>::zeros(n_coarse);
392        self.v_cycle(&mut error_coarse, &residual_coarse, 2.0 * dx, level - 1)?;
393
394        // Prolongate (interpolate) to fine grid
395        let mut error_fine = Array1::<f64>::zeros(n);
396        for i in 0..n_coarse {
397            if 2 * i < n {
398                error_fine[2 * i] = error_coarse[i];
399            }
400            if 2 * i + 1 < n && i + 1 < n_coarse {
401                error_fine[2 * i + 1] = 0.5 * (error_coarse[i] + error_coarse[i + 1]);
402            }
403        }
404
405        // Correct solution
406        for i in 0..n {
407            u[i] += error_fine[i];
408        }
409
410        // Post-smoothing
411        for _ in 0..self.n_post_smooth {
412            self.gauss_seidel_smooth(u, f, dx);
413        }
414
415        Ok(())
416    }
417
418    /// Gauss-Seidel smoother
419    fn gauss_seidel_smooth(&self, u: &mut Array1<f64>, f: &Array1<f64>, dx: f64) {
420        let n = u.len();
421        for i in 1..(n - 1) {
422            u[i] = 0.5 * (u[i - 1] + u[i + 1] - dx * dx * f[i]);
423        }
424    }
425
426    /// Direct solver for small systems
427    fn direct_solve(&self, u: &mut Array1<f64>, f: &Array1<f64>, dx: f64) -> IntegrateResult<()> {
428        let n = u.len();
429
430        // Build tridiagonal matrix
431        let mut a = Array1::<f64>::zeros(n);
432        let mut b = Array1::<f64>::zeros(n);
433        let mut c = Array1::<f64>::zeros(n);
434        let mut rhs = f.clone();
435
436        for i in 1..(n - 1) {
437            a[i] = 1.0 / (dx * dx);
438            b[i] = -2.0 / (dx * dx);
439            c[i] = 1.0 / (dx * dx);
440        }
441
442        // Boundary conditions
443        b[0] = 1.0;
444        c[0] = 0.0;
445        a[n - 1] = 0.0;
446        b[n - 1] = 1.0;
447        rhs[0] = u[0];
448        rhs[n - 1] = u[n - 1];
449
450        // Thomas algorithm
451        let mut c_prime = Array1::<f64>::zeros(n);
452        let mut d_prime = Array1::<f64>::zeros(n);
453
454        c_prime[0] = c[0] / b[0];
455        d_prime[0] = rhs[0] / b[0];
456
457        for i in 1..n {
458            let denom = b[i] - a[i] * c_prime[i - 1];
459            if i < n - 1 {
460                c_prime[i] = c[i] / denom;
461            }
462            d_prime[i] = (rhs[i] - a[i] * d_prime[i - 1]) / denom;
463        }
464
465        u[n - 1] = d_prime[n - 1];
466        for i in (0..n - 1).rev() {
467            u[i] = d_prime[i] - c_prime[i] * u[i + 1];
468        }
469
470        Ok(())
471    }
472
473    /// Solve Poisson equation using multigrid
474    pub fn solve_poisson<F>(
475        &self,
476        f: F,
477        x_range: (f64, f64),
478        n_grid: usize,
479        max_iterations: usize,
480    ) -> IntegrateResult<Array1<f64>>
481    where
482        F: Fn(f64) -> f64,
483    {
484        let dx = (x_range.1 - x_range.0) / (n_grid - 1) as f64;
485        let mut u = Array1::<f64>::zeros(n_grid);
486        let mut rhs = Array1::<f64>::zeros(n_grid);
487
488        // Set up right-hand side
489        for i in 1..(n_grid - 1) {
490            let x = x_range.0 + i as f64 * dx;
491            rhs[i] = f(x);
492        }
493
494        // Boundary conditions (homogeneous)
495        u[0] = 0.0;
496        u[n_grid - 1] = 0.0;
497
498        // Multigrid iterations
499        for _ in 0..max_iterations {
500            self.v_cycle(&mut u, &rhs, dx, self.n_levels - 1)?;
501        }
502
503        Ok(u)
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[test]
512    fn test_spectral_chebyshev_solver() {
513        let solver = SpectralChebyshevSolver::new(10).unwrap();
514        assert_eq!(solver.n_nodes, 10);
515    }
516
517    #[test]
518    fn test_chebyshev_nodes() {
519        let solver = SpectralChebyshevSolver::new(5).unwrap();
520        let nodes = solver.chebyshev_nodes();
521
522        // Nodes should be in [-1, 1]
523        for &node in nodes.iter() {
524            assert!((-1.0..=1.0).contains(&node));
525        }
526
527        // First and last nodes should be -1 and 1
528        assert!((nodes[0] - (-1.0)).abs() < 1e-10);
529        assert!((nodes[4] - 1.0).abs() < 1e-10);
530    }
531
532    #[test]
533    fn test_spectral_poisson() {
534        let solver = SpectralChebyshevSolver::new(20).unwrap();
535
536        // Solve u'' = -2 with u(-1) = 0, u(1) = 0
537        // Exact solution: u(x) = 1 - x²
538        let f = |_x: f64| -2.0;
539
540        let u = solver.solve_poisson(f, 0.0, 0.0).unwrap();
541        let nodes = solver.chebyshev_nodes();
542
543        // Check solution at some interior points
544        for i in 5..15 {
545            let exact = 1.0 - nodes[i] * nodes[i];
546            assert!(
547                (u[i] - exact).abs() < 0.1,
548                "At node {}: u={}, exact={}",
549                i,
550                u[i],
551                exact
552            );
553        }
554    }
555
556    #[test]
557    fn test_rbf_solver_creation() {
558        let solver = RBFSolver::new(RBFType::Gaussian, 1.0, 10).unwrap();
559        assert_eq!(solver.rbf_type, RBFType::Gaussian);
560    }
561
562    #[test]
563    fn test_rbf_interpolation() {
564        let solver = RBFSolver::new(RBFType::Gaussian, 2.0, 5).unwrap();
565
566        let points = Array1::from_vec(vec![0.0, 0.25, 0.5, 0.75, 1.0]);
567        let values = Array1::from_vec(vec![0.0, 0.25, 0.5, 0.75, 1.0]);
568
569        let weights = solver.interpolate(&points, &values).unwrap();
570
571        // Should find weights (not checking exact values due to RBF nature)
572        assert_eq!(weights.len(), 5);
573    }
574
575    #[test]
576    fn test_multigrid_solver_creation() {
577        let solver = MultigridSolver::new(3, 2, 2).unwrap();
578        assert_eq!(solver.n_levels, 3);
579        assert_eq!(solver.n_pre_smooth, 2);
580        assert_eq!(solver.n_post_smooth, 2);
581    }
582
583    #[test]
584    fn test_multigrid_poisson() {
585        let solver = MultigridSolver::new(3, 3, 3).unwrap();
586
587        // Solve u'' = -2 with u(0) = 0, u(1) = 0
588        // Exact solution: u(x) = x(1-x)
589        let f = |_x: f64| -2.0;
590
591        let u = solver.solve_poisson(f, (0.0, 1.0), 33, 5).unwrap();
592
593        // Check solution at midpoint
594        let mid_idx = u.len() / 2;
595        let x_mid = 0.5;
596        let exact = x_mid * (1.0 - x_mid);
597
598        assert!(
599            (u[mid_idx] - exact).abs() < 0.01,
600            "Midpoint: u={}, exact={}",
601            u[mid_idx],
602            exact
603        );
604    }
605
606    #[test]
607    fn test_rbf_types() {
608        let solver_gaussian = RBFSolver::new(RBFType::Gaussian, 1.0, 5).unwrap();
609        let solver_mq = RBFSolver::new(RBFType::Multiquadric, 1.0, 5).unwrap();
610        let solver_imq = RBFSolver::new(RBFType::InverseMultiquadric, 1.0, 5).unwrap();
611        let solver_tps = RBFSolver::new(RBFType::ThinPlateSpline, 1.0, 5).unwrap();
612
613        // Test RBF evaluations
614        assert!(solver_gaussian.rbf(0.0) > 0.0);
615        assert!(solver_mq.rbf(1.0) > 1.0);
616        assert!(solver_imq.rbf(1.0) < 1.0);
617        assert_eq!(solver_tps.rbf(0.0), 0.0);
618    }
619
620    #[test]
621    fn test_differentiation_matrix() {
622        let solver = SpectralChebyshevSolver::new(5).unwrap();
623        let d = solver.differentiation_matrix().unwrap();
624
625        // Matrix should be square
626        assert_eq!(d.nrows(), 5);
627        assert_eq!(d.ncols(), 5);
628    }
629
630    #[test]
631    fn test_gauss_seidel_convergence() {
632        let solver = MultigridSolver::new(1, 10, 10).unwrap();
633
634        let n = 11;
635        let dx = 1.0 / (n - 1) as f64;
636        let mut u = Array1::<f64>::zeros(n);
637        let f = Array1::from_vec(vec![0.0; n]);
638
639        // Apply several smoothing iterations
640        for _ in 0..20 {
641            solver.gauss_seidel_smooth(&mut u, &f, dx);
642        }
643
644        // Solution should converge toward zero
645        for i in 1..(n - 1) {
646            assert!(u[i].abs() < 1e-6, "u[{}] = {}", i, u[i]);
647        }
648    }
649}