Skip to main content

scirs2_integrate/pde/finite_element/
mod.rs

1//! Finite Element Method (FEM) for solving PDEs
2//!
3//! This module provides implementations of the Finite Element Method for
4//! solving partial differential equations on structured and unstructured meshes.
5//!
6//! Key features:
7//! - Linear, quadratic, and cubic element types
8//! - Triangular elements for 2D problems
9//! - Mesh generation and manipulation
10//! - Support for irregular domains
11//! - Various boundary condition types
12
13pub mod higher_order;
14
15#[cfg(test)]
16mod higher_order_tests;
17
18use scirs2_core::ndarray::{Array1, Array2};
19use std::collections::HashMap;
20use std::time::Instant;
21
22use crate::pde::{
23    BoundaryCondition, BoundaryConditionType, BoundaryLocation, PDEError, PDEResult, PDESolution,
24    PDESolverInfo,
25};
26
27// Re-export higher-order functionality
28pub use higher_order::{
29    HigherOrderMeshGenerator, HigherOrderTriangle, ShapeFunctions, TriangularQuadrature,
30};
31
32/// A point in 2D space
33#[derive(Debug, Clone, Copy, PartialEq)]
34pub struct Point {
35    /// x-coordinate
36    pub x: f64,
37
38    /// y-coordinate
39    pub y: f64,
40}
41
42impl Point {
43    /// Create a new point
44    pub fn new(x: f64, y: f64) -> Self {
45        Point { x, y }
46    }
47
48    /// Calculate the distance to another point
49    pub fn distance(&self, other: &Point) -> f64 {
50        ((self.x - other.x).powi(2) + (self.y - other.y).powi(2)).sqrt()
51    }
52}
53
54/// A triangle element defined by three nodes
55#[derive(Debug, Clone)]
56pub struct Triangle {
57    /// Node indices (vertices of the triangle)
58    pub nodes: [usize; 3],
59
60    /// Marker for domain/boundary identification
61    pub marker: Option<i32>,
62}
63
64impl Triangle {
65    /// Create a new triangle
66    pub fn new(nodes: [usize; 3], marker: Option<i32>) -> Self {
67        Triangle { nodes, marker }
68    }
69}
70
71/// A mesh of triangular elements
72#[derive(Debug, Clone)]
73pub struct TriangularMesh {
74    /// Points/nodes in the mesh
75    pub points: Vec<Point>,
76
77    /// Triangular elements
78    pub elements: Vec<Triangle>,
79
80    /// Boundary edges (node indices for each edge)
81    pub boundary_edges: Vec<(usize, usize, Option<i32>)>,
82
83    /// Map from node index to its boundary condition type (if on boundary)
84    pub boundary_nodes: HashMap<usize, BoundaryNodeInfo>,
85}
86
87/// Information about a boundary node
88#[derive(Debug, Clone)]
89pub struct BoundaryNodeInfo {
90    /// Boundary type
91    pub bc_type: BoundaryConditionType,
92
93    /// Value for Dirichlet or flux for Neumann boundaries
94    pub value: f64,
95
96    /// Additional coefficients for Robin boundaries
97    pub coefficients: Option<[f64; 3]>,
98
99    /// Marker for boundary identification
100    pub marker: Option<i32>,
101}
102
103impl Default for TriangularMesh {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl TriangularMesh {
110    /// Create a new empty triangular mesh
111    pub fn new() -> Self {
112        TriangularMesh {
113            points: Vec::new(),
114            elements: Vec::new(),
115            boundary_edges: Vec::new(),
116            boundary_nodes: HashMap::new(),
117        }
118    }
119
120    /// Generate a simple triangular mesh on a rectangular domain
121    pub fn generate_rectangular(
122        x_range: (f64, f64),
123        y_range: (f64, f64),
124        nx: usize,
125        ny: usize,
126    ) -> Self {
127        let mut mesh = TriangularMesh::new();
128
129        // Generate grid points
130        let dx = (x_range.1 - x_range.0) / (nx as f64);
131        let dy = (y_range.1 - y_range.0) / (ny as f64);
132
133        // Create points
134        for j in 0..=ny {
135            for i in 0..=nx {
136                let x = x_range.0 + i as f64 * dx;
137                let y = y_range.0 + j as f64 * dy;
138                mesh.points.push(Point::new(x, y));
139            }
140        }
141
142        // Create triangular elements
143        for j in 0..ny {
144            for i in 0..nx {
145                // Node indices at the corners of the grid cell
146                let n00 = j * (nx + 1) + i; // Bottom-left
147                let n10 = j * (nx + 1) + (i + 1); // Bottom-right
148                let n01 = (j + 1) * (nx + 1) + i; // Top-left
149                let n11 = (j + 1) * (nx + 1) + (i + 1); // Top-right
150
151                // Create two triangles per grid cell
152                // Triangle 1: Bottom-left, Bottom-right, Top-left
153                mesh.elements.push(Triangle::new([n00, n10, n01], None));
154
155                // Triangle 2: Top-right, Top-left, Bottom-right
156                mesh.elements.push(Triangle::new([n11, n01, n10], None));
157            }
158        }
159
160        // Identify boundary edges
161
162        // Bottom edge (y = y_range.0)
163        for i in 0..nx {
164            let n1 = i;
165            let n2 = i + 1;
166            mesh.boundary_edges.push((n1, n2, Some(1))); // Marker 1 for bottom
167        }
168
169        // Right edge (x = x_range.1)
170        for j in 0..ny {
171            let n1 = (j + 1) * (nx + 1) - 1;
172            let n2 = (j + 2) * (nx + 1) - 1;
173            mesh.boundary_edges.push((n1, n2, Some(2))); // Marker 2 for right
174        }
175
176        // Top edge (y = y_range.1)
177        for i in 0..nx {
178            let n1 = (ny + 1) * (nx + 1) - i - 1;
179            let n2 = (ny + 1) * (nx + 1) - i - 2;
180            mesh.boundary_edges.push((n1, n2, Some(3))); // Marker 3 for top
181        }
182
183        // Left edge (x = x_range.0)
184        for j in 0..ny {
185            let n1 = (ny - j) * (nx + 1);
186            let n2 = (ny - j - 1) * (nx + 1);
187            mesh.boundary_edges.push((n1, n2, Some(4))); // Marker 4 for left
188        }
189
190        mesh
191    }
192
193    /// Set boundary conditions based on boundary markers
194    pub fn set_boundary_conditions(
195        &mut self,
196        boundary_conditions: &[BoundaryCondition<f64>],
197    ) -> PDEResult<()> {
198        // Clear existing boundary nodes
199        self.boundary_nodes.clear();
200
201        // Process all boundary edges
202        for &(n1, n2, marker) in &self.boundary_edges {
203            // Find the matching boundary condition by marker
204            for bc in boundary_conditions {
205                // Map dimension and location to marker (simplified approach for example)
206                let bc_marker = match (bc.dimension, bc.location) {
207                    (1, BoundaryLocation::Lower) => Some(1), // Bottom
208                    (0, BoundaryLocation::Upper) => Some(2), // Right
209                    (1, BoundaryLocation::Upper) => Some(3), // Top
210                    (0, BoundaryLocation::Lower) => Some(4), // Left
211                    _ => None,
212                };
213
214                // If this boundary condition matches the edge marker
215                if bc_marker == marker {
216                    // Add both nodes of the edge to boundary_nodes
217                    let bc_info = BoundaryNodeInfo {
218                        bc_type: bc.bc_type,
219                        value: bc.value,
220                        coefficients: bc.coefficients,
221                        marker,
222                    };
223
224                    self.boundary_nodes.insert(n1, bc_info.clone());
225                    self.boundary_nodes.insert(n2, bc_info);
226                }
227            }
228        }
229
230        Ok(())
231    }
232
233    /// Compute area of a triangle
234    pub fn triangle_area(&self, element: &Triangle) -> f64 {
235        let [i, j, k] = element.nodes;
236        let pi = &self.points[i];
237        let pj = &self.points[j];
238        let pk = &self.points[k];
239
240        // Area using cross product
241        0.5 * ((pj.x - pi.x) * (pk.y - pi.y) - (pk.x - pi.x) * (pj.y - pi.y)).abs()
242    }
243
244    /// Compute shape function gradients for a linear triangular element
245    pub fn shape_function_gradients(&self, element: &Triangle) -> PDEResult<[Point; 3]> {
246        let [i, j, k] = element.nodes;
247        let pi = &self.points[i];
248        let pj = &self.points[j];
249        let pk = &self.points[k];
250
251        let area = self.triangle_area(element);
252        if area < 1e-10 {
253            return Err(PDEError::FiniteElementError(format!(
254                "Element has nearly zero area: {area}"
255            )));
256        }
257
258        // Linear shape function gradients
259        let gradients = [
260            Point::new((pj.y - pk.y) / (2.0 * area), (pk.x - pj.x) / (2.0 * area)),
261            Point::new((pk.y - pi.y) / (2.0 * area), (pi.x - pk.x) / (2.0 * area)),
262            Point::new((pi.y - pj.y) / (2.0 * area), (pj.x - pi.x) / (2.0 * area)),
263        ];
264
265        Ok(gradients)
266    }
267}
268
269/// Element type for finite element method
270#[derive(Debug, Clone, Copy, PartialEq)]
271pub enum ElementType {
272    /// Linear elements (1st order, 3 nodes for triangles)
273    Linear,
274
275    /// Quadratic elements (2nd order, 6 nodes for triangles)
276    Quadratic,
277
278    /// Cubic elements (3rd order, 10 nodes for triangles)
279    Cubic,
280}
281
282/// Options for finite element solvers
283#[derive(Debug, Clone)]
284pub struct FEMOptions {
285    /// Element type to use
286    pub element_type: ElementType,
287
288    /// Quadrature rule order (number of integration points)
289    pub quadrature_order: usize,
290
291    /// Maximum iterations for iterative solvers
292    pub max_iterations: usize,
293
294    /// Tolerance for convergence
295    pub tolerance: f64,
296
297    /// Whether to save convergence history
298    pub save_convergence_history: bool,
299
300    /// Print detailed progress information
301    pub verbose: bool,
302}
303
304impl Default for FEMOptions {
305    fn default() -> Self {
306        FEMOptions {
307            element_type: ElementType::Linear,
308            quadrature_order: 3, // 3-point rule suitable for quadratic functions
309            max_iterations: 1000,
310            tolerance: 1e-6,
311            save_convergence_history: false,
312            verbose: false,
313        }
314    }
315}
316
317/// Result of FEM solution
318#[derive(Debug, Clone)]
319pub struct FEMResult {
320    /// Solution values at nodes
321    pub u: Array1<f64>,
322
323    /// Mesh used for the solution
324    pub mesh: TriangularMesh,
325
326    /// Residual norm
327    pub residual_norm: f64,
328
329    /// Number of iterations performed
330    pub num_iterations: usize,
331
332    /// Computation time
333    pub computation_time: f64,
334
335    /// Convergence history
336    pub convergence_history: Option<Vec<f64>>,
337}
338
339/// Finite Element solver for Poisson's equation
340pub struct FEMPoissonSolver {
341    /// Mesh for finite element discretization
342    mesh: TriangularMesh,
343
344    /// Higher-order elements (if using non-linear elements)
345    higher_order_elements: Option<Vec<HigherOrderTriangle>>,
346
347    /// Additional points for higher-order elements
348    higher_order_points: Option<Vec<Point>>,
349
350    /// Source term function f(x, y)
351    source_term: Box<dyn Fn(f64, f64) -> f64 + Send + Sync>,
352
353    /// Boundary conditions
354    boundary_conditions: Vec<BoundaryCondition<f64>>,
355
356    /// Solver options
357    options: FEMOptions,
358}
359
360impl FEMPoissonSolver {
361    /// Create a new Finite Element solver for Poisson's equation
362    pub fn new(
363        mesh: TriangularMesh,
364        source_term: impl Fn(f64, f64) -> f64 + Send + Sync + 'static,
365        boundary_conditions: Vec<BoundaryCondition<f64>>,
366        options: Option<FEMOptions>,
367    ) -> PDEResult<Self> {
368        // Validate boundary _conditions
369        if boundary_conditions.is_empty() {
370            return Err(PDEError::BoundaryConditions(
371                "At least one boundary condition is required".to_string(),
372            ));
373        }
374
375        let opts = options.unwrap_or_default();
376
377        // Create higher-order elements if needed
378        let (higher_order_elements, higher_order_points) = match opts.element_type {
379            ElementType::Linear => (None, None),
380            ElementType::Quadratic => {
381                let (points, elements) = HigherOrderMeshGenerator::linear_to_quadratic(&mesh)?;
382                (Some(elements), Some(points))
383            }
384            ElementType::Cubic => {
385                let (points, elements) = HigherOrderMeshGenerator::linear_to_cubic(&mesh)?;
386                (Some(elements), Some(points))
387            }
388        };
389
390        Ok(FEMPoissonSolver {
391            mesh,
392            higher_order_elements,
393            higher_order_points,
394            source_term: Box::new(source_term),
395            boundary_conditions,
396            options: opts,
397        })
398    }
399
400    /// Solve Poisson's equation using the Finite Element Method
401    pub fn solve(&mut self) -> PDEResult<FEMResult> {
402        let start_time = Instant::now();
403
404        // Apply boundary conditions to the mesh
405        self.mesh
406            .set_boundary_conditions(&self.boundary_conditions)?;
407
408        // Number of nodes (degrees of freedom)
409        let _n = if let Some(ref higher_order_points) = self.higher_order_points {
410            higher_order_points.len()
411        } else {
412            self.mesh.points.len()
413        };
414
415        // Assemble stiffness matrix and load vector
416        let (mut a, mut b) = self.assemble_system()?;
417
418        // Apply Dirichlet boundary conditions
419        self.apply_dirichlet_boundary_conditions(&mut a, &mut b)?;
420
421        // Solve the linear system
422        let u = FEMPoissonSolver::solve_linear_system(&a, &b)?;
423
424        // Compute residual norm
425        let residual_norm = FEMPoissonSolver::compute_residual(&a, &b, &u);
426
427        let computation_time = start_time.elapsed().as_secs_f64();
428
429        Ok(FEMResult {
430            u,
431            mesh: self.mesh.clone(),
432            residual_norm,
433            num_iterations: 1, // Direct solver counts as one iteration
434            computation_time,
435            convergence_history: None,
436        })
437    }
438
439    /// Assemble the stiffness matrix and load vector for the FEM system
440    fn assemble_system(&self) -> PDEResult<(Array2<f64>, Array1<f64>)> {
441        let n = if let Some(ref higher_order_points) = self.higher_order_points {
442            higher_order_points.len()
443        } else {
444            self.mesh.points.len()
445        };
446
447        // Initialize stiffness matrix and load vector
448        let mut a = Array2::zeros((n, n));
449        let mut b = Array1::zeros(n);
450
451        match self.options.element_type {
452            ElementType::Linear => {
453                // Use existing linear element assembly
454                for element in &self.mesh.elements {
455                    let (a_e, b_e) = self.element_matrices_linear(element)?;
456
457                    // Assemble into global matrices
458                    let [i, j, k] = element.nodes;
459
460                    // Diagonal terms
461                    a[[i, i]] += a_e[0][0];
462                    a[[j, j]] += a_e[1][1];
463                    a[[k, k]] += a_e[2][2];
464
465                    // Off-diagonal terms
466                    a[[i, j]] += a_e[0][1];
467                    a[[i, k]] += a_e[0][2];
468                    a[[j, i]] += a_e[1][0];
469                    a[[j, k]] += a_e[1][2];
470                    a[[k, i]] += a_e[2][0];
471                    a[[k, j]] += a_e[2][1];
472
473                    // Load vector
474                    b[i] += b_e[0];
475                    b[j] += b_e[1];
476                    b[k] += b_e[2];
477                }
478            }
479            _ => {
480                // Use higher-order element assembly
481                if let Some(ref higher_order_elements) = self.higher_order_elements {
482                    for element in higher_order_elements {
483                        let (a_e, b_e) = self.element_matrices_higher_order(element)?;
484
485                        // Assemble into global matrices
486                        for (i, node_i) in element.nodes.iter().enumerate() {
487                            b[*node_i] += b_e[i];
488                            for (j, node_j) in element.nodes.iter().enumerate() {
489                                a[[*node_i, *node_j]] += a_e[[i, j]];
490                            }
491                        }
492                    }
493                }
494            }
495        }
496
497        Ok((a, b))
498    }
499
500    /// Compute element stiffness matrix and load vector for linear elements
501    fn element_matrices_linear(&self, element: &Triangle) -> PDEResult<([[f64; 3]; 3], [f64; 3])> {
502        // Get nodes
503        let [i, j, k] = element.nodes;
504        let pi = &self.mesh.points[i];
505        let pj = &self.mesh.points[j];
506        let pk = &self.mesh.points[k];
507
508        // Element area
509        let area = self.mesh.triangle_area(element);
510
511        // Shape function gradients
512        let gradients = self.mesh.shape_function_gradients(element)?;
513
514        // Stiffness matrix - For Poisson's equation: Integral of (∇φᵢ · ∇φⱼ) over _element
515        let mut a_e = [[0.0; 3]; 3];
516
517        for m in 0..3 {
518            for n in 0..3 {
519                // Dot product of shape function gradients
520                a_e[m][n] =
521                    (gradients[m].x * gradients[n].x + gradients[m].y * gradients[n].y) * area;
522            }
523        }
524
525        // Load vector - For Poisson's equation: Integral of (f · φᵢ) over _element
526        let mut b_e = [0.0; 3];
527
528        // Approximate the source term at the centroid of the triangle
529        let centroid_x = (pi.x + pj.x + pk.x) / 3.0;
530        let centroid_y = (pi.y + pj.y + pk.y) / 3.0;
531        let f_centroid = (self.source_term)(centroid_x, centroid_y);
532
533        // For linear elements, the integral of each shape function over the _element is area/3
534        b_e.iter_mut().for_each(|value| {
535            *value = f_centroid * (area / 3.0);
536        });
537
538        Ok((a_e, b_e))
539    }
540
541    /// Compute element stiffness matrix and load vector for higher-order elements
542    fn element_matrices_higher_order(
543        &self,
544        element: &HigherOrderTriangle,
545    ) -> PDEResult<(Array2<f64>, Array1<f64>)> {
546        let num_nodes = element.nodes.len();
547        let mut a_e = Array2::zeros((num_nodes, num_nodes));
548        let mut b_e = Array1::zeros(num_nodes);
549
550        // Get the points for this element type
551        let points = if let Some(ref ho_points) = self.higher_order_points {
552            ho_points
553        } else {
554            return Err(PDEError::FiniteElementError(
555                "Higher-order points not available".to_string(),
556            ));
557        };
558
559        // Get corner nodes to compute element area and coordinate transformation
560        let corner_nodes = element.corner_nodes();
561        let p1 = &points[corner_nodes[0]];
562        let p2 = &points[corner_nodes[1]];
563        let p3 = &points[corner_nodes[2]];
564
565        // Compute Jacobian for coordinate transformation from reference to physical element
566        let jacobian = Array2::from_shape_vec(
567            (2, 2),
568            vec![p2.x - p1.x, p3.x - p1.x, p2.y - p1.y, p3.y - p1.y],
569        )
570        .expect("Operation failed");
571
572        let det_j = jacobian[[0, 0]] * jacobian[[1, 1]] - jacobian[[0, 1]] * jacobian[[1, 0]];
573        if det_j.abs() < 1e-12 {
574            return Err(PDEError::FiniteElementError(
575                "Degenerate element with zero Jacobian determinant".to_string(),
576            ));
577        }
578
579        // Inverse of Jacobian
580        let inv_j = Array2::from_shape_vec(
581            (2, 2),
582            vec![
583                jacobian[[1, 1]] / det_j,
584                -jacobian[[0, 1]] / det_j,
585                -jacobian[[1, 0]] / det_j,
586                jacobian[[0, 0]] / det_j,
587            ],
588        )
589        .expect("Operation failed");
590
591        // Get quadrature rule
592        let (xi_coords, eta_coords, weights) =
593            TriangularQuadrature::get_rule(self.options.quadrature_order)?;
594
595        // Integrate over the element using quadrature
596        for q in 0..xi_coords.len() {
597            let xi = xi_coords[q];
598            let eta = eta_coords[q];
599            let weight = weights[q];
600
601            // Evaluate shape functions and their derivatives at quadrature point
602            let shape_funcs = ShapeFunctions::evaluate(element.element_type, xi, eta)?;
603            let (d_n_dxi, d_n_deta) =
604                ShapeFunctions::evaluate_derivatives(element.element_type, xi, eta)?;
605
606            // Transform derivatives from reference to physical coordinates
607            let mut d_n_dx = Array1::zeros(num_nodes);
608            let mut d_n_dy = Array1::zeros(num_nodes);
609
610            for i in 0..num_nodes {
611                d_n_dx[i] = inv_j[[0, 0]] * d_n_dxi[i] + inv_j[[0, 1]] * d_n_deta[i];
612                d_n_dy[i] = inv_j[[1, 0]] * d_n_dxi[i] + inv_j[[1, 1]] * d_n_deta[i];
613            }
614
615            // Compute physical coordinates of quadrature point for source term evaluation
616            let mut x_phys = 0.0;
617            let mut y_phys = 0.0;
618            for i in 0..num_nodes {
619                x_phys += shape_funcs[i] * points[element.nodes[i]].x;
620                y_phys += shape_funcs[i] * points[element.nodes[i]].y;
621            }
622
623            // Evaluate source term at quadrature point
624            let f_val = (self.source_term)(x_phys, y_phys);
625
626            // Add contributions to element matrices
627            for i in 0..num_nodes {
628                // Load vector: ∫ f * N_i * dV
629                b_e[i] += f_val * shape_funcs[i] * weight * det_j.abs();
630
631                for j in 0..num_nodes {
632                    // Stiffness matrix: ∫ (∇N_i · ∇N_j) * dV
633                    a_e[[i, j]] +=
634                        (d_n_dx[i] * d_n_dx[j] + d_n_dy[i] * d_n_dy[j]) * weight * det_j.abs();
635                }
636            }
637        }
638
639        Ok((a_e, b_e))
640    }
641
642    /// Apply Dirichlet boundary conditions to the system
643    fn apply_dirichlet_boundary_conditions(
644        &self,
645        a: &mut Array2<f64>,
646        b: &mut Array1<f64>,
647    ) -> PDEResult<()> {
648        let n = self.mesh.points.len();
649
650        // Loop over boundary nodes
651        for (&node_idx, bc_info) in &self.mesh.boundary_nodes {
652            if bc_info.bc_type == BoundaryConditionType::Dirichlet {
653                // Set row to identity
654                for j in 0..n {
655                    a[[node_idx, j]] = 0.0;
656                }
657                a[[node_idx, node_idx]] = 1.0;
658
659                // Set right-hand side to boundary value
660                b[node_idx] = bc_info.value;
661            } else if bc_info.bc_type == BoundaryConditionType::Neumann {
662                // Neumann boundary conditions are handled in the assembly process
663                // For linear elements on a flat boundary, this is equivalent to
664                // modifying the right-hand side vector
665
666                // Get all boundary edges containing this node
667                let boundary_edges: Vec<_> = self
668                    .mesh
669                    .boundary_edges
670                    .iter()
671                    .filter(|&&(n1, n2, _)| n1 == node_idx || n2 == node_idx)
672                    .collect();
673
674                // For each boundary edge, apply the Neumann condition
675                for &(n1, n2, _) in &boundary_edges {
676                    let other_node = if *n1 == node_idx { *n2 } else { *n1 };
677
678                    // Get the coordinates of the nodes
679                    let p1 = &self.mesh.points[node_idx];
680                    let p2 = &self.mesh.points[other_node];
681
682                    // Length of the edge
683                    let edge_length = p1.distance(p2);
684
685                    // Contribution to the load vector: g * (edge_length / 2)
686                    // where g is the Neumann boundary value
687                    b[node_idx] += bc_info.value * (edge_length / 2.0);
688                }
689            } else if bc_info.bc_type == BoundaryConditionType::Robin {
690                // Robin boundary conditions (a*u + b*∂u/∂n = c)
691                if let Some([a_coef, b_coef, c_coef]) = bc_info.coefficients {
692                    // Similar to Neumann, we need to find boundary edges
693                    let boundary_edges: Vec<_> = self
694                        .mesh
695                        .boundary_edges
696                        .iter()
697                        .filter(|&&(n1, n2, _)| n1 == node_idx || n2 == node_idx)
698                        .collect();
699
700                    for &(n1, n2, _) in &boundary_edges {
701                        let other_node = if *n1 == node_idx { *n2 } else { *n1 };
702
703                        // Get the coordinates of the nodes
704                        let p1 = &self.mesh.points[node_idx];
705                        let p2 = &self.mesh.points[other_node];
706
707                        // Length of the edge
708                        let edge_length = p1.distance(p2);
709
710                        // Contribution to the stiffness matrix and load vector
711                        // This is simplified - a more accurate implementation would
712                        // involve integrating along the boundary edge
713                        a[[node_idx, node_idx]] += a_coef * edge_length / 2.0;
714
715                        // Right-hand side contribution
716                        b[node_idx] += c_coef * edge_length / 2.0;
717                    }
718                }
719            }
720        }
721
722        Ok(())
723    }
724
725    /// Solve the linear system Ax = b
726    fn solve_linear_system(a: &Array2<f64>, b: &Array1<f64>) -> PDEResult<Array1<f64>> {
727        let n = b.len();
728
729        // Simple Gaussian elimination for demonstration purposes
730        // For a real implementation, use a sparse matrix solver library
731
732        // Create copies of A and b
733        let mut a_copy = a.clone();
734        let mut b_copy = b.clone();
735
736        // Forward elimination
737        for i in 0..n {
738            // Find pivot
739            let mut max_val = a_copy[[i, i]].abs();
740            let mut max_row = i;
741
742            for k in i + 1..n {
743                if a_copy[[k, i]].abs() > max_val {
744                    max_val = a_copy[[k, i]].abs();
745                    max_row = k;
746                }
747            }
748
749            // Check if matrix is singular
750            if max_val < 1e-10 {
751                return Err(PDEError::Other(
752                    "Matrix is singular or nearly singular".to_string(),
753                ));
754            }
755
756            // Swap rows if necessary
757            if max_row != i {
758                for j in i..n {
759                    let temp = a_copy[[i, j]];
760                    a_copy[[i, j]] = a_copy[[max_row, j]];
761                    a_copy[[max_row, j]] = temp;
762                }
763
764                let temp = b_copy[i];
765                b_copy[i] = b_copy[max_row];
766                b_copy[max_row] = temp;
767            }
768
769            // Eliminate below
770            for k in i + 1..n {
771                let factor = a_copy[[k, i]] / a_copy[[i, i]];
772
773                for j in i..n {
774                    a_copy[[k, j]] -= factor * a_copy[[i, j]];
775                }
776
777                b_copy[k] -= factor * b_copy[i];
778            }
779        }
780
781        // Back substitution
782        let mut x = Array1::zeros(n);
783        for i in (0..n).rev() {
784            let mut sum = 0.0;
785            for j in i + 1..n {
786                sum += a_copy[[i, j]] * x[j];
787            }
788
789            x[i] = (b_copy[i] - sum) / a_copy[[i, i]];
790        }
791
792        Ok(x)
793    }
794
795    /// Compute residual norm ||Ax - b||₂
796    fn compute_residual(a: &Array2<f64>, b: &Array1<f64>, x: &Array1<f64>) -> f64 {
797        let n = b.len();
798        let mut residual = 0.0;
799
800        for i in 0..n {
801            let mut row_sum = 0.0;
802            for j in 0..n {
803                row_sum += a[[i, j]] * x[j];
804            }
805
806            let diff = row_sum - b[i];
807            residual += diff * diff;
808        }
809
810        residual.sqrt()
811    }
812}
813
814/// Convert FEMResult to PDESolution
815impl From<FEMResult> for PDESolution<f64> {
816    fn from(result: FEMResult) -> Self {
817        let mut grids = Vec::new();
818        let n = result.mesh.points.len();
819
820        // Extract x and y coordinates as separate grids
821        let mut x_coords = Array1::zeros(n);
822        let mut y_coords = Array1::zeros(n);
823
824        for (i, point) in result.mesh.points.iter().enumerate() {
825            x_coords[i] = point.x;
826            y_coords[i] = point.y;
827        }
828
829        grids.push(x_coords);
830        grids.push(y_coords);
831
832        // Create solution values as a 2D array with one column
833        let mut values = Vec::new();
834        let u_reshaped = result
835            .u
836            .into_shape_with_order((n, 1))
837            .expect("Operation failed");
838        values.push(u_reshaped);
839
840        // Create solver info
841        let info = PDESolverInfo {
842            num_iterations: result.num_iterations,
843            computation_time: result.computation_time,
844            residual_norm: Some(result.residual_norm),
845            convergence_history: result.convergence_history,
846            method: "Finite Element Method".to_string(),
847        };
848
849        PDESolution {
850            grids,
851            values,
852            error_estimate: None,
853            info,
854        }
855    }
856}
857
858// Add PDE error types
859impl PDEError {
860    /// Create a finite element error
861    pub fn finite_element_error(msg: String) -> Self {
862        PDEError::Other(format!("Finite element error: {msg}"))
863    }
864}