scirs2_integrate/
bvp_extended.rs

1//! Extended Boundary Value Problem solvers with Robin and mixed boundary conditions
2//!
3//! This module extends the basic BVP solver to support more general boundary
4//! conditions including Robin conditions (a*u + b*u' = c) and complex mixed
5//! boundary conditions.
6
7use crate::bvp::{BVPOptions, BVPResult};
8use crate::common::IntegrateFloat;
9use crate::error::{IntegrateError, IntegrateResult};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
11
12/// Boundary condition types for extended BVP solver
13#[derive(Debug, Clone)]
14pub enum BoundaryConditionType<F: IntegrateFloat> {
15    /// Dirichlet: u = value
16    Dirichlet { value: F },
17    /// Neumann: u' = value  
18    Neumann { value: F },
19    /// Robin: a*u + b*u' = c
20    Robin { a: F, b: F, c: F },
21    /// Periodic: u(a) = u(b), u'(a) = u'(b)
22    Periodic,
23}
24
25impl<F: IntegrateFloat> BoundaryConditionType<F> {
26    /// Evaluate the boundary condition residual
27    pub fn evaluate_residual(
28        &self,
29        x: F,
30        y: ArrayView1<F>,
31        dydt: ArrayView1<F>,
32        component: usize,
33    ) -> F {
34        match self {
35            BoundaryConditionType::Dirichlet { value } => y[component] - *value,
36            BoundaryConditionType::Neumann { value } => dydt[component] - *value,
37            BoundaryConditionType::Robin { a, b, c } => {
38                *a * y[component] + *b * dydt[component] - *c
39            }
40            BoundaryConditionType::Periodic => {
41                // This is handled specially in the solver
42                F::zero()
43            }
44        }
45    }
46
47    /// Get derivative of residual with respect to y[component]
48    pub fn derivative_y(&self, component: usize) -> F {
49        match self {
50            BoundaryConditionType::Dirichlet { .. } => F::one(),
51            BoundaryConditionType::Neumann { .. } => F::zero(),
52            BoundaryConditionType::Robin { a, .. } => *a,
53            BoundaryConditionType::Periodic => F::one(),
54        }
55    }
56
57    /// Get derivative of residual with respect to dydt[component]  
58    pub fn derivative_dydt(&self, component: usize) -> F {
59        match self {
60            BoundaryConditionType::Dirichlet { .. } => F::zero(),
61            BoundaryConditionType::Neumann { .. } => F::one(),
62            BoundaryConditionType::Robin { b, .. } => *b,
63            BoundaryConditionType::Periodic => F::zero(),
64        }
65    }
66}
67
68/// Extended boundary condition specification
69#[derive(Debug, Clone)]
70pub struct ExtendedBoundaryConditions<F: IntegrateFloat> {
71    /// Boundary conditions at left endpoint (x = a)
72    pub left: Vec<BoundaryConditionType<F>>,
73    /// Boundary conditions at right endpoint (x = b)  
74    pub right: Vec<BoundaryConditionType<F>>,
75    /// Whether the problem has periodic boundary conditions
76    pub is_periodic: bool,
77}
78
79impl<F: IntegrateFloat> ExtendedBoundaryConditions<F> {
80    /// Create Dirichlet boundary conditions
81    pub fn dirichlet(_left_values: Vec<F>, rightvalues: Vec<F>) -> Self {
82        let left = _left_values
83            .into_iter()
84            .map(|value| BoundaryConditionType::Dirichlet { value })
85            .collect();
86
87        let right = rightvalues
88            .into_iter()
89            .map(|value| BoundaryConditionType::Dirichlet { value })
90            .collect();
91
92        Self {
93            left,
94            right,
95            is_periodic: false,
96        }
97    }
98
99    /// Create Neumann boundary conditions
100    pub fn neumann(_left_values: Vec<F>, rightvalues: Vec<F>) -> Self {
101        let left = _left_values
102            .into_iter()
103            .map(|value| BoundaryConditionType::Neumann { value })
104            .collect();
105
106        let right = rightvalues
107            .into_iter()
108            .map(|value| BoundaryConditionType::Neumann { value })
109            .collect();
110
111        Self {
112            left,
113            right,
114            is_periodic: false,
115        }
116    }
117
118    /// Create Robin boundary conditions: a*u + b*u' = c
119    pub fn robin(
120        left_coeffs: Vec<(F, F, F)>, // (a, b, c) for each component
121        right_coeffs: Vec<(F, F, F)>,
122    ) -> Self {
123        let left = left_coeffs
124            .into_iter()
125            .map(|(a, b, c)| BoundaryConditionType::Robin { a, b, c })
126            .collect();
127
128        let right = right_coeffs
129            .into_iter()
130            .map(|(a, b, c)| BoundaryConditionType::Robin { a, b, c })
131            .collect();
132
133        Self {
134            left,
135            right,
136            is_periodic: false,
137        }
138    }
139
140    /// Create periodic boundary conditions
141    pub fn periodic(dimension: usize) -> Self {
142        let condition = BoundaryConditionType::Periodic;
143        Self {
144            left: vec![condition.clone(); dimension],
145            right: vec![condition; dimension],
146            is_periodic: true,
147        }
148    }
149
150    /// Create mixed boundary conditions
151    pub fn mixed(
152        left: Vec<BoundaryConditionType<F>>,
153        right: Vec<BoundaryConditionType<F>>,
154    ) -> Self {
155        Self {
156            left,
157            right,
158            is_periodic: false,
159        }
160    }
161}
162
163/// Solve BVP with extended boundary condition support
164#[allow(dead_code)]
165pub fn solve_bvp_extended<F, FunType>(
166    fun: FunType,
167    x_span: [F; 2],
168    boundary_conditions: ExtendedBoundaryConditions<F>,
169    n_points: usize,
170    options: Option<BVPOptions<F>>,
171) -> IntegrateResult<BVPResult<F>>
172where
173    F: IntegrateFloat,
174    FunType: Fn(F, ArrayView1<F>) -> Array1<F> + Copy,
175{
176    let [a, b] = x_span;
177
178    if a >= b {
179        return Err(IntegrateError::ValueError(
180            "Invalid interval: left bound must be less than right bound".to_string(),
181        ));
182    }
183
184    let ndim = boundary_conditions.left.len();
185    if boundary_conditions.right.len() != ndim {
186        return Err(IntegrateError::ValueError(
187            "Left and right boundary _conditions must have same dimension".to_string(),
188        ));
189    }
190
191    // Generate uniform mesh
192    let mesh: Vec<F> = (0..n_points)
193        .map(|i| a + (b - a) * F::from_usize(i).unwrap() / F::from_usize(n_points - 1).unwrap())
194        .collect();
195
196    // Generate initial guess (zero solution for now - could be improved)
197    let mut y_init = Vec::with_capacity(n_points);
198    for _i in 0..n_points {
199        y_init.push(Array1::zeros(ndim));
200    }
201
202    // Apply initial guess based on boundary _conditions
203    match boundary_conditions.left[0] {
204        BoundaryConditionType::Dirichlet { value } => {
205            if let BoundaryConditionType::Dirichlet { value: right_value } =
206                boundary_conditions.right[0]
207            {
208                // Linear interpolation between boundary values
209                for (i, y_val) in y_init.iter_mut().enumerate().take(n_points) {
210                    let t = F::from_usize(i).unwrap() / F::from_usize(n_points - 1).unwrap();
211                    y_val[0] = value * (F::one() - t) + right_value * t;
212                }
213            }
214        }
215        _ => {
216            // For other boundary conditions, use zero initial guess
217        }
218    }
219
220    // Create boundary condition function for the main BVP solver
221    let bc_func = create_extended_bc_function(boundary_conditions, fun, a, b);
222
223    // Solve using the main BVP solver
224    crate::bvp::solve_bvp(fun, bc_func, Some(mesh), y_init, options)
225}
226
227/// Create boundary condition function for extended boundary conditions
228#[allow(dead_code)]
229fn create_extended_bc_function<F, FunType>(
230    boundary_conditions: ExtendedBoundaryConditions<F>,
231    fun: FunType,
232    a: F,
233    b: F,
234) -> impl Fn(ArrayView1<F>, ArrayView1<F>) -> Array1<F>
235where
236    F: IntegrateFloat,
237    FunType: Fn(F, ArrayView1<F>) -> Array1<F> + Copy,
238{
239    move |ya: ArrayView1<F>, yb: ArrayView1<F>| {
240        let ndim = ya.len();
241
242        if boundary_conditions.is_periodic {
243            // For periodic boundary _conditions: u(a) = u(b), u'(a) = u'(b)
244            let f_a = fun(a, ya);
245            let f_b = fun(b, yb);
246
247            let mut residuals = Array1::zeros(ndim * 2);
248            for i in 0..ndim {
249                residuals[i] = ya[i] - yb[i]; // u(a) = u(b)
250                residuals[i + ndim] = f_a[i] - f_b[i]; // u'(a) = u'(b)
251            }
252            residuals
253        } else {
254            // General boundary _conditions
255            let f_a = fun(a, ya);
256            let f_b = fun(b, yb);
257
258            let mut residuals = Array1::zeros(ndim * 2);
259
260            // Left boundary _conditions
261            for (i, bc) in boundary_conditions.left.iter().enumerate() {
262                residuals[i] = bc.evaluate_residual(a, ya, f_a.view(), i);
263            }
264
265            // Right boundary _conditions
266            for (i, bc) in boundary_conditions.right.iter().enumerate() {
267                residuals[i + ndim] = bc.evaluate_residual(b, yb, f_b.view(), i);
268            }
269
270            residuals
271        }
272    }
273}
274
275/// Robin boundary condition builder for convenience
276#[derive(Debug, Clone)]
277pub struct RobinBC<F: IntegrateFloat> {
278    /// Coefficient of u
279    pub a: F,
280    /// Coefficient of u'
281    pub b: F,
282    /// Right-hand side value
283    pub c: F,
284}
285
286impl<F: IntegrateFloat> RobinBC<F> {
287    /// Create new Robin boundary condition: a*u + b*u' = c
288    pub fn new(a: F, b: F, c: F) -> Self {
289        Self { a, b, c }
290    }
291
292    /// Create Dirichlet condition: u = c
293    pub fn dirichlet(c: F) -> Self {
294        Self {
295            a: F::one(),
296            b: F::zero(),
297            c,
298        }
299    }
300
301    /// Create Neumann condition: u' = c
302    pub fn neumann(c: F) -> Self {
303        Self {
304            a: F::zero(),
305            b: F::one(),
306            c,
307        }
308    }
309
310    /// Create insulated boundary condition: u' = 0
311    pub fn insulated() -> Self {
312        Self::neumann(F::zero())
313    }
314
315    /// Create convective boundary condition: u' + h*(u - u_env) = 0
316    /// where h is heat transfer coefficient and u_env is environment temperature
317    pub fn convective(h: F, uenv: F) -> Self {
318        Self {
319            a: h,
320            b: F::one(),
321            c: h * uenv,
322        }
323    }
324}
325
326/// Multipoint boundary value problem support
327#[derive(Debug, Clone)]
328pub struct MultipointBVP<F: IntegrateFloat> {
329    /// Interior points where conditions are specified
330    pub interior_points: Vec<F>,
331    /// Boundary conditions at interior points
332    pub interior_conditions: Vec<Vec<BoundaryConditionType<F>>>,
333}
334
335impl<F: IntegrateFloat> Default for MultipointBVP<F> {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341impl<F: IntegrateFloat> MultipointBVP<F> {
342    /// Create new multipoint BVP
343    pub fn new() -> Self {
344        Self {
345            interior_points: Vec::new(),
346            interior_conditions: Vec::new(),
347        }
348    }
349
350    /// Add interior point with conditions
351    pub fn add_interior_point(&mut self, x: F, conditions: Vec<BoundaryConditionType<F>>) {
352        self.interior_points.push(x);
353        self.interior_conditions.push(conditions);
354    }
355}
356
357/// Solve multipoint boundary value problem
358#[allow(dead_code)]
359pub fn solve_multipoint_bvp<F, FunType>(
360    fun: FunType,
361    x_span: [F; 2],
362    boundary_conditions: ExtendedBoundaryConditions<F>,
363    multipoint: MultipointBVP<F>,
364    n_points: usize,
365    options: Option<BVPOptions<F>>,
366) -> IntegrateResult<BVPResult<F>>
367where
368    F: IntegrateFloat,
369    FunType: Fn(F, ArrayView1<F>) -> Array1<F> + Copy,
370{
371    if multipoint.interior_points.is_empty() {
372        // No interior points, solve as regular BVP
373        solve_bvp_extended(fun, x_span, boundary_conditions, n_points, options)
374    } else {
375        // Multipoint BVP using segmented collocation approach
376
377        // Validate interior _points are within domain and sorted
378        let [a, b] = x_span;
379        let mut all_points = vec![a];
380        all_points.extend(multipoint.interior_points.clone());
381        all_points.push(b);
382
383        // Sort and check uniqueness
384        for i in 1..all_points.len() {
385            if all_points[i] <= all_points[i - 1] {
386                return Err(IntegrateError::ValueError(
387                    "Interior _points must be unique and in ascending order".to_string(),
388                ));
389            }
390        }
391
392        // Determine dimensions
393        let ndim = boundary_conditions.left.len();
394        let n_segments = all_points.len() - 1;
395        let points_per_segment = (n_points - 1) / n_segments + 1;
396
397        // Build global mesh
398        let mut global_mesh = Vec::new();
399        for i in 0..n_segments {
400            let segment_start = all_points[i];
401            let segment_end = all_points[i + 1];
402            let n_seg_points = if i == n_segments - 1 {
403                points_per_segment
404            } else {
405                points_per_segment - 1 // Avoid duplicating interior _points
406            };
407
408            for j in 0..n_seg_points {
409                let t = F::from_usize(j).unwrap() / F::from_usize(n_seg_points - 1).unwrap();
410                let x = segment_start + (segment_end - segment_start) * t;
411                global_mesh.push(x);
412            }
413        }
414
415        // Initialize solution with zeros
416        let total_points = global_mesh.len();
417        let mut y_solution: Array2<F> = Array2::zeros((total_points, ndim));
418
419        // Apply boundary _conditions at endpoints
420        apply_initial_boundary_values(&boundary_conditions, &mut y_solution, total_points, ndim);
421
422        // Set up collocation system
423        let options = options.unwrap_or_default();
424        let mut residuals = vec![F::zero(); total_points * ndim];
425        let mut max_residual = F::zero();
426
427        // Newton's method for solving the collocation system
428        for _iter in 0..options.max_iter {
429            // Compute residuals at all collocation _points
430            compute_multipoint_residuals(
431                &fun,
432                &global_mesh,
433                &y_solution,
434                &boundary_conditions,
435                &multipoint,
436                &mut residuals,
437                ndim,
438            )?;
439
440            // Check convergence
441            max_residual =
442                residuals
443                    .iter()
444                    .map(|&r| r.abs())
445                    .fold(F::zero(), |a, b| if a > b { a } else { b });
446
447            if max_residual < options.tol {
448                // Converged
449                let x = global_mesh.clone();
450                let y = transpose_solution(y_solution);
451
452                return Ok(BVPResult {
453                    x: x.to_vec(),
454                    y,
455                    n_iter: _iter + 1,
456                    success: true,
457                    message: Some("Converged".to_string()),
458                    residual_norm: max_residual,
459                });
460            }
461
462            // Compute Jacobian and solve linear system
463            let jacobian = compute_multipoint_jacobian(
464                &fun,
465                &global_mesh,
466                &y_solution,
467                &boundary_conditions,
468                &multipoint,
469                ndim,
470                F::from(1e-6).unwrap(), // Default jacobian epsilon
471            )?;
472
473            // Solve J * delta_y = -residuals
474            let delta_y = solve_sparse_system(&jacobian, &residuals)?;
475
476            // Update solution
477            for (i, delta) in delta_y.iter().enumerate() {
478                let row = i / ndim;
479                let col = i % ndim;
480                y_solution[[row, col]] -= *delta;
481            }
482        }
483
484        // Did not converge
485        let x = global_mesh;
486        let y = transpose_solution(y_solution);
487
488        Ok(BVPResult {
489            x,
490            y,
491            n_iter: options.max_iter,
492            success: false,
493            message: Some("Did not converge within max iterations".to_string()),
494            residual_norm: max_residual,
495        })
496    }
497}
498
499/// Apply initial boundary values to solution array
500#[allow(dead_code)]
501fn apply_initial_boundary_values<F: IntegrateFloat>(
502    boundary_conditions: &ExtendedBoundaryConditions<F>,
503    y_solution: &mut Array2<F>,
504    n_points: usize,
505    _ndim: usize,
506) {
507    // Apply Dirichlet _conditions at boundaries if available
508    for (dim, bc) in boundary_conditions.left.iter().enumerate() {
509        if let BoundaryConditionType::Dirichlet { value } = bc {
510            y_solution[[0, dim]] = *value;
511        }
512    }
513
514    for (dim, bc) in boundary_conditions.right.iter().enumerate() {
515        if let BoundaryConditionType::Dirichlet { value } = bc {
516            y_solution[[n_points - 1, dim]] = *value;
517        }
518    }
519}
520
521/// Compute residuals for multipoint BVP
522#[allow(dead_code)]
523fn compute_multipoint_residuals<F: IntegrateFloat, FunType>(
524    fun: &FunType,
525    mesh: &[F],
526    y_solution: &Array2<F>,
527    boundary_conditions: &ExtendedBoundaryConditions<F>,
528    multipoint: &MultipointBVP<F>,
529    residuals: &mut [F],
530    ndim: usize,
531) -> IntegrateResult<()>
532where
533    FunType: Fn(F, ArrayView1<F>) -> Array1<F>,
534{
535    let n_points = mesh.len();
536    let h = mesh[1] - mesh[0]; // Assuming uniform spacing for simplicity
537
538    // Interior point residuals (differential equations)
539    for i in 1..n_points - 1 {
540        let y_prev = y_solution.row(i - 1);
541        let y_curr = y_solution.row(i);
542        let y_next = y_solution.row(i + 1);
543
544        // Compute derivatives using central differences
545        let dydt = (&y_next - &y_prev) / (F::from_f64(2.0).unwrap() * h);
546
547        // Evaluate ODE
548        let f_val = fun(mesh[i], y_curr);
549
550        // Residual: dy/dt - f(t, y) = 0
551        for j in 0..ndim {
552            residuals[i * ndim + j] = dydt[j] - f_val[j];
553        }
554    }
555
556    // Boundary condition residuals
557    apply_boundary_residuals(
558        boundary_conditions,
559        y_solution,
560        residuals,
561        n_points,
562        ndim,
563        h,
564    );
565
566    // Interior point condition residuals
567    apply_interior_residuals(multipoint, mesh, y_solution, residuals, ndim, h)?;
568
569    Ok(())
570}
571
572/// Apply boundary condition residuals
573#[allow(dead_code)]
574fn apply_boundary_residuals<F: IntegrateFloat>(
575    boundary_conditions: &ExtendedBoundaryConditions<F>,
576    y_solution: &Array2<F>,
577    residuals: &mut [F],
578    n_points: usize,
579    ndim: usize,
580    h: F,
581) {
582    // Left boundary
583    let y_left = y_solution.row(0);
584    let y_left_next = y_solution.row(1);
585    let dydt_left = (&y_left_next - &y_left) / h;
586
587    for (dim, bc) in boundary_conditions.left.iter().enumerate() {
588        residuals[dim] = bc.evaluate_residual(F::zero(), y_left, dydt_left.view(), dim);
589    }
590
591    // Right boundary
592    let y_right = y_solution.row(n_points - 1);
593    let y_right_prev = y_solution.row(n_points - 2);
594    let dydt_right = (&y_right - &y_right_prev) / h;
595
596    for (dim, bc) in boundary_conditions.right.iter().enumerate() {
597        residuals[(n_points - 1) * ndim + dim] =
598            bc.evaluate_residual(F::zero(), y_right, dydt_right.view(), dim);
599    }
600}
601
602/// Apply interior point condition residuals
603#[allow(dead_code)]
604fn apply_interior_residuals<F: IntegrateFloat>(
605    multipoint: &MultipointBVP<F>,
606    mesh: &[F],
607    y_solution: &Array2<F>,
608    residuals: &mut [F],
609    ndim: usize,
610    h: F,
611) -> IntegrateResult<()> {
612    // Find indices of interior condition points
613    for (point_idx, &interior_x) in multipoint.interior_points.iter().enumerate() {
614        // Find closest mesh point
615        let mesh_idx = mesh
616            .iter()
617            .position(|&x| (x - interior_x).abs() < F::from_f64(1e-10).unwrap())
618            .ok_or_else(|| {
619                IntegrateError::ValueError("Interior point not found in mesh".to_string())
620            })?;
621
622        let y_at_point = y_solution.row(mesh_idx);
623
624        // Compute derivative at interior point
625        let dydt_at_point = if mesh_idx > 0 && mesh_idx < mesh.len() - 1 {
626            let y_prev = y_solution.row(mesh_idx - 1);
627            let y_next = y_solution.row(mesh_idx + 1);
628            (&y_next - &y_prev) / (F::from_f64(2.0).unwrap() * h)
629        } else {
630            // Use one-sided difference at boundaries
631            if mesh_idx == 0 {
632                let y_next = y_solution.row(1);
633                (&y_next - &y_at_point) / h
634            } else {
635                let y_prev = y_solution.row(mesh_idx - 1);
636                (&y_at_point - &y_prev) / h
637            }
638        };
639
640        // Apply each condition at this interior point
641        for (cond_idx, condition) in multipoint.interior_conditions[point_idx].iter().enumerate() {
642            residuals[mesh_idx * ndim + cond_idx] =
643                condition.evaluate_residual(interior_x, y_at_point, dydt_at_point.view(), cond_idx);
644        }
645    }
646
647    Ok(())
648}
649
650/// Compute Jacobian for multipoint BVP (simplified version)
651#[allow(dead_code)]
652fn compute_multipoint_jacobian<F: IntegrateFloat, FunType>(
653    fun: &FunType,
654    mesh: &[F],
655    y_solution: &Array2<F>,
656    boundary_conditions: &ExtendedBoundaryConditions<F>,
657    multipoint: &MultipointBVP<F>,
658    ndim: usize,
659    eps: F,
660) -> IntegrateResult<Vec<Vec<F>>>
661where
662    FunType: Fn(F, ArrayView1<F>) -> Array1<F>,
663{
664    let n_points = mesh.len();
665    let total_size = n_points * ndim;
666    let mut jacobian = vec![vec![F::zero(); total_size]; total_size];
667
668    // Use finite differences to approximate Jacobian
669    let mut residuals_base = vec![F::zero(); total_size];
670    let mut residuals_pert = vec![F::zero(); total_size];
671
672    // Compute base residuals
673    compute_multipoint_residuals(
674        fun,
675        mesh,
676        y_solution,
677        boundary_conditions,
678        multipoint,
679        &mut residuals_base,
680        ndim,
681    )?;
682
683    // Perturb each variable and compute Jacobian columns
684    let mut y_pert = y_solution.clone();
685
686    for col in 0..total_size {
687        let row_idx = col / ndim;
688        let dim_idx = col % ndim;
689
690        // Perturb
691        let original = y_pert[[row_idx, dim_idx]];
692        y_pert[[row_idx, dim_idx]] = original + eps;
693
694        // Compute perturbed residuals
695        compute_multipoint_residuals(
696            fun,
697            mesh,
698            &y_pert,
699            boundary_conditions,
700            multipoint,
701            &mut residuals_pert,
702            ndim,
703        )?;
704
705        // Compute Jacobian column
706        for row in 0..total_size {
707            jacobian[row][col] = (residuals_pert[row] - residuals_base[row]) / eps;
708        }
709
710        // Restore original value
711        y_pert[[row_idx, dim_idx]] = original;
712    }
713
714    Ok(jacobian)
715}
716
717/// Solve sparse linear system (simplified dense solver)
718#[allow(dead_code)]
719fn solve_sparse_system<F: IntegrateFloat>(
720    jacobian: &[Vec<F>],
721    residuals: &[F],
722) -> IntegrateResult<Vec<F>> {
723    // Convert to dense matrix and use LU decomposition
724    // In a real implementation, we'd use a sparse solver
725    let n = jacobian.len();
726    let mut a = Array2::zeros((n, n));
727    let mut b = Array1::zeros(n);
728
729    for i in 0..n {
730        for j in 0..n {
731            a[[i, j]] = jacobian[i][j];
732        }
733        b[i] = residuals[i];
734    }
735
736    // Use scirs2-linalg for solving
737    // For now, use a simple Gaussian elimination
738    let solution = gaussian_elimination(a, b)?;
739
740    Ok(solution.to_vec())
741}
742
743/// Simple Gaussian elimination solver
744#[allow(dead_code)]
745fn gaussian_elimination<F: IntegrateFloat>(
746    mut a: Array2<F>,
747    mut b: Array1<F>,
748) -> IntegrateResult<Array1<F>> {
749    let n = a.nrows();
750
751    // Forward elimination
752    for k in 0..n - 1 {
753        // Find pivot
754        let mut max_row = k;
755        for i in k + 1..n {
756            if a[[i, k]].abs() > a[[max_row, k]].abs() {
757                max_row = i;
758            }
759        }
760
761        // Swap rows
762        if max_row != k {
763            for j in 0..n {
764                let temp = a[[k, j]];
765                a[[k, j]] = a[[max_row, j]];
766                a[[max_row, j]] = temp;
767            }
768            let temp = b[k];
769            b[k] = b[max_row];
770            b[max_row] = temp;
771        }
772
773        // Check for singular matrix
774        if a[[k, k]].abs() < F::from_f64(1e-12).unwrap() {
775            return Err(IntegrateError::ComputationError(
776                "Singular matrix in Gaussian elimination".to_string(),
777            ));
778        }
779
780        // Eliminate column
781        for i in k + 1..n {
782            let factor = a[[i, k]] / a[[k, k]];
783            for j in k + 1..n {
784                a[[i, j]] = a[[i, j]] - factor * a[[k, j]];
785            }
786            b[i] = b[i] - factor * b[k];
787        }
788    }
789
790    // Back substitution
791    let mut x = Array1::zeros(n);
792    for i in (0..n).rev() {
793        x[i] = b[i];
794        for j in i + 1..n {
795            x[i] = x[i] - a[[i, j]] * x[j];
796        }
797        x[i] /= a[[i, i]];
798    }
799
800    Ok(x)
801}
802
803/// Transpose solution from row-major to column-major format
804#[allow(dead_code)]
805fn transpose_solution<F: IntegrateFloat>(solution: Array2<F>) -> Vec<Array1<F>> {
806    let n_points = solution.nrows();
807    let _ndim = solution.ncols();
808
809    let mut result = Vec::with_capacity(n_points);
810    for i in 0..n_points {
811        result.push(solution.row(i).to_owned());
812    }
813
814    result
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use approx::assert_abs_diff_eq;
821
822    #[test]
823    fn test_robin_boundary_conditions() {
824        // Test Robin BC builder
825        let robin = RobinBC::new(2.0, 1.0, 3.0); // 2*u + 1*u' = 3
826        assert_abs_diff_eq!(robin.a, 2.0);
827        assert_abs_diff_eq!(robin.b, 1.0);
828        assert_abs_diff_eq!(robin.c, 3.0);
829
830        // Test convenience methods
831        let dirichlet = RobinBC::dirichlet(5.0); // u = 5
832        assert_abs_diff_eq!(dirichlet.a, 1.0);
833        assert_abs_diff_eq!(dirichlet.b, 0.0);
834        assert_abs_diff_eq!(dirichlet.c, 5.0);
835
836        let neumann = RobinBC::neumann(2.0); // u' = 2
837        assert_abs_diff_eq!(neumann.a, 0.0);
838        assert_abs_diff_eq!(neumann.b, 1.0);
839        assert_abs_diff_eq!(neumann.c, 2.0);
840
841        let insulated: RobinBC<f64> = RobinBC::insulated(); // u' = 0
842        assert_abs_diff_eq!(insulated.a, 0.0);
843        assert_abs_diff_eq!(insulated.b, 1.0);
844        assert_abs_diff_eq!(insulated.c, 0.0);
845    }
846
847    #[test]
848    fn test_boundary_condition_evaluation() {
849        let y = Array1::from_vec(vec![2.0, 3.0]);
850        let dydt = Array1::from_vec(vec![1.0, -1.0]);
851
852        // Test Dirichlet: u = 5
853        let dirichlet = BoundaryConditionType::Dirichlet { value: 5.0 };
854        let residual = dirichlet.evaluate_residual(0.0, y.view(), dydt.view(), 0);
855        assert_abs_diff_eq!(residual, -3.0); // 2 - 5 = -3
856
857        // Test Neumann: u' = 0.5
858        let neumann = BoundaryConditionType::Neumann { value: 0.5 };
859        let residual = neumann.evaluate_residual(0.0, y.view(), dydt.view(), 0);
860        assert_abs_diff_eq!(residual, 0.5); // 1 - 0.5 = 0.5
861
862        // Test Robin: 2*u + 3*u' = 10
863        let robin = BoundaryConditionType::Robin {
864            a: 2.0,
865            b: 3.0,
866            c: 10.0,
867        };
868        let residual = robin.evaluate_residual(0.0, y.view(), dydt.view(), 0);
869        assert_abs_diff_eq!(residual, -3.0); // 2*2 + 3*1 - 10 = -3
870    }
871
872    #[test]
873    fn test_extended_boundary_conditions_creation() {
874        // Test Dirichlet creation
875        let dirichlet = ExtendedBoundaryConditions::dirichlet(vec![1.0, 2.0], vec![3.0, 4.0]);
876        assert!(!dirichlet.is_periodic);
877        assert_eq!(dirichlet.left.len(), 2);
878        assert_eq!(dirichlet.right.len(), 2);
879
880        // Test Robin creation
881        let robin = ExtendedBoundaryConditions::robin(
882            vec![(1.0, 0.0, 5.0), (0.0, 1.0, 2.0)], // u = 5, u' = 2
883            vec![(1.0, 0.0, 3.0), (0.0, 1.0, 1.0)], // u = 3, u' = 1
884        );
885        assert!(!robin.is_periodic);
886        assert_eq!(robin.left.len(), 2);
887        assert_eq!(robin.right.len(), 2);
888
889        // Test periodic creation
890        let periodic: ExtendedBoundaryConditions<f64> = ExtendedBoundaryConditions::periodic(3);
891        assert!(periodic.is_periodic);
892        assert_eq!(periodic.left.len(), 3);
893        assert_eq!(periodic.right.len(), 3);
894    }
895}