scirs2_integrate/pde/spectral/
spectral_element.rs

1//! Spectral Element Methods (SEM) for solving PDEs
2//!
3//! This module provides implementations of spectral element methods,
4//! which combine the accuracy of spectral methods with the geometric
5//! flexibility of finite element methods. This approach allows for
6//! high-order polynomial approximations on complex geometries.
7//!
8//! Key features:
9//! - High-order polynomial basis functions (using nodal Lagrange polynomials)
10//! - Domain decomposition into elements (quadrilaterals in 2D)
11//! - Gauss-Lobatto-Legendre quadrature for integration
12//! - Isoparametric mapping for curved elements
13//! - Exponential convergence for smooth solutions
14
15use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1};
16use std::time::Instant;
17
18use crate::pde::spectral::{legendre_diff_matrix, legendre_points};
19use crate::pde::{
20    BoundaryCondition, BoundaryConditionType, BoundaryLocation, Domain, PDEError, PDEResult,
21    PDESolution, PDESolverInfo,
22};
23
24/// Quadrilateral element for 2D spectral element methods
25#[derive(Debug, Clone)]
26pub struct QuadElement {
27    /// Element ID
28    pub id: usize,
29
30    /// Global coordinates of element vertices (4 corners, counterclockwise ordering)
31    pub vertices: [(f64, f64); 4],
32
33    /// Global indices of nodes in this element
34    pub node_indices: Vec<usize>,
35
36    /// Element boundary conditions (if any)
37    pub boundary_edges: Vec<(usize, usize, Option<BoundaryConditionType>)>,
38}
39
40/// Spectral element mesh for 2D problems
41#[derive(Debug, Clone)]
42pub struct SpectralElementMesh2D {
43    /// Elements in the mesh
44    pub elements: Vec<QuadElement>,
45
46    /// Global node coordinates (x, y)
47    pub nodes: Vec<(f64, f64)>,
48
49    /// Global-to-local mapping for nodes
50    pub global_to_local: Vec<Vec<(usize, usize, usize)>>, // (element_id, i, j)
51
52    /// Boundary nodes with condition info
53    pub boundary_nodes: Vec<(usize, BoundaryConditionType)>,
54
55    /// Polynomial order in each direction
56    pub order: (usize, usize),
57
58    /// Total number of nodes in the mesh
59    pub num_nodes: usize,
60}
61
62impl SpectralElementMesh2D {
63    /// Create a new rectangular spectral element mesh
64    ///
65    /// # Arguments
66    ///
67    /// * `x_range` - Range for the x-coordinate domain [x_min, x_max]
68    /// * `y_range` - Range for the y-coordinate domain [y_min, y_max]
69    /// * `nx` - Number of elements in the x direction
70    /// * `ny` - Number of elements in the y direction
71    /// * `order` - Polynomial order in each element (p, p)
72    ///
73    /// # Returns
74    ///
75    /// * A structured rectangular mesh of quadrilateral elements
76    pub fn rectangular(
77        x_range: [f64; 2],
78        y_range: [f64; 2],
79        nx: usize,
80        ny: usize,
81        order: usize,
82    ) -> PDEResult<Self> {
83        if nx == 0 || ny == 0 {
84            return Err(PDEError::DomainError(
85                "Number of elements must be at least 1 in each direction".to_string(),
86            ));
87        }
88
89        if order < 1 {
90            return Err(PDEError::DomainError(
91                "Polynomial order must be at least 1".to_string(),
92            ));
93        }
94
95        let [x_min, x_max] = x_range;
96        let [y_min, y_max] = y_range;
97
98        let dx = (x_max - x_min) / nx as f64;
99        let dy = (y_max - y_min) / ny as f64;
100
101        // Number of nodes in each direction per element (order + 1)
102        let n = order + 1;
103
104        // Generate Gauss-Lobatto-Legendre points in 1D (scaled to [0, 1])
105        let (xi_pts_, weights) = legendre_points(n);
106        let xi = (xi_pts_ + 1.0) * 0.5;
107
108        // Create elements and nodes
109        let mut elements = Vec::with_capacity(nx * ny);
110        let mut nodes = Vec::new();
111        let mut boundary_nodes = Vec::new();
112
113        // Global node counter
114        let mut node_count = 0;
115
116        // Temporary global node indices for each element
117        let mut global_indices = Array2::<usize>::zeros((nx * n - (nx - 1), ny * n - (ny - 1)));
118
119        // First, create all nodes and assign global indices
120        for j in 0..(ny * n - (ny - 1)) {
121            for i in 0..(nx * n - (nx - 1)) {
122                // Determine if this is an element edge or interior node
123                let _is_edge = i % n == 0
124                    || i == nx * n - (nx - 1) - 1
125                    || j % n == 0
126                    || j == ny * n - (ny - 1) - 1;
127
128                // Compute physical coordinates
129                let x_idx = i / n; // Element index in x direction
130                let y_idx = j / n; // Element index in y direction
131
132                let local_i = i % n; // Local index within element
133                let local_j = j % n;
134
135                // Handle shared nodes between elements
136                if (local_i == 0 && x_idx > 0) || (local_j == 0 && y_idx > 0) {
137                    // This node is shared with a previous element
138                    // Skip node creation but still assign the global index
139                    if local_i == 0 && local_j == 0 && x_idx > 0 && y_idx > 0 {
140                        // Corner node shared by 4 elements
141                        global_indices[[i, j]] = global_indices[[i - n + 1, j - n + 1]];
142                    } else if local_i == 0 && x_idx > 0 {
143                        // Edge node shared horizontally
144                        global_indices[[i, j]] = global_indices[[i - n + 1, j]];
145                    } else if local_j == 0 && y_idx > 0 {
146                        // Edge node shared vertically
147                        global_indices[[i, j]] = global_indices[[i, j - n + 1]];
148                    }
149                    continue;
150                }
151
152                // Map to physical coordinates using element index and local coordinates
153                let x = x_min + (x_idx as f64 + xi[local_i]) * dx;
154                let y = y_min + (y_idx as f64 + xi[local_j]) * dy;
155
156                // Create node and store index
157                nodes.push((x, y));
158                global_indices[[i, j]] = node_count;
159
160                // Add to boundary nodes if on the domain boundary
161                if i == 0 || i == nx * n - (nx - 1) - 1 || j == 0 || j == ny * n - (ny - 1) - 1 {
162                    let bc_type = BoundaryConditionType::Dirichlet; // Default, will be updated later
163                    boundary_nodes.push((node_count, bc_type));
164                }
165
166                node_count += 1;
167            }
168        }
169
170        // Initialize global to local mapping
171        let mut global_to_local = vec![Vec::new(); node_count];
172
173        // Now create elements using the global node indices
174        for ey in 0..ny {
175            for ex in 0..nx {
176                let element_id = ey * nx + ex;
177
178                // Element vertices (corners)
179                let vertices = [
180                    (x_min + ex as f64 * dx, y_min + ey as f64 * dy),
181                    (x_min + (ex + 1) as f64 * dx, y_min + ey as f64 * dy),
182                    (x_min + (ex + 1) as f64 * dx, y_min + (ey + 1) as f64 * dy),
183                    (x_min + ex as f64 * dx, y_min + (ey + 1) as f64 * dy),
184                ];
185
186                // Collect all node indices for this element
187                let mut node_indices = Vec::with_capacity(n * n);
188
189                for j in 0..n {
190                    for i in 0..n {
191                        let global_i = ex * (n - 1) + i;
192                        let global_j = ey * (n - 1) + j;
193
194                        let idx = global_indices[[global_i, global_j]];
195                        node_indices.push(idx);
196
197                        // Update global to local mapping
198                        global_to_local[idx].push((element_id, i, j));
199                    }
200                }
201
202                // Determine element boundary edges
203                let mut boundary_edges = Vec::new();
204
205                // Check if element is on domain boundary
206                if ex == 0 {
207                    // Left boundary
208                    boundary_edges.push((0, 3, Some(BoundaryConditionType::Dirichlet)));
209                }
210                if ex == nx - 1 {
211                    // Right boundary
212                    boundary_edges.push((1, 2, Some(BoundaryConditionType::Dirichlet)));
213                }
214                if ey == 0 {
215                    // Bottom boundary
216                    boundary_edges.push((0, 1, Some(BoundaryConditionType::Dirichlet)));
217                }
218                if ey == ny - 1 {
219                    // Top boundary
220                    boundary_edges.push((2, 3, Some(BoundaryConditionType::Dirichlet)));
221                }
222
223                // Create the element
224                let element = QuadElement {
225                    id: element_id,
226                    vertices,
227                    node_indices,
228                    boundary_edges,
229                };
230
231                elements.push(element);
232            }
233        }
234
235        Ok(SpectralElementMesh2D {
236            elements,
237            nodes,
238            global_to_local,
239            boundary_nodes,
240            order: (order, order),
241            num_nodes: node_count,
242        })
243    }
244
245    /// Update the boundary conditions on the mesh
246    ///
247    /// # Arguments
248    ///
249    /// * `boundary_conditions` - Vector of boundary conditions to apply
250    ///
251    /// # Returns
252    ///
253    /// * `PDEResult<()>` - Result indicating success or error
254    pub fn set_boundary_conditions(
255        &mut self,
256        boundary_conditions: &[BoundaryCondition<f64>],
257    ) -> PDEResult<()> {
258        // Map boundary _conditions to mesh boundary nodes
259        for bc in boundary_conditions {
260            let nodes_to_update = match bc.location {
261                BoundaryLocation::Lower => match bc.dimension {
262                    0 => self
263                        .boundary_nodes
264                        .iter()
265                        .enumerate()
266                        .filter(|(_, (idx, _))| self.nodes[*idx].0 < 1e-10)
267                        .map(|(i_, _)| i_)
268                        .collect::<Vec<_>>(),
269                    1 => self
270                        .boundary_nodes
271                        .iter()
272                        .enumerate()
273                        .filter(|(_, (idx, _))| self.nodes[*idx].1 < 1e-10)
274                        .map(|(i_, _)| i_)
275                        .collect::<Vec<_>>(),
276                    _ => {
277                        return Err(PDEError::DomainError(format!(
278                            "Invalid dimension: {}",
279                            bc.dimension
280                        )))
281                    }
282                },
283                BoundaryLocation::Upper => match bc.dimension {
284                    0 => self
285                        .boundary_nodes
286                        .iter()
287                        .enumerate()
288                        .filter(|(_, (idx, _))| (self.nodes[*idx].0 - 1.0).abs() < 1e-10)
289                        .map(|(i_, _)| i_)
290                        .collect::<Vec<_>>(),
291                    1 => self
292                        .boundary_nodes
293                        .iter()
294                        .enumerate()
295                        .filter(|(_, (idx, _))| (self.nodes[*idx].1 - 1.0).abs() < 1e-10)
296                        .map(|(i_, _)| i_)
297                        .collect::<Vec<_>>(),
298                    _ => {
299                        return Err(PDEError::DomainError(format!(
300                            "Invalid dimension: {}",
301                            bc.dimension
302                        )))
303                    }
304                },
305            };
306
307            // Update boundary node types
308            for node_idx in nodes_to_update {
309                self.boundary_nodes[node_idx].1 = bc.bc_type;
310            }
311
312            // Update element boundary edges
313            for element in &mut self.elements {
314                for (_, _, bc_type) in &mut element.boundary_edges {
315                    if let Some(ref mut bc_type_val) = bc_type {
316                        // Check if this boundary edge is on the specified boundary
317                        let edge_matches = match bc.location {
318                            BoundaryLocation::Lower => match bc.dimension {
319                                0 => element.vertices[0].0 < 1e-10 && element.vertices[3].0 < 1e-10,
320                                1 => element.vertices[0].1 < 1e-10 && element.vertices[1].1 < 1e-10,
321                                _ => false,
322                            },
323                            BoundaryLocation::Upper => match bc.dimension {
324                                0 => {
325                                    (element.vertices[1].0 - 1.0).abs() < 1e-10
326                                        && (element.vertices[2].0 - 1.0).abs() < 1e-10
327                                }
328                                1 => {
329                                    (element.vertices[2].1 - 1.0).abs() < 1e-10
330                                        && (element.vertices[3].1 - 1.0).abs() < 1e-10
331                                }
332                                _ => false,
333                            },
334                        };
335
336                        if edge_matches {
337                            *bc_type_val = bc.bc_type;
338                        }
339                    }
340                }
341            }
342        }
343
344        Ok(())
345    }
346}
347
348/// Options for spectral element methods
349#[derive(Debug, Clone)]
350pub struct SpectralElementOptions {
351    /// Polynomial order in each direction
352    pub order: usize,
353
354    /// Number of elements in x direction
355    pub nx: usize,
356
357    /// Number of elements in y direction
358    pub ny: usize,
359
360    /// Maximum iterations for iterative solvers
361    pub max_iterations: usize,
362
363    /// Tolerance for convergence
364    pub tolerance: f64,
365
366    /// Whether to save convergence history
367    pub save_convergence_history: bool,
368
369    /// Print detailed progress information
370    pub verbose: bool,
371}
372
373impl Default for SpectralElementOptions {
374    fn default() -> Self {
375        SpectralElementOptions {
376            order: 4,
377            nx: 4,
378            ny: 4,
379            max_iterations: 1000,
380            tolerance: 1e-10,
381            save_convergence_history: false,
382            verbose: false,
383        }
384    }
385}
386
387/// Result from spectral element method solution
388#[derive(Debug, Clone)]
389pub struct SpectralElementResult {
390    /// Solution values at each node
391    pub u: Array1<f64>,
392
393    /// Node coordinates (x, y)
394    pub nodes: Vec<(f64, f64)>,
395
396    /// Element connectivity
397    pub elements: Vec<QuadElement>,
398
399    /// Residual norm
400    pub residual_norm: f64,
401
402    /// Number of iterations performed
403    pub num_iterations: usize,
404
405    /// Computation time
406    pub computation_time: f64,
407
408    /// Convergence history
409    pub convergence_history: Option<Vec<f64>>,
410}
411
412/// 2D Poisson solver using spectral element method
413///
414/// Solves: ∇²u = f(x,y) with appropriate boundary conditions
415pub struct SpectralElementPoisson2D {
416    /// Computational domain
417    domain: Domain,
418
419    /// Source term function f(x, y)
420    source_term: Box<dyn Fn(f64, f64) -> f64 + Send + Sync>,
421
422    /// Boundary conditions
423    boundary_conditions: Vec<BoundaryCondition<f64>>,
424
425    /// Solver options
426    options: SpectralElementOptions,
427}
428
429impl SpectralElementPoisson2D {
430    /// Create a new spectral element Poisson solver
431    ///
432    /// # Arguments
433    ///
434    /// * `domain` - Computational domain
435    /// * `source_term` - Function f(x,y) for the Poisson equation ∇²u = f(x,y)
436    /// * `boundary_conditions` - Boundary conditions for the domain
437    /// * `options` - Solver options (or None for defaults)
438    ///
439    /// # Returns
440    ///
441    /// * `PDEResult<Self>` - New solver instance
442    pub fn new(
443        domain: Domain,
444        source_term: impl Fn(f64, f64) -> f64 + Send + Sync + 'static,
445        boundary_conditions: Vec<BoundaryCondition<f64>>,
446        options: Option<SpectralElementOptions>,
447    ) -> PDEResult<Self> {
448        // Validate domain
449        if domain.dimensions() != 2 {
450            return Err(PDEError::DomainError(
451                "Domain must be 2-dimensional for 2D spectral element solver".to_string(),
452            ));
453        }
454
455        // Validate boundary _conditions
456        for bc in &boundary_conditions {
457            if bc.dimension >= 2 {
458                return Err(PDEError::BoundaryConditions(format!(
459                    "Invalid dimension {} in boundary condition",
460                    bc.dimension
461                )));
462            }
463
464            match bc.bc_type {
465                BoundaryConditionType::Dirichlet => {}
466                BoundaryConditionType::Neumann => {}
467                BoundaryConditionType::Robin => {
468                    if bc.coefficients.is_none() {
469                        return Err(PDEError::BoundaryConditions(
470                            "Robin boundary _conditions require coefficients".to_string(),
471                        ));
472                    }
473                }
474                BoundaryConditionType::Periodic => {
475                    return Err(PDEError::BoundaryConditions(
476                        "Periodic boundary _conditions not implemented for spectral element method"
477                            .to_string(),
478                    ));
479                }
480            }
481        }
482
483        let options = options.unwrap_or_default();
484
485        Ok(SpectralElementPoisson2D {
486            domain,
487            source_term: Box::new(source_term),
488            boundary_conditions,
489            options,
490        })
491    }
492
493    /// Solve the Poisson equation using spectral element method
494    ///
495    /// # Returns
496    ///
497    /// * `PDEResult<SpectralElementResult>` - Solution result
498    pub fn solve(&self) -> PDEResult<SpectralElementResult> {
499        let start_time = Instant::now();
500
501        // Extract domain information
502        let x_range = &self.domain.ranges[0];
503        let y_range = &self.domain.ranges[1];
504
505        // Create spectral element mesh
506        let mut mesh = SpectralElementMesh2D::rectangular(
507            [x_range.start, x_range.end],
508            [y_range.start, y_range.end],
509            self.options.nx,
510            self.options.ny,
511            self.options.order,
512        )?;
513
514        // Set boundary conditions
515        mesh.set_boundary_conditions(&self.boundary_conditions)?;
516
517        // --- Local operations within each element ---
518
519        // Create differentiation matrices for reference element [-1, 1]²
520        let n = self.options.order + 1;
521        let d1_ref = legendre_diff_matrix(n);
522
523        // Reference Gauss-Lobatto-Legendre points and weights
524        let (xi, w) = legendre_points(n);
525
526        // Create stiffness and mass matrix for each element
527        let mut element_stiffness = vec![Array2::<f64>::zeros((n * n, n * n)); mesh.elements.len()];
528        let mut element_mass = vec![Array2::<f64>::zeros((n * n, n * n)); mesh.elements.len()];
529        let mut element_load = vec![Array1::<f64>::zeros(n * n); mesh.elements.len()];
530
531        for (e_idx, element) in mesh.elements.iter().enumerate() {
532            // Element vertices
533            let vertices = element.vertices;
534
535            // Element size
536            let dx = vertices[1].0 - vertices[0].0;
537            let dy = vertices[3].1 - vertices[0].1;
538
539            // Jacobian determinant (assuming rectangular elements)
540            let j_det = (dx * dy) / 4.0;
541
542            // Scaled differentiation matrices for physical element
543            let d1_x = d1_ref.mapv(|val| val * 2.0 / dx);
544            let d1_y = d1_ref.mapv(|val| val * 2.0 / dy);
545
546            // Precompute tensor products for efficiency
547            let mut dx_tensor = Array3::<f64>::zeros((n, n, n * n));
548            let mut dy_tensor = Array3::<f64>::zeros((n, n, n * n));
549
550            for j in 0..n {
551                for i in 0..n {
552                    let _node = j * n + i;
553
554                    // Fill tensor products
555                    for k in 0..n {
556                        // First pattern: dx_tensor[[i, j, k * n..(k + 1) * n]]
557                        dx_tensor
558                            .slice_mut(s![i, j, k * n..(k + 1) * n])
559                            .assign(&d1_x.slice(s![k, ..]));
560                        dy_tensor.slice_mut(s![i, j, k * n..(k + 1) * n]).fill(0.0);
561                    }
562
563                    // Second pattern: tensors at indices from k to n*n
564                    for k in 0..n {
565                        for idx in k..(n * n) {
566                            if idx % n == k {
567                                dx_tensor[[i, j, idx]] = 0.0;
568                                dy_tensor[[i, j, idx]] = d1_y[[k, idx / n]];
569                            }
570                        }
571                    }
572                }
573            }
574
575            // Compute element stiffness matrix: K_{ij} = ∫∫ (∇φ_i ⋅ ∇φ_j) dxdy
576            for i in 0..n * n {
577                for j in 0..n * n {
578                    // For Poisson equation: stiffness = integral of gradient dot product
579                    let mut stiffness_val = 0.0;
580
581                    for ni in 0..n {
582                        for nj in 0..n {
583                            // Compute ∇φ_i ⋅ ∇φ_j at quadrature point (ξ_ni, ξ_nj)
584                            let dx_i = dx_tensor[[ni, nj, i]];
585                            let dy_i = dy_tensor[[ni, nj, i]];
586                            let dx_j = dx_tensor[[ni, nj, j]];
587                            let dy_j = dy_tensor[[ni, nj, j]];
588
589                            // Gradient dot product: ∇φ_i ⋅ ∇φ_j
590                            let grad_dot = dx_i * dx_j + dy_i * dy_j;
591
592                            // Integrate with quadrature weights
593                            stiffness_val += grad_dot * w[ni] * w[nj] * j_det;
594                        }
595                    }
596
597                    element_stiffness[e_idx][[i, j]] = stiffness_val;
598                }
599            }
600
601            // Compute element mass matrix: M_{ij} = ∫∫ φ_i φ_j dxdy
602            for i in 0..n * n {
603                for j in 0..n * n {
604                    // For mass matrix: just the integral of basis function products
605                    let mut mass_val = 0.0;
606
607                    // Which local nodes do i and j correspond to?
608                    let i_local = (i % n, i / n);
609                    let _j_local = (j % n, j / n);
610
611                    // Mass matrix has nice tensor product structure for Lagrange polynomials
612                    // If nodes are far apart, the integral is zero (orthogonality)
613                    if i == j {
614                        // Diagonal element - we can use the quadrature weights directly
615                        mass_val = w[i_local.0] * w[i_local.1] * j_det;
616                    }
617
618                    element_mass[e_idx][[i, j]] = mass_val;
619                }
620            }
621
622            // Compute element load vector: f_i = ∫∫ f(x,y) φ_i dxdy
623            for i in 0..n * n {
624                let mut load_val = 0.0;
625
626                for ni in 0..n {
627                    for nj in 0..n {
628                        // Map reference coordinates to physical coordinates
629                        let x = vertices[0].0 + (xi[ni] + 1.0) * dx / 2.0;
630                        let y = vertices[0].1 + (xi[nj] + 1.0) * dy / 2.0;
631
632                        // Evaluate source term at quadrature point
633                        let source = (self.source_term)(x, y);
634
635                        // Local basis function value at quadrature point
636                        let i_local = (i % n, i / n);
637                        let basis_val = if i_local.0 == ni && i_local.1 == nj {
638                            1.0
639                        } else {
640                            0.0
641                        };
642
643                        // Integrate with quadrature weights
644                        load_val += source * basis_val * w[ni] * w[nj] * j_det;
645                    }
646                }
647
648                element_load[e_idx][i] = load_val;
649            }
650        }
651
652        // --- Global assembly ---
653
654        // Create global system: Au = b
655        let n_dof = mesh.num_nodes;
656        let mut global_matrix = Array2::<f64>::zeros((n_dof, n_dof));
657        let mut global_load = Array1::<f64>::zeros(n_dof);
658
659        // Assemble global matrix and load vector from element contributions
660        for (e_idx, element) in mesh.elements.iter().enumerate() {
661            for (i, &i_global) in element.node_indices.iter().enumerate() {
662                // Add element load to global load
663                global_load[i_global] += element_load[e_idx][i];
664
665                for (j, &j_global) in element.node_indices.iter().enumerate() {
666                    // Add element stiffness to global matrix
667                    global_matrix[[i_global, j_global]] += element_stiffness[e_idx][[i, j]];
668                }
669            }
670        }
671
672        // Apply boundary conditions
673        for &(node_idx, bc_type) in &mesh.boundary_nodes {
674            match bc_type {
675                BoundaryConditionType::Dirichlet => {
676                    // For Dirichlet, set matrix row to identity and load to value
677                    let (x, y) = mesh.nodes[node_idx];
678
679                    // Find the appropriate boundary condition value
680                    let mut bc_value = 0.0; // Default value
681
682                    for bc in &self.boundary_conditions {
683                        if bc.bc_type != BoundaryConditionType::Dirichlet {
684                            continue;
685                        }
686
687                        let is_on_boundary = match (bc.dimension, bc.location) {
688                            (0, BoundaryLocation::Lower) => x < 1e-10,
689                            (0, BoundaryLocation::Upper) => {
690                                (x - (x_range.end - x_range.start)).abs() < 1e-10
691                            }
692                            (1, BoundaryLocation::Lower) => y < 1e-10,
693                            (1, BoundaryLocation::Upper) => {
694                                (y - (y_range.end - y_range.start)).abs() < 1e-10
695                            }
696                            _ => false,
697                        };
698
699                        if is_on_boundary {
700                            bc_value = bc.value;
701                            break;
702                        }
703                    }
704
705                    // Clear row and set diagonal to 1
706                    for j in 0..n_dof {
707                        global_matrix[[node_idx, j]] = 0.0;
708                    }
709                    global_matrix[[node_idx, node_idx]] = 1.0;
710
711                    // Set load vector value
712                    global_load[node_idx] = bc_value;
713                }
714                BoundaryConditionType::Neumann => {
715                    // Neumann boundary conditions are natural in the weak form
716                    // They're already accounted for in the global assembly
717                    // unless there's a non-zero flux, in which case we need to
718                    // add boundary integrals
719
720                    // Find the appropriate boundary condition value
721                    let (x, y) = mesh.nodes[node_idx];
722
723                    for bc in &self.boundary_conditions {
724                        if bc.bc_type != BoundaryConditionType::Neumann {
725                            continue;
726                        }
727
728                        let is_on_boundary = match (bc.dimension, bc.location) {
729                            (0, BoundaryLocation::Lower) => x < 1e-10,
730                            (0, BoundaryLocation::Upper) => {
731                                (x - (x_range.end - x_range.start)).abs() < 1e-10
732                            }
733                            (1, BoundaryLocation::Lower) => y < 1e-10,
734                            (1, BoundaryLocation::Upper) => {
735                                (y - (y_range.end - y_range.start)).abs() < 1e-10
736                            }
737                            _ => false,
738                        };
739
740                        if is_on_boundary {
741                            let bc_value = bc.value;
742
743                            // For non-zero Neumann BC, we need to add the boundary integral
744                            if bc_value != 0.0 {
745                                // This implementation is simplified - in a complete solver
746                                // we would compute boundary integrals properly
747                                global_load[node_idx] += bc_value;
748                            }
749
750                            break;
751                        }
752                    }
753                }
754                BoundaryConditionType::Robin => {
755                    // Robin boundary conditions combine Dirichlet and Neumann
756                    // For simplicity, we approximate them here by considering
757                    // only the constant term of the form a*u + b*du/dn = c
758
759                    // Find the appropriate boundary condition coefficients
760                    let (x, y) = mesh.nodes[node_idx];
761
762                    for bc in &self.boundary_conditions {
763                        if bc.bc_type != BoundaryConditionType::Robin {
764                            continue;
765                        }
766
767                        let is_on_boundary = match (bc.dimension, bc.location) {
768                            (0, BoundaryLocation::Lower) => x < 1e-10,
769                            (0, BoundaryLocation::Upper) => {
770                                (x - (x_range.end - x_range.start)).abs() < 1e-10
771                            }
772                            (1, BoundaryLocation::Lower) => y < 1e-10,
773                            (1, BoundaryLocation::Upper) => {
774                                (y - (y_range.end - y_range.start)).abs() < 1e-10
775                            }
776                            _ => false,
777                        };
778
779                        if is_on_boundary {
780                            if let Some([a_b, c, _]) = bc.coefficients {
781                                // For Robin BCs: a*u + b*du/dn = c
782                                // We need to modify the matrix and load vector
783                                global_matrix[[node_idx, node_idx]] += a_b;
784                                global_load[node_idx] += c;
785                            }
786
787                            break;
788                        }
789                    }
790                }
791                _ => {
792                    return Err(PDEError::BoundaryConditions(
793                        "Unsupported boundary condition type".to_string(),
794                    ));
795                }
796            }
797        }
798
799        // Solve the linear system
800        let solution =
801            SpectralElementPoisson2D::solve_linear_system(&global_matrix, &global_load.view())?;
802
803        // Compute residual
804        let global_residual = {
805            let mut residual = global_matrix.dot(&solution) - &global_load;
806
807            // Exclude boundary points from residual calculation
808            for &(node_idx, bc_type) in &mesh.boundary_nodes {
809                if bc_type == BoundaryConditionType::Dirichlet {
810                    residual[node_idx] = 0.0;
811                }
812            }
813
814            residual
815        };
816
817        let residual_norm = (global_residual.iter().map(|&r| r * r).sum::<f64>()
818            / (n_dof - mesh.boundary_nodes.len()) as f64)
819            .sqrt();
820
821        let computation_time = start_time.elapsed().as_secs_f64();
822
823        Ok(SpectralElementResult {
824            u: solution,
825            nodes: mesh.nodes.clone(),
826            elements: mesh.elements.clone(),
827            residual_norm,
828            num_iterations: 1, // Direct solve
829            computation_time,
830            convergence_history: None,
831        })
832    }
833
834    /// Solve the linear system Ax = b
835    fn solve_linear_system(a: &Array2<f64>, b: &ArrayView1<f64>) -> PDEResult<Array1<f64>> {
836        let n = b.len();
837
838        // Simple Gaussian elimination with partial pivoting
839        // For a real implementation, use a specialized linear algebra library
840
841        // Create copies of A and b
842        let mut a_copy = a.clone();
843        let mut b_copy = b.to_owned();
844
845        // Forward elimination
846        for i in 0..n {
847            // Find pivot
848            let mut max_val = a_copy[[i, i]].abs();
849            let mut max_row = i;
850
851            for k in i + 1..n {
852                if a_copy[[k, i]].abs() > max_val {
853                    max_val = a_copy[[k, i]].abs();
854                    max_row = k;
855                }
856            }
857
858            // Check if matrix is singular
859            if max_val < 1e-10 {
860                return Err(PDEError::Other(
861                    "Matrix is singular or nearly singular".to_string(),
862                ));
863            }
864
865            // Swap rows if necessary
866            if max_row != i {
867                for j in i..n {
868                    let temp = a_copy[[i, j]];
869                    a_copy[[i, j]] = a_copy[[max_row, j]];
870                    a_copy[[max_row, j]] = temp;
871                }
872
873                let temp = b_copy[i];
874                b_copy[i] = b_copy[max_row];
875                b_copy[max_row] = temp;
876            }
877
878            // Eliminate below
879            for k in i + 1..n {
880                let factor = a_copy[[k, i]] / a_copy[[i, i]];
881
882                for j in i..n {
883                    a_copy[[k, j]] -= factor * a_copy[[i, j]];
884                }
885
886                b_copy[k] -= factor * b_copy[i];
887            }
888        }
889
890        // Back substitution
891        let mut x = Array1::zeros(n);
892        for i in (0..n).rev() {
893            let mut sum = 0.0;
894            for j in i + 1..n {
895                sum += a_copy[[i, j]] * x[j];
896            }
897
898            x[i] = (b_copy[i] - sum) / a_copy[[i, i]];
899        }
900
901        Ok(x)
902    }
903}
904
905impl From<SpectralElementResult> for PDESolution<f64> {
906    fn from(result: SpectralElementResult) -> Self {
907        // Extract node coordinates for grids
908        let mut x_coords = Vec::new();
909        let mut y_coords = Vec::new();
910
911        for &(x, y) in &result.nodes {
912            x_coords.push(x);
913            y_coords.push(y);
914        }
915
916        // Create unique sorted x and y coordinates for grid
917        x_coords.sort_by(|a, b| a.partial_cmp(b).unwrap());
918        y_coords.sort_by(|a, b| a.partial_cmp(b).unwrap());
919
920        x_coords.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
921        y_coords.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
922
923        let grids = vec![Array1::from_vec(x_coords), Array1::from_vec(y_coords)];
924
925        // Create solution values as a 2D array for each grid point
926        let mut values = Vec::new();
927        let n_points = result.u.len();
928        let u_reshaped = result.u.into_shape_with_order((n_points, 1)).unwrap();
929        values.push(u_reshaped);
930
931        // Create solver info
932        let info = PDESolverInfo {
933            num_iterations: result.num_iterations,
934            computation_time: result.computation_time,
935            residual_norm: Some(result.residual_norm),
936            convergence_history: result.convergence_history,
937            method: "Spectral Element Method".to_string(),
938        };
939
940        PDESolution {
941            grids,
942            values,
943            error_estimate: None,
944            info,
945        }
946    }
947}