scirs2_integrate/pde/spectral/spectral_element.rs
1//! Spectral Element Methods (SEM) for solving PDEs
2//!
3//! This module provides implementations of spectral element methods,
4//! which combine the accuracy of spectral methods with the geometric
5//! flexibility of finite element methods. This approach allows for
6//! high-order polynomial approximations on complex geometries.
7//!
8//! Key features:
9//! - High-order polynomial basis functions (using nodal Lagrange polynomials)
10//! - Domain decomposition into elements (quadrilaterals in 2D)
11//! - Gauss-Lobatto-Legendre quadrature for integration
12//! - Isoparametric mapping for curved elements
13//! - Exponential convergence for smooth solutions
14
15use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1};
16use std::time::Instant;
17
18use crate::pde::spectral::{legendre_diff_matrix, legendre_points};
19use crate::pde::{
20 BoundaryCondition, BoundaryConditionType, BoundaryLocation, Domain, PDEError, PDEResult,
21 PDESolution, PDESolverInfo,
22};
23
24/// Quadrilateral element for 2D spectral element methods
25#[derive(Debug, Clone)]
26pub struct QuadElement {
27 /// Element ID
28 pub id: usize,
29
30 /// Global coordinates of element vertices (4 corners, counterclockwise ordering)
31 pub vertices: [(f64, f64); 4],
32
33 /// Global indices of nodes in this element
34 pub node_indices: Vec<usize>,
35
36 /// Element boundary conditions (if any)
37 pub boundary_edges: Vec<(usize, usize, Option<BoundaryConditionType>)>,
38}
39
40/// Spectral element mesh for 2D problems
41#[derive(Debug, Clone)]
42pub struct SpectralElementMesh2D {
43 /// Elements in the mesh
44 pub elements: Vec<QuadElement>,
45
46 /// Global node coordinates (x, y)
47 pub nodes: Vec<(f64, f64)>,
48
49 /// Global-to-local mapping for nodes
50 pub global_to_local: Vec<Vec<(usize, usize, usize)>>, // (element_id, i, j)
51
52 /// Boundary nodes with condition info
53 pub boundary_nodes: Vec<(usize, BoundaryConditionType)>,
54
55 /// Polynomial order in each direction
56 pub order: (usize, usize),
57
58 /// Total number of nodes in the mesh
59 pub num_nodes: usize,
60}
61
62impl SpectralElementMesh2D {
63 /// Create a new rectangular spectral element mesh
64 ///
65 /// # Arguments
66 ///
67 /// * `x_range` - Range for the x-coordinate domain [x_min, x_max]
68 /// * `y_range` - Range for the y-coordinate domain [y_min, y_max]
69 /// * `nx` - Number of elements in the x direction
70 /// * `ny` - Number of elements in the y direction
71 /// * `order` - Polynomial order in each element (p, p)
72 ///
73 /// # Returns
74 ///
75 /// * A structured rectangular mesh of quadrilateral elements
76 pub fn rectangular(
77 x_range: [f64; 2],
78 y_range: [f64; 2],
79 nx: usize,
80 ny: usize,
81 order: usize,
82 ) -> PDEResult<Self> {
83 if nx == 0 || ny == 0 {
84 return Err(PDEError::DomainError(
85 "Number of elements must be at least 1 in each direction".to_string(),
86 ));
87 }
88
89 if order < 1 {
90 return Err(PDEError::DomainError(
91 "Polynomial order must be at least 1".to_string(),
92 ));
93 }
94
95 let [x_min, x_max] = x_range;
96 let [y_min, y_max] = y_range;
97
98 let dx = (x_max - x_min) / nx as f64;
99 let dy = (y_max - y_min) / ny as f64;
100
101 // Number of nodes in each direction per element (order + 1)
102 let n = order + 1;
103
104 // Generate Gauss-Lobatto-Legendre points in 1D (scaled to [0, 1])
105 let (xi_pts_, weights) = legendre_points(n);
106 let xi = (xi_pts_ + 1.0) * 0.5;
107
108 // Create elements and nodes
109 let mut elements = Vec::with_capacity(nx * ny);
110 let mut nodes = Vec::new();
111 let mut boundary_nodes = Vec::new();
112
113 // Global node counter
114 let mut node_count = 0;
115
116 // Temporary global node indices for each element
117 let mut global_indices = Array2::<usize>::zeros((nx * n - (nx - 1), ny * n - (ny - 1)));
118
119 // First, create all nodes and assign global indices
120 for j in 0..(ny * n - (ny - 1)) {
121 for i in 0..(nx * n - (nx - 1)) {
122 // Determine if this is an element edge or interior node
123 let _is_edge = i % n == 0
124 || i == nx * n - (nx - 1) - 1
125 || j % n == 0
126 || j == ny * n - (ny - 1) - 1;
127
128 // Compute physical coordinates
129 let x_idx = i / n; // Element index in x direction
130 let y_idx = j / n; // Element index in y direction
131
132 let local_i = i % n; // Local index within element
133 let local_j = j % n;
134
135 // Handle shared nodes between elements
136 if (local_i == 0 && x_idx > 0) || (local_j == 0 && y_idx > 0) {
137 // This node is shared with a previous element
138 // Skip node creation but still assign the global index
139 if local_i == 0 && local_j == 0 && x_idx > 0 && y_idx > 0 {
140 // Corner node shared by 4 elements
141 global_indices[[i, j]] = global_indices[[i - n + 1, j - n + 1]];
142 } else if local_i == 0 && x_idx > 0 {
143 // Edge node shared horizontally
144 global_indices[[i, j]] = global_indices[[i - n + 1, j]];
145 } else if local_j == 0 && y_idx > 0 {
146 // Edge node shared vertically
147 global_indices[[i, j]] = global_indices[[i, j - n + 1]];
148 }
149 continue;
150 }
151
152 // Map to physical coordinates using element index and local coordinates
153 let x = x_min + (x_idx as f64 + xi[local_i]) * dx;
154 let y = y_min + (y_idx as f64 + xi[local_j]) * dy;
155
156 // Create node and store index
157 nodes.push((x, y));
158 global_indices[[i, j]] = node_count;
159
160 // Add to boundary nodes if on the domain boundary
161 if i == 0 || i == nx * n - (nx - 1) - 1 || j == 0 || j == ny * n - (ny - 1) - 1 {
162 let bc_type = BoundaryConditionType::Dirichlet; // Default, will be updated later
163 boundary_nodes.push((node_count, bc_type));
164 }
165
166 node_count += 1;
167 }
168 }
169
170 // Initialize global to local mapping
171 let mut global_to_local = vec![Vec::new(); node_count];
172
173 // Now create elements using the global node indices
174 for ey in 0..ny {
175 for ex in 0..nx {
176 let element_id = ey * nx + ex;
177
178 // Element vertices (corners)
179 let vertices = [
180 (x_min + ex as f64 * dx, y_min + ey as f64 * dy),
181 (x_min + (ex + 1) as f64 * dx, y_min + ey as f64 * dy),
182 (x_min + (ex + 1) as f64 * dx, y_min + (ey + 1) as f64 * dy),
183 (x_min + ex as f64 * dx, y_min + (ey + 1) as f64 * dy),
184 ];
185
186 // Collect all node indices for this element
187 let mut node_indices = Vec::with_capacity(n * n);
188
189 for j in 0..n {
190 for i in 0..n {
191 let global_i = ex * (n - 1) + i;
192 let global_j = ey * (n - 1) + j;
193
194 let idx = global_indices[[global_i, global_j]];
195 node_indices.push(idx);
196
197 // Update global to local mapping
198 global_to_local[idx].push((element_id, i, j));
199 }
200 }
201
202 // Determine element boundary edges
203 let mut boundary_edges = Vec::new();
204
205 // Check if element is on domain boundary
206 if ex == 0 {
207 // Left boundary
208 boundary_edges.push((0, 3, Some(BoundaryConditionType::Dirichlet)));
209 }
210 if ex == nx - 1 {
211 // Right boundary
212 boundary_edges.push((1, 2, Some(BoundaryConditionType::Dirichlet)));
213 }
214 if ey == 0 {
215 // Bottom boundary
216 boundary_edges.push((0, 1, Some(BoundaryConditionType::Dirichlet)));
217 }
218 if ey == ny - 1 {
219 // Top boundary
220 boundary_edges.push((2, 3, Some(BoundaryConditionType::Dirichlet)));
221 }
222
223 // Create the element
224 let element = QuadElement {
225 id: element_id,
226 vertices,
227 node_indices,
228 boundary_edges,
229 };
230
231 elements.push(element);
232 }
233 }
234
235 Ok(SpectralElementMesh2D {
236 elements,
237 nodes,
238 global_to_local,
239 boundary_nodes,
240 order: (order, order),
241 num_nodes: node_count,
242 })
243 }
244
245 /// Update the boundary conditions on the mesh
246 ///
247 /// # Arguments
248 ///
249 /// * `boundary_conditions` - Vector of boundary conditions to apply
250 ///
251 /// # Returns
252 ///
253 /// * `PDEResult<()>` - Result indicating success or error
254 pub fn set_boundary_conditions(
255 &mut self,
256 boundary_conditions: &[BoundaryCondition<f64>],
257 ) -> PDEResult<()> {
258 // Map boundary _conditions to mesh boundary nodes
259 for bc in boundary_conditions {
260 let nodes_to_update = match bc.location {
261 BoundaryLocation::Lower => match bc.dimension {
262 0 => self
263 .boundary_nodes
264 .iter()
265 .enumerate()
266 .filter(|(_, (idx, _))| self.nodes[*idx].0 < 1e-10)
267 .map(|(i_, _)| i_)
268 .collect::<Vec<_>>(),
269 1 => self
270 .boundary_nodes
271 .iter()
272 .enumerate()
273 .filter(|(_, (idx, _))| self.nodes[*idx].1 < 1e-10)
274 .map(|(i_, _)| i_)
275 .collect::<Vec<_>>(),
276 _ => {
277 return Err(PDEError::DomainError(format!(
278 "Invalid dimension: {}",
279 bc.dimension
280 )))
281 }
282 },
283 BoundaryLocation::Upper => match bc.dimension {
284 0 => self
285 .boundary_nodes
286 .iter()
287 .enumerate()
288 .filter(|(_, (idx, _))| (self.nodes[*idx].0 - 1.0).abs() < 1e-10)
289 .map(|(i_, _)| i_)
290 .collect::<Vec<_>>(),
291 1 => self
292 .boundary_nodes
293 .iter()
294 .enumerate()
295 .filter(|(_, (idx, _))| (self.nodes[*idx].1 - 1.0).abs() < 1e-10)
296 .map(|(i_, _)| i_)
297 .collect::<Vec<_>>(),
298 _ => {
299 return Err(PDEError::DomainError(format!(
300 "Invalid dimension: {}",
301 bc.dimension
302 )))
303 }
304 },
305 };
306
307 // Update boundary node types
308 for node_idx in nodes_to_update {
309 self.boundary_nodes[node_idx].1 = bc.bc_type;
310 }
311
312 // Update element boundary edges
313 for element in &mut self.elements {
314 for (_, _, bc_type) in &mut element.boundary_edges {
315 if let Some(ref mut bc_type_val) = bc_type {
316 // Check if this boundary edge is on the specified boundary
317 let edge_matches = match bc.location {
318 BoundaryLocation::Lower => match bc.dimension {
319 0 => element.vertices[0].0 < 1e-10 && element.vertices[3].0 < 1e-10,
320 1 => element.vertices[0].1 < 1e-10 && element.vertices[1].1 < 1e-10,
321 _ => false,
322 },
323 BoundaryLocation::Upper => match bc.dimension {
324 0 => {
325 (element.vertices[1].0 - 1.0).abs() < 1e-10
326 && (element.vertices[2].0 - 1.0).abs() < 1e-10
327 }
328 1 => {
329 (element.vertices[2].1 - 1.0).abs() < 1e-10
330 && (element.vertices[3].1 - 1.0).abs() < 1e-10
331 }
332 _ => false,
333 },
334 };
335
336 if edge_matches {
337 *bc_type_val = bc.bc_type;
338 }
339 }
340 }
341 }
342 }
343
344 Ok(())
345 }
346}
347
348/// Options for spectral element methods
349#[derive(Debug, Clone)]
350pub struct SpectralElementOptions {
351 /// Polynomial order in each direction
352 pub order: usize,
353
354 /// Number of elements in x direction
355 pub nx: usize,
356
357 /// Number of elements in y direction
358 pub ny: usize,
359
360 /// Maximum iterations for iterative solvers
361 pub max_iterations: usize,
362
363 /// Tolerance for convergence
364 pub tolerance: f64,
365
366 /// Whether to save convergence history
367 pub save_convergence_history: bool,
368
369 /// Print detailed progress information
370 pub verbose: bool,
371}
372
373impl Default for SpectralElementOptions {
374 fn default() -> Self {
375 SpectralElementOptions {
376 order: 4,
377 nx: 4,
378 ny: 4,
379 max_iterations: 1000,
380 tolerance: 1e-10,
381 save_convergence_history: false,
382 verbose: false,
383 }
384 }
385}
386
387/// Result from spectral element method solution
388#[derive(Debug, Clone)]
389pub struct SpectralElementResult {
390 /// Solution values at each node
391 pub u: Array1<f64>,
392
393 /// Node coordinates (x, y)
394 pub nodes: Vec<(f64, f64)>,
395
396 /// Element connectivity
397 pub elements: Vec<QuadElement>,
398
399 /// Residual norm
400 pub residual_norm: f64,
401
402 /// Number of iterations performed
403 pub num_iterations: usize,
404
405 /// Computation time
406 pub computation_time: f64,
407
408 /// Convergence history
409 pub convergence_history: Option<Vec<f64>>,
410}
411
412/// 2D Poisson solver using spectral element method
413///
414/// Solves: ∇²u = f(x,y) with appropriate boundary conditions
415pub struct SpectralElementPoisson2D {
416 /// Computational domain
417 domain: Domain,
418
419 /// Source term function f(x, y)
420 source_term: Box<dyn Fn(f64, f64) -> f64 + Send + Sync>,
421
422 /// Boundary conditions
423 boundary_conditions: Vec<BoundaryCondition<f64>>,
424
425 /// Solver options
426 options: SpectralElementOptions,
427}
428
429impl SpectralElementPoisson2D {
430 /// Create a new spectral element Poisson solver
431 ///
432 /// # Arguments
433 ///
434 /// * `domain` - Computational domain
435 /// * `source_term` - Function f(x,y) for the Poisson equation ∇²u = f(x,y)
436 /// * `boundary_conditions` - Boundary conditions for the domain
437 /// * `options` - Solver options (or None for defaults)
438 ///
439 /// # Returns
440 ///
441 /// * `PDEResult<Self>` - New solver instance
442 pub fn new(
443 domain: Domain,
444 source_term: impl Fn(f64, f64) -> f64 + Send + Sync + 'static,
445 boundary_conditions: Vec<BoundaryCondition<f64>>,
446 options: Option<SpectralElementOptions>,
447 ) -> PDEResult<Self> {
448 // Validate domain
449 if domain.dimensions() != 2 {
450 return Err(PDEError::DomainError(
451 "Domain must be 2-dimensional for 2D spectral element solver".to_string(),
452 ));
453 }
454
455 // Validate boundary _conditions
456 for bc in &boundary_conditions {
457 if bc.dimension >= 2 {
458 return Err(PDEError::BoundaryConditions(format!(
459 "Invalid dimension {} in boundary condition",
460 bc.dimension
461 )));
462 }
463
464 match bc.bc_type {
465 BoundaryConditionType::Dirichlet => {}
466 BoundaryConditionType::Neumann => {}
467 BoundaryConditionType::Robin => {
468 if bc.coefficients.is_none() {
469 return Err(PDEError::BoundaryConditions(
470 "Robin boundary _conditions require coefficients".to_string(),
471 ));
472 }
473 }
474 BoundaryConditionType::Periodic => {
475 return Err(PDEError::BoundaryConditions(
476 "Periodic boundary _conditions not implemented for spectral element method"
477 .to_string(),
478 ));
479 }
480 }
481 }
482
483 let options = options.unwrap_or_default();
484
485 Ok(SpectralElementPoisson2D {
486 domain,
487 source_term: Box::new(source_term),
488 boundary_conditions,
489 options,
490 })
491 }
492
493 /// Solve the Poisson equation using spectral element method
494 ///
495 /// # Returns
496 ///
497 /// * `PDEResult<SpectralElementResult>` - Solution result
498 pub fn solve(&self) -> PDEResult<SpectralElementResult> {
499 let start_time = Instant::now();
500
501 // Extract domain information
502 let x_range = &self.domain.ranges[0];
503 let y_range = &self.domain.ranges[1];
504
505 // Create spectral element mesh
506 let mut mesh = SpectralElementMesh2D::rectangular(
507 [x_range.start, x_range.end],
508 [y_range.start, y_range.end],
509 self.options.nx,
510 self.options.ny,
511 self.options.order,
512 )?;
513
514 // Set boundary conditions
515 mesh.set_boundary_conditions(&self.boundary_conditions)?;
516
517 // --- Local operations within each element ---
518
519 // Create differentiation matrices for reference element [-1, 1]²
520 let n = self.options.order + 1;
521 let d1_ref = legendre_diff_matrix(n);
522
523 // Reference Gauss-Lobatto-Legendre points and weights
524 let (xi, w) = legendre_points(n);
525
526 // Create stiffness and mass matrix for each element
527 let mut element_stiffness = vec![Array2::<f64>::zeros((n * n, n * n)); mesh.elements.len()];
528 let mut element_mass = vec![Array2::<f64>::zeros((n * n, n * n)); mesh.elements.len()];
529 let mut element_load = vec![Array1::<f64>::zeros(n * n); mesh.elements.len()];
530
531 for (e_idx, element) in mesh.elements.iter().enumerate() {
532 // Element vertices
533 let vertices = element.vertices;
534
535 // Element size
536 let dx = vertices[1].0 - vertices[0].0;
537 let dy = vertices[3].1 - vertices[0].1;
538
539 // Jacobian determinant (assuming rectangular elements)
540 let j_det = (dx * dy) / 4.0;
541
542 // Scaled differentiation matrices for physical element
543 let d1_x = d1_ref.mapv(|val| val * 2.0 / dx);
544 let d1_y = d1_ref.mapv(|val| val * 2.0 / dy);
545
546 // Precompute tensor products for efficiency
547 let mut dx_tensor = Array3::<f64>::zeros((n, n, n * n));
548 let mut dy_tensor = Array3::<f64>::zeros((n, n, n * n));
549
550 for j in 0..n {
551 for i in 0..n {
552 let _node = j * n + i;
553
554 // Fill tensor products
555 for k in 0..n {
556 // First pattern: dx_tensor[[i, j, k * n..(k + 1) * n]]
557 dx_tensor
558 .slice_mut(s![i, j, k * n..(k + 1) * n])
559 .assign(&d1_x.slice(s![k, ..]));
560 dy_tensor.slice_mut(s![i, j, k * n..(k + 1) * n]).fill(0.0);
561 }
562
563 // Second pattern: tensors at indices from k to n*n
564 for k in 0..n {
565 for idx in k..(n * n) {
566 if idx % n == k {
567 dx_tensor[[i, j, idx]] = 0.0;
568 dy_tensor[[i, j, idx]] = d1_y[[k, idx / n]];
569 }
570 }
571 }
572 }
573 }
574
575 // Compute element stiffness matrix: K_{ij} = ∫∫ (∇φ_i ⋅ ∇φ_j) dxdy
576 for i in 0..n * n {
577 for j in 0..n * n {
578 // For Poisson equation: stiffness = integral of gradient dot product
579 let mut stiffness_val = 0.0;
580
581 for ni in 0..n {
582 for nj in 0..n {
583 // Compute ∇φ_i ⋅ ∇φ_j at quadrature point (ξ_ni, ξ_nj)
584 let dx_i = dx_tensor[[ni, nj, i]];
585 let dy_i = dy_tensor[[ni, nj, i]];
586 let dx_j = dx_tensor[[ni, nj, j]];
587 let dy_j = dy_tensor[[ni, nj, j]];
588
589 // Gradient dot product: ∇φ_i ⋅ ∇φ_j
590 let grad_dot = dx_i * dx_j + dy_i * dy_j;
591
592 // Integrate with quadrature weights
593 stiffness_val += grad_dot * w[ni] * w[nj] * j_det;
594 }
595 }
596
597 element_stiffness[e_idx][[i, j]] = stiffness_val;
598 }
599 }
600
601 // Compute element mass matrix: M_{ij} = ∫∫ φ_i φ_j dxdy
602 for i in 0..n * n {
603 for j in 0..n * n {
604 // For mass matrix: just the integral of basis function products
605 let mut mass_val = 0.0;
606
607 // Which local nodes do i and j correspond to?
608 let i_local = (i % n, i / n);
609 let _j_local = (j % n, j / n);
610
611 // Mass matrix has nice tensor product structure for Lagrange polynomials
612 // If nodes are far apart, the integral is zero (orthogonality)
613 if i == j {
614 // Diagonal element - we can use the quadrature weights directly
615 mass_val = w[i_local.0] * w[i_local.1] * j_det;
616 }
617
618 element_mass[e_idx][[i, j]] = mass_val;
619 }
620 }
621
622 // Compute element load vector: f_i = ∫∫ f(x,y) φ_i dxdy
623 for i in 0..n * n {
624 let mut load_val = 0.0;
625
626 for ni in 0..n {
627 for nj in 0..n {
628 // Map reference coordinates to physical coordinates
629 let x = vertices[0].0 + (xi[ni] + 1.0) * dx / 2.0;
630 let y = vertices[0].1 + (xi[nj] + 1.0) * dy / 2.0;
631
632 // Evaluate source term at quadrature point
633 let source = (self.source_term)(x, y);
634
635 // Local basis function value at quadrature point
636 let i_local = (i % n, i / n);
637 let basis_val = if i_local.0 == ni && i_local.1 == nj {
638 1.0
639 } else {
640 0.0
641 };
642
643 // Integrate with quadrature weights
644 load_val += source * basis_val * w[ni] * w[nj] * j_det;
645 }
646 }
647
648 element_load[e_idx][i] = load_val;
649 }
650 }
651
652 // --- Global assembly ---
653
654 // Create global system: Au = b
655 let n_dof = mesh.num_nodes;
656 let mut global_matrix = Array2::<f64>::zeros((n_dof, n_dof));
657 let mut global_load = Array1::<f64>::zeros(n_dof);
658
659 // Assemble global matrix and load vector from element contributions
660 for (e_idx, element) in mesh.elements.iter().enumerate() {
661 for (i, &i_global) in element.node_indices.iter().enumerate() {
662 // Add element load to global load
663 global_load[i_global] += element_load[e_idx][i];
664
665 for (j, &j_global) in element.node_indices.iter().enumerate() {
666 // Add element stiffness to global matrix
667 global_matrix[[i_global, j_global]] += element_stiffness[e_idx][[i, j]];
668 }
669 }
670 }
671
672 // Apply boundary conditions
673 for &(node_idx, bc_type) in &mesh.boundary_nodes {
674 match bc_type {
675 BoundaryConditionType::Dirichlet => {
676 // For Dirichlet, set matrix row to identity and load to value
677 let (x, y) = mesh.nodes[node_idx];
678
679 // Find the appropriate boundary condition value
680 let mut bc_value = 0.0; // Default value
681
682 for bc in &self.boundary_conditions {
683 if bc.bc_type != BoundaryConditionType::Dirichlet {
684 continue;
685 }
686
687 let is_on_boundary = match (bc.dimension, bc.location) {
688 (0, BoundaryLocation::Lower) => x < 1e-10,
689 (0, BoundaryLocation::Upper) => {
690 (x - (x_range.end - x_range.start)).abs() < 1e-10
691 }
692 (1, BoundaryLocation::Lower) => y < 1e-10,
693 (1, BoundaryLocation::Upper) => {
694 (y - (y_range.end - y_range.start)).abs() < 1e-10
695 }
696 _ => false,
697 };
698
699 if is_on_boundary {
700 bc_value = bc.value;
701 break;
702 }
703 }
704
705 // Clear row and set diagonal to 1
706 for j in 0..n_dof {
707 global_matrix[[node_idx, j]] = 0.0;
708 }
709 global_matrix[[node_idx, node_idx]] = 1.0;
710
711 // Set load vector value
712 global_load[node_idx] = bc_value;
713 }
714 BoundaryConditionType::Neumann => {
715 // Neumann boundary conditions are natural in the weak form
716 // They're already accounted for in the global assembly
717 // unless there's a non-zero flux, in which case we need to
718 // add boundary integrals
719
720 // Find the appropriate boundary condition value
721 let (x, y) = mesh.nodes[node_idx];
722
723 for bc in &self.boundary_conditions {
724 if bc.bc_type != BoundaryConditionType::Neumann {
725 continue;
726 }
727
728 let is_on_boundary = match (bc.dimension, bc.location) {
729 (0, BoundaryLocation::Lower) => x < 1e-10,
730 (0, BoundaryLocation::Upper) => {
731 (x - (x_range.end - x_range.start)).abs() < 1e-10
732 }
733 (1, BoundaryLocation::Lower) => y < 1e-10,
734 (1, BoundaryLocation::Upper) => {
735 (y - (y_range.end - y_range.start)).abs() < 1e-10
736 }
737 _ => false,
738 };
739
740 if is_on_boundary {
741 let bc_value = bc.value;
742
743 // For non-zero Neumann BC, we need to add the boundary integral
744 if bc_value != 0.0 {
745 // This implementation is simplified - in a complete solver
746 // we would compute boundary integrals properly
747 global_load[node_idx] += bc_value;
748 }
749
750 break;
751 }
752 }
753 }
754 BoundaryConditionType::Robin => {
755 // Robin boundary conditions combine Dirichlet and Neumann
756 // For simplicity, we approximate them here by considering
757 // only the constant term of the form a*u + b*du/dn = c
758
759 // Find the appropriate boundary condition coefficients
760 let (x, y) = mesh.nodes[node_idx];
761
762 for bc in &self.boundary_conditions {
763 if bc.bc_type != BoundaryConditionType::Robin {
764 continue;
765 }
766
767 let is_on_boundary = match (bc.dimension, bc.location) {
768 (0, BoundaryLocation::Lower) => x < 1e-10,
769 (0, BoundaryLocation::Upper) => {
770 (x - (x_range.end - x_range.start)).abs() < 1e-10
771 }
772 (1, BoundaryLocation::Lower) => y < 1e-10,
773 (1, BoundaryLocation::Upper) => {
774 (y - (y_range.end - y_range.start)).abs() < 1e-10
775 }
776 _ => false,
777 };
778
779 if is_on_boundary {
780 if let Some([a_b, c, _]) = bc.coefficients {
781 // For Robin BCs: a*u + b*du/dn = c
782 // We need to modify the matrix and load vector
783 global_matrix[[node_idx, node_idx]] += a_b;
784 global_load[node_idx] += c;
785 }
786
787 break;
788 }
789 }
790 }
791 _ => {
792 return Err(PDEError::BoundaryConditions(
793 "Unsupported boundary condition type".to_string(),
794 ));
795 }
796 }
797 }
798
799 // Solve the linear system
800 let solution =
801 SpectralElementPoisson2D::solve_linear_system(&global_matrix, &global_load.view())?;
802
803 // Compute residual
804 let global_residual = {
805 let mut residual = global_matrix.dot(&solution) - &global_load;
806
807 // Exclude boundary points from residual calculation
808 for &(node_idx, bc_type) in &mesh.boundary_nodes {
809 if bc_type == BoundaryConditionType::Dirichlet {
810 residual[node_idx] = 0.0;
811 }
812 }
813
814 residual
815 };
816
817 let residual_norm = (global_residual.iter().map(|&r| r * r).sum::<f64>()
818 / (n_dof - mesh.boundary_nodes.len()) as f64)
819 .sqrt();
820
821 let computation_time = start_time.elapsed().as_secs_f64();
822
823 Ok(SpectralElementResult {
824 u: solution,
825 nodes: mesh.nodes.clone(),
826 elements: mesh.elements.clone(),
827 residual_norm,
828 num_iterations: 1, // Direct solve
829 computation_time,
830 convergence_history: None,
831 })
832 }
833
834 /// Solve the linear system Ax = b
835 fn solve_linear_system(a: &Array2<f64>, b: &ArrayView1<f64>) -> PDEResult<Array1<f64>> {
836 let n = b.len();
837
838 // Simple Gaussian elimination with partial pivoting
839 // For a real implementation, use a specialized linear algebra library
840
841 // Create copies of A and b
842 let mut a_copy = a.clone();
843 let mut b_copy = b.to_owned();
844
845 // Forward elimination
846 for i in 0..n {
847 // Find pivot
848 let mut max_val = a_copy[[i, i]].abs();
849 let mut max_row = i;
850
851 for k in i + 1..n {
852 if a_copy[[k, i]].abs() > max_val {
853 max_val = a_copy[[k, i]].abs();
854 max_row = k;
855 }
856 }
857
858 // Check if matrix is singular
859 if max_val < 1e-10 {
860 return Err(PDEError::Other(
861 "Matrix is singular or nearly singular".to_string(),
862 ));
863 }
864
865 // Swap rows if necessary
866 if max_row != i {
867 for j in i..n {
868 let temp = a_copy[[i, j]];
869 a_copy[[i, j]] = a_copy[[max_row, j]];
870 a_copy[[max_row, j]] = temp;
871 }
872
873 let temp = b_copy[i];
874 b_copy[i] = b_copy[max_row];
875 b_copy[max_row] = temp;
876 }
877
878 // Eliminate below
879 for k in i + 1..n {
880 let factor = a_copy[[k, i]] / a_copy[[i, i]];
881
882 for j in i..n {
883 a_copy[[k, j]] -= factor * a_copy[[i, j]];
884 }
885
886 b_copy[k] -= factor * b_copy[i];
887 }
888 }
889
890 // Back substitution
891 let mut x = Array1::zeros(n);
892 for i in (0..n).rev() {
893 let mut sum = 0.0;
894 for j in i + 1..n {
895 sum += a_copy[[i, j]] * x[j];
896 }
897
898 x[i] = (b_copy[i] - sum) / a_copy[[i, i]];
899 }
900
901 Ok(x)
902 }
903}
904
905impl From<SpectralElementResult> for PDESolution<f64> {
906 fn from(result: SpectralElementResult) -> Self {
907 // Extract node coordinates for grids
908 let mut x_coords = Vec::new();
909 let mut y_coords = Vec::new();
910
911 for &(x, y) in &result.nodes {
912 x_coords.push(x);
913 y_coords.push(y);
914 }
915
916 // Create unique sorted x and y coordinates for grid
917 x_coords.sort_by(|a, b| a.partial_cmp(b).unwrap());
918 y_coords.sort_by(|a, b| a.partial_cmp(b).unwrap());
919
920 x_coords.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
921 y_coords.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
922
923 let grids = vec![Array1::from_vec(x_coords), Array1::from_vec(y_coords)];
924
925 // Create solution values as a 2D array for each grid point
926 let mut values = Vec::new();
927 let n_points = result.u.len();
928 let u_reshaped = result.u.into_shape_with_order((n_points, 1)).unwrap();
929 values.push(u_reshaped);
930
931 // Create solver info
932 let info = PDESolverInfo {
933 num_iterations: result.num_iterations,
934 computation_time: result.computation_time,
935 residual_norm: Some(result.residual_norm),
936 convergence_history: result.convergence_history,
937 method: "Spectral Element Method".to_string(),
938 };
939
940 PDESolution {
941 grids,
942 values,
943 error_estimate: None,
944 info,
945 }
946 }
947}