1use scirs2_core::ndarray::{Array1, Array2};
22
23use crate::pde::{PDEError, PDEResult};
24
25#[derive(Debug, Clone)]
31pub enum FDBoundaryCondition {
32 Dirichlet(f64),
34 Neumann(f64),
36 Periodic,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum TimeSteppingMethod {
43 Explicit,
45 CrankNicolson,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq)]
51pub enum EllipticIterativeMethod {
52 Jacobi,
54 GaussSeidel,
56 SOR(f64),
58}
59
60#[derive(Debug, Clone)]
66pub struct CFLAnalysis {
67 pub cfl_number: f64,
69 pub is_stable: bool,
71 pub max_stable_dt: f64,
73 pub description: String,
75}
76
77pub fn cfl_heat_1d(alpha: f64, dx: f64, dt: f64) -> CFLAnalysis {
82 let cfl = alpha * dt / (dx * dx);
83 let max_stable_dt = 0.5 * dx * dx / alpha;
84 CFLAnalysis {
85 cfl_number: cfl,
86 is_stable: cfl <= 0.5,
87 max_stable_dt,
88 description: format!(
89 "Heat 1D: CFL = {cfl:.4e} (must be <= 0.5). Max stable dt = {max_stable_dt:.4e}"
90 ),
91 }
92}
93
94pub fn cfl_heat_2d(alpha: f64, dx: f64, dy: f64, dt: f64) -> CFLAnalysis {
98 let cfl = alpha * dt * (1.0 / (dx * dx) + 1.0 / (dy * dy));
99 let max_stable_dt = 0.5 / (alpha * (1.0 / (dx * dx) + 1.0 / (dy * dy)));
100 CFLAnalysis {
101 cfl_number: cfl,
102 is_stable: cfl <= 0.5,
103 max_stable_dt,
104 description: format!(
105 "Heat 2D: CFL = {cfl:.4e} (must be <= 0.5). Max stable dt = {max_stable_dt:.4e}"
106 ),
107 }
108}
109
110pub fn cfl_wave_1d(c: f64, dx: f64, dt: f64) -> CFLAnalysis {
114 let cfl = c * dt / dx;
115 let max_stable_dt = dx / c;
116 CFLAnalysis {
117 cfl_number: cfl,
118 is_stable: cfl <= 1.0,
119 max_stable_dt,
120 description: format!(
121 "Wave 1D: CFL = {cfl:.4e} (must be <= 1.0). Max stable dt = {max_stable_dt:.4e}"
122 ),
123 }
124}
125
126pub fn cfl_wave_2d(c: f64, dx: f64, dy: f64, dt: f64) -> CFLAnalysis {
130 let factor = (1.0 / (dx * dx) + 1.0 / (dy * dy)).sqrt();
131 let cfl = c * dt * factor;
132 let max_stable_dt = 1.0 / (c * factor);
133 CFLAnalysis {
134 cfl_number: cfl,
135 is_stable: cfl <= 1.0,
136 max_stable_dt,
137 description: format!(
138 "Wave 2D: CFL = {cfl:.4e} (must be <= 1.0). Max stable dt = {max_stable_dt:.4e}"
139 ),
140 }
141}
142
143#[derive(Debug, Clone)]
149pub struct HeatResult {
150 pub x: Array1<f64>,
152 pub t: Array1<f64>,
154 pub u: Array2<f64>,
156 pub cfl: Option<CFLAnalysis>,
158}
159
160pub fn solve_heat_1d(
177 alpha: f64,
178 x_range: [f64; 2],
179 t_range: [f64; 2],
180 nx: usize,
181 nt: usize,
182 initial_condition: &dyn Fn(f64) -> f64,
183 left_bc: &FDBoundaryCondition,
184 right_bc: &FDBoundaryCondition,
185 method: TimeSteppingMethod,
186) -> PDEResult<HeatResult> {
187 if alpha <= 0.0 {
188 return Err(PDEError::InvalidParameter(
189 "Thermal diffusivity alpha must be positive".to_string(),
190 ));
191 }
192 if nx < 3 {
193 return Err(PDEError::InvalidGrid(
194 "Need at least 3 spatial grid points".to_string(),
195 ));
196 }
197 if nt < 1 {
198 return Err(PDEError::InvalidParameter(
199 "Need at least 1 time step".to_string(),
200 ));
201 }
202
203 let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
204 let dt = (t_range[1] - t_range[0]) / nt as f64;
205
206 let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
208 let t = Array1::from_shape_fn(nt + 1, |i| t_range[0] + i as f64 * dt);
210
211 let mut u = Array2::zeros((nt + 1, nx));
213 for i in 0..nx {
214 u[[0, i]] = initial_condition(x[i]);
215 }
216 apply_bc_1d(&mut u, 0, left_bc, right_bc, dx);
218
219 let cfl = cfl_heat_1d(alpha, dx, dt);
220
221 match method {
222 TimeSteppingMethod::Explicit => {
223 if !cfl.is_stable {
224 return Err(PDEError::ComputationError(format!(
225 "Explicit scheme unstable: {}",
226 cfl.description
227 )));
228 }
229 let r = alpha * dt / (dx * dx);
230 for n in 0..nt {
231 let is_periodic = matches!(
233 (left_bc, right_bc),
234 (FDBoundaryCondition::Periodic, FDBoundaryCondition::Periodic)
235 );
236 for i in 1..nx - 1 {
237 u[[n + 1, i]] =
238 u[[n, i]] + r * (u[[n, i + 1]] - 2.0 * u[[n, i]] + u[[n, i - 1]]);
239 }
240 if is_periodic {
241 u[[n + 1, 0]] = u[[n, 0]] + r * (u[[n, 1]] - 2.0 * u[[n, 0]] + u[[n, nx - 2]]);
243 u[[n + 1, nx - 1]] = u[[n + 1, 0]];
244 } else {
245 apply_bc_1d(&mut u, n + 1, left_bc, right_bc, dx);
246 }
247 }
248 }
249 TimeSteppingMethod::CrankNicolson => {
250 let r = alpha * dt / (2.0 * dx * dx);
251 let is_periodic = matches!(
252 (left_bc, right_bc),
253 (FDBoundaryCondition::Periodic, FDBoundaryCondition::Periodic)
254 );
255 for n in 0..nt {
256 let mut rhs = Array1::zeros(nx);
258 for i in 1..nx - 1 {
259 rhs[i] = u[[n, i]] + r * (u[[n, i + 1]] - 2.0 * u[[n, i]] + u[[n, i - 1]]);
260 }
261 if is_periodic {
262 rhs[0] = u[[n, 0]] + r * (u[[n, 1]] - 2.0 * u[[n, 0]] + u[[n, nx - 2]]);
263 rhs[nx - 1] = rhs[0];
264 }
265
266 if is_periodic {
268 let solved = solve_periodic_tridiag(nx - 1, -r, 1.0 + 2.0 * r, -r, &rhs)?;
269 for i in 0..nx - 1 {
270 u[[n + 1, i]] = solved[i];
271 }
272 u[[n + 1, nx - 1]] = u[[n + 1, 0]];
273 } else {
274 let interior_size = nx - 2;
275 if interior_size == 0 {
276 apply_bc_1d(&mut u, n + 1, left_bc, right_bc, dx);
277 continue;
278 }
279 let mut rhs_interior = Array1::zeros(interior_size);
280 for i in 0..interior_size {
281 rhs_interior[i] = rhs[i + 1];
282 }
283 apply_cn_bc_adjustment(
285 &mut rhs_interior,
286 left_bc,
287 right_bc,
288 r,
289 &u,
290 n + 1,
291 nx,
292 dx,
293 );
294 let solved =
295 solve_tridiag(interior_size, -r, 1.0 + 2.0 * r, -r, &rhs_interior)?;
296 for i in 0..interior_size {
297 u[[n + 1, i + 1]] = solved[i];
298 }
299 apply_bc_1d(&mut u, n + 1, left_bc, right_bc, dx);
300 }
301 }
302 }
303 }
304
305 Ok(HeatResult {
306 x,
307 t,
308 u,
309 cfl: Some(cfl),
310 })
311}
312
313#[derive(Debug, Clone)]
319pub struct Heat2DResult {
320 pub x: Array1<f64>,
322 pub y: Array1<f64>,
324 pub t: Array1<f64>,
326 pub u: Vec<Array2<f64>>,
328 pub cfl: Option<CFLAnalysis>,
330}
331
332pub fn solve_heat_2d(
337 alpha: f64,
338 x_range: [f64; 2],
339 y_range: [f64; 2],
340 t_range: [f64; 2],
341 nx: usize,
342 ny: usize,
343 nt: usize,
344 initial_condition: &dyn Fn(f64, f64) -> f64,
345 bc_values: [f64; 4], save_every: usize,
347) -> PDEResult<Heat2DResult> {
348 if alpha <= 0.0 {
349 return Err(PDEError::InvalidParameter(
350 "Thermal diffusivity alpha must be positive".to_string(),
351 ));
352 }
353 if nx < 3 || ny < 3 {
354 return Err(PDEError::InvalidGrid(
355 "Need at least 3 grid points in each dimension".to_string(),
356 ));
357 }
358
359 let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
360 let dy = (y_range[1] - y_range[0]) / (ny as f64 - 1.0);
361 let dt = (t_range[1] - t_range[0]) / nt as f64;
362
363 let cfl = cfl_heat_2d(alpha, dx, dy, dt);
364 if !cfl.is_stable {
365 return Err(PDEError::ComputationError(format!(
366 "Explicit scheme unstable: {}",
367 cfl.description
368 )));
369 }
370
371 let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
372 let y = Array1::from_shape_fn(ny, |j| y_range[0] + j as f64 * dy);
373 let mut t_save = vec![t_range[0]];
374
375 let mut u_curr = Array2::zeros((ny, nx));
377 for j in 0..ny {
378 for i in 0..nx {
379 u_curr[[j, i]] = initial_condition(x[i], y[j]);
380 }
381 }
382 apply_dirichlet_2d(&mut u_curr, bc_values, nx, ny);
383
384 let save_every = if save_every == 0 { 1 } else { save_every };
385 let mut snapshots = vec![u_curr.clone()];
386
387 let rx = alpha * dt / (dx * dx);
388 let ry = alpha * dt / (dy * dy);
389
390 for n in 0..nt {
391 let mut u_next = u_curr.clone();
392 for j in 1..ny - 1 {
393 for i in 1..nx - 1 {
394 u_next[[j, i]] = u_curr[[j, i]]
395 + rx * (u_curr[[j, i + 1]] - 2.0 * u_curr[[j, i]] + u_curr[[j, i - 1]])
396 + ry * (u_curr[[j + 1, i]] - 2.0 * u_curr[[j, i]] + u_curr[[j - 1, i]]);
397 }
398 }
399 apply_dirichlet_2d(&mut u_next, bc_values, nx, ny);
400 u_curr = u_next;
401
402 if (n + 1) % save_every == 0 || n + 1 == nt {
403 snapshots.push(u_curr.clone());
404 t_save.push(t_range[0] + (n + 1) as f64 * dt);
405 }
406 }
407
408 Ok(Heat2DResult {
409 x,
410 y,
411 t: Array1::from_vec(t_save),
412 u: snapshots,
413 cfl: Some(cfl),
414 })
415}
416
417#[derive(Debug, Clone)]
423pub struct WaveResult {
424 pub x: Array1<f64>,
426 pub t: Array1<f64>,
428 pub u: Array2<f64>,
430 pub cfl: Option<CFLAnalysis>,
432}
433
434pub fn solve_wave_1d(
438 c: f64,
439 x_range: [f64; 2],
440 t_range: [f64; 2],
441 nx: usize,
442 nt: usize,
443 initial_displacement: &dyn Fn(f64) -> f64,
444 initial_velocity: &dyn Fn(f64) -> f64,
445 left_bc: &FDBoundaryCondition,
446 right_bc: &FDBoundaryCondition,
447) -> PDEResult<WaveResult> {
448 if c <= 0.0 {
449 return Err(PDEError::InvalidParameter(
450 "Wave speed c must be positive".to_string(),
451 ));
452 }
453 if nx < 3 {
454 return Err(PDEError::InvalidGrid(
455 "Need at least 3 spatial grid points".to_string(),
456 ));
457 }
458
459 let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
460 let dt = (t_range[1] - t_range[0]) / nt as f64;
461
462 let cfl = cfl_wave_1d(c, dx, dt);
463 if !cfl.is_stable {
464 return Err(PDEError::ComputationError(format!(
465 "Explicit wave scheme unstable: {}",
466 cfl.description
467 )));
468 }
469
470 let r2 = (c * dt / dx) * (c * dt / dx);
471 let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
472 let t = Array1::from_shape_fn(nt + 1, |i| t_range[0] + i as f64 * dt);
473
474 let mut u = Array2::zeros((nt + 1, nx));
475
476 for i in 0..nx {
478 u[[0, i]] = initial_displacement(x[i]);
479 }
480 apply_bc_1d(&mut u, 0, left_bc, right_bc, dx);
481
482 let is_periodic = matches!(
485 (left_bc, right_bc),
486 (FDBoundaryCondition::Periodic, FDBoundaryCondition::Periodic)
487 );
488 for i in 1..nx - 1 {
489 let d2u = u[[0, i + 1]] - 2.0 * u[[0, i]] + u[[0, i - 1]];
490 u[[1, i]] = u[[0, i]] + dt * initial_velocity(x[i]) + 0.5 * r2 * d2u;
491 }
492 if is_periodic {
493 let d2u = u[[0, 1]] - 2.0 * u[[0, 0]] + u[[0, nx - 2]];
494 u[[1, 0]] = u[[0, 0]] + dt * initial_velocity(x[0]) + 0.5 * r2 * d2u;
495 u[[1, nx - 1]] = u[[1, 0]];
496 } else {
497 apply_bc_1d(&mut u, 1, left_bc, right_bc, dx);
498 }
499
500 for n in 1..nt {
502 for i in 1..nx - 1 {
503 u[[n + 1, i]] = 2.0 * u[[n, i]] - u[[n - 1, i]]
504 + r2 * (u[[n, i + 1]] - 2.0 * u[[n, i]] + u[[n, i - 1]]);
505 }
506 if is_periodic {
507 u[[n + 1, 0]] = 2.0 * u[[n, 0]] - u[[n - 1, 0]]
508 + r2 * (u[[n, 1]] - 2.0 * u[[n, 0]] + u[[n, nx - 2]]);
509 u[[n + 1, nx - 1]] = u[[n + 1, 0]];
510 } else {
511 apply_bc_1d(&mut u, n + 1, left_bc, right_bc, dx);
512 }
513 }
514
515 Ok(WaveResult {
516 x,
517 t,
518 u,
519 cfl: Some(cfl),
520 })
521}
522
523pub fn solve_wave_2d(
527 c: f64,
528 x_range: [f64; 2],
529 y_range: [f64; 2],
530 t_range: [f64; 2],
531 nx: usize,
532 ny: usize,
533 nt: usize,
534 initial_displacement: &dyn Fn(f64, f64) -> f64,
535 initial_velocity: &dyn Fn(f64, f64) -> f64,
536 bc_value: f64,
537 save_every: usize,
538) -> PDEResult<Wave2DResult> {
539 if c <= 0.0 {
540 return Err(PDEError::InvalidParameter(
541 "Wave speed c must be positive".to_string(),
542 ));
543 }
544 if nx < 3 || ny < 3 {
545 return Err(PDEError::InvalidGrid(
546 "Need at least 3 grid points in each dimension".to_string(),
547 ));
548 }
549
550 let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
551 let dy = (y_range[1] - y_range[0]) / (ny as f64 - 1.0);
552 let dt = (t_range[1] - t_range[0]) / nt as f64;
553
554 let cfl = cfl_wave_2d(c, dx, dy, dt);
555 if !cfl.is_stable {
556 return Err(PDEError::ComputationError(format!(
557 "Explicit 2D wave scheme unstable: {}",
558 cfl.description
559 )));
560 }
561
562 let rx2 = (c * dt / dx) * (c * dt / dx);
563 let ry2 = (c * dt / dy) * (c * dt / dy);
564
565 let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
566 let y = Array1::from_shape_fn(ny, |j| y_range[0] + j as f64 * dy);
567
568 let save_every = if save_every == 0 { 1 } else { save_every };
569 let bc_vals = [bc_value; 4];
570
571 let mut u_prev = Array2::zeros((ny, nx));
573 let mut u_curr = Array2::zeros((ny, nx));
574
575 for j in 0..ny {
577 for i in 0..nx {
578 u_curr[[j, i]] = initial_displacement(x[i], y[j]);
579 }
580 }
581 apply_dirichlet_2d(&mut u_curr, bc_vals, nx, ny);
582
583 let mut snapshots = vec![u_curr.clone()];
584 let mut t_save = vec![t_range[0]];
585
586 for j in 1..ny - 1 {
588 for i in 1..nx - 1 {
589 let d2x = u_curr[[j, i + 1]] - 2.0 * u_curr[[j, i]] + u_curr[[j, i - 1]];
590 let d2y = u_curr[[j + 1, i]] - 2.0 * u_curr[[j, i]] + u_curr[[j - 1, i]];
591 u_prev[[j, i]] =
592 u_curr[[j, i]] + dt * initial_velocity(x[i], y[j]) + 0.5 * (rx2 * d2x + ry2 * d2y);
593 }
594 }
595 apply_dirichlet_2d(&mut u_prev, bc_vals, nx, ny);
596 std::mem::swap(&mut u_prev, &mut u_curr);
598
599 if save_every == 1 {
600 snapshots.push(u_curr.clone());
601 t_save.push(t_range[0] + dt);
602 }
603
604 for n in 1..nt {
606 let mut u_next = Array2::zeros((ny, nx));
607 for j in 1..ny - 1 {
608 for i in 1..nx - 1 {
609 let d2x = u_curr[[j, i + 1]] - 2.0 * u_curr[[j, i]] + u_curr[[j, i - 1]];
610 let d2y = u_curr[[j + 1, i]] - 2.0 * u_curr[[j, i]] + u_curr[[j - 1, i]];
611 u_next[[j, i]] = 2.0 * u_curr[[j, i]] - u_prev[[j, i]] + rx2 * d2x + ry2 * d2y;
612 }
613 }
614 apply_dirichlet_2d(&mut u_next, bc_vals, nx, ny);
615 u_prev = u_curr;
616 u_curr = u_next;
617
618 if (n + 1) % save_every == 0 || n + 1 == nt {
619 snapshots.push(u_curr.clone());
620 t_save.push(t_range[0] + (n + 1) as f64 * dt);
621 }
622 }
623
624 Ok(Wave2DResult {
625 x,
626 y,
627 t: Array1::from_vec(t_save),
628 u: snapshots,
629 cfl: Some(cfl),
630 })
631}
632
633#[derive(Debug, Clone)]
635pub struct Wave2DResult {
636 pub x: Array1<f64>,
638 pub y: Array1<f64>,
640 pub t: Array1<f64>,
642 pub u: Vec<Array2<f64>>,
644 pub cfl: Option<CFLAnalysis>,
646}
647
648#[derive(Debug, Clone)]
654pub struct PoissonResult {
655 pub x: Array1<f64>,
657 pub y: Array1<f64>,
659 pub u: Array2<f64>,
661 pub iterations: usize,
663 pub residual: f64,
665 pub convergence_history: Vec<f64>,
667}
668
669pub fn solve_poisson_2d(
673 source: &dyn Fn(f64, f64) -> f64,
674 x_range: [f64; 2],
675 y_range: [f64; 2],
676 nx: usize,
677 ny: usize,
678 bc_values: [f64; 4], method: EllipticIterativeMethod,
680 tol: f64,
681 max_iter: usize,
682) -> PDEResult<PoissonResult> {
683 if nx < 3 || ny < 3 {
684 return Err(PDEError::InvalidGrid(
685 "Need at least 3 grid points in each dimension".to_string(),
686 ));
687 }
688
689 let dx = (x_range[1] - x_range[0]) / (nx as f64 - 1.0);
690 let dy = (y_range[1] - y_range[0]) / (ny as f64 - 1.0);
691
692 let x = Array1::from_shape_fn(nx, |i| x_range[0] + i as f64 * dx);
693 let y = Array1::from_shape_fn(ny, |j| y_range[0] + j as f64 * dy);
694
695 let mut u = Array2::zeros((ny, nx));
696 apply_dirichlet_2d(&mut u, bc_values, nx, ny);
697
698 let dx2 = dx * dx;
699 let dy2 = dy * dy;
700 let denom = 2.0 * (1.0 / dx2 + 1.0 / dy2);
701
702 let mut convergence_history = Vec::with_capacity(max_iter);
703 let mut iterations = 0;
704 let mut residual = f64::MAX;
705
706 for iter in 0..max_iter {
707 match method {
708 EllipticIterativeMethod::Jacobi => {
709 let u_old = u.clone();
710 for j in 1..ny - 1 {
711 for i in 1..nx - 1 {
712 u[[j, i]] = ((u_old[[j, i + 1]] + u_old[[j, i - 1]]) / dx2
713 + (u_old[[j + 1, i]] + u_old[[j - 1, i]]) / dy2
714 - source(x[i], y[j]))
715 / denom;
716 }
717 }
718 }
719 EllipticIterativeMethod::GaussSeidel => {
720 for j in 1..ny - 1 {
721 for i in 1..nx - 1 {
722 u[[j, i]] = ((u[[j, i + 1]] + u[[j, i - 1]]) / dx2
723 + (u[[j + 1, i]] + u[[j - 1, i]]) / dy2
724 - source(x[i], y[j]))
725 / denom;
726 }
727 }
728 }
729 EllipticIterativeMethod::SOR(omega) => {
730 for j in 1..ny - 1 {
731 for i in 1..nx - 1 {
732 let gs_val = ((u[[j, i + 1]] + u[[j, i - 1]]) / dx2
733 + (u[[j + 1, i]] + u[[j - 1, i]]) / dy2
734 - source(x[i], y[j]))
735 / denom;
736 u[[j, i]] = (1.0 - omega) * u[[j, i]] + omega * gs_val;
737 }
738 }
739 }
740 }
741
742 let mut res_sum = 0.0;
744 for j in 1..ny - 1 {
745 for i in 1..nx - 1 {
746 let lap = (u[[j, i + 1]] - 2.0 * u[[j, i]] + u[[j, i - 1]]) / dx2
747 + (u[[j + 1, i]] - 2.0 * u[[j, i]] + u[[j - 1, i]]) / dy2;
748 let r = source(x[i], y[j]) - lap;
749 res_sum += r * r;
750 }
751 }
752 residual = (res_sum / ((nx - 2) * (ny - 2)) as f64).sqrt();
753 convergence_history.push(residual);
754 iterations = iter + 1;
755
756 if residual < tol {
757 break;
758 }
759 }
760
761 Ok(PoissonResult {
762 x,
763 y,
764 u,
765 iterations,
766 residual,
767 convergence_history,
768 })
769}
770
771fn apply_bc_1d(
777 u: &mut Array2<f64>,
778 time_idx: usize,
779 left_bc: &FDBoundaryCondition,
780 right_bc: &FDBoundaryCondition,
781 dx: f64,
782) {
783 let nx = u.shape()[1];
784 match left_bc {
785 FDBoundaryCondition::Dirichlet(val) => {
786 u[[time_idx, 0]] = *val;
787 }
788 FDBoundaryCondition::Neumann(val) => {
789 u[[time_idx, 0]] = u[[time_idx, 1]] - dx * val;
791 }
792 FDBoundaryCondition::Periodic => {
793 }
795 }
796 match right_bc {
797 FDBoundaryCondition::Dirichlet(val) => {
798 u[[time_idx, nx - 1]] = *val;
799 }
800 FDBoundaryCondition::Neumann(val) => {
801 u[[time_idx, nx - 1]] = u[[time_idx, nx - 2]] + dx * val;
803 }
804 FDBoundaryCondition::Periodic => {
805 }
807 }
808}
809
810fn apply_dirichlet_2d(u: &mut Array2<f64>, bc: [f64; 4], nx: usize, ny: usize) {
812 for j in 0..ny {
813 u[[j, 0]] = bc[0]; u[[j, nx - 1]] = bc[1]; }
816 for i in 0..nx {
817 u[[0, i]] = bc[2]; u[[ny - 1, i]] = bc[3]; }
820}
821
822#[allow(clippy::too_many_arguments)]
824fn apply_cn_bc_adjustment(
825 rhs: &mut Array1<f64>,
826 left_bc: &FDBoundaryCondition,
827 right_bc: &FDBoundaryCondition,
828 r: f64,
829 u: &Array2<f64>,
830 _time_idx: usize,
831 nx: usize,
832 dx: f64,
833) {
834 let interior_size = rhs.len();
835 if interior_size == 0 {
836 return;
837 }
838 match left_bc {
840 FDBoundaryCondition::Dirichlet(val) => {
841 rhs[0] += r * val;
842 }
843 FDBoundaryCondition::Neumann(val) => {
844 rhs[0] -= r * dx * val;
847 }
848 FDBoundaryCondition::Periodic => {}
849 }
850 match right_bc {
852 FDBoundaryCondition::Dirichlet(val) => {
853 rhs[interior_size - 1] += r * val;
854 }
855 FDBoundaryCondition::Neumann(val) => {
856 rhs[interior_size - 1] += r * dx * val;
857 }
858 FDBoundaryCondition::Periodic => {}
859 }
860 let _ = u; }
862
863fn solve_tridiag(
866 n: usize,
867 sub: f64,
868 diag: f64,
869 sup: f64,
870 rhs: &Array1<f64>,
871) -> PDEResult<Array1<f64>> {
872 if n == 0 {
873 return Ok(Array1::zeros(0));
874 }
875 let mut c_prime = vec![0.0; n];
876 let mut d_prime = vec![0.0; n];
877
878 c_prime[0] = sup / diag;
880 d_prime[0] = rhs[0] / diag;
881 for i in 1..n {
882 let m = diag - sub * c_prime[i - 1];
883 if m.abs() < 1e-15 {
884 return Err(PDEError::ComputationError(
885 "Zero pivot in tridiagonal solve".to_string(),
886 ));
887 }
888 c_prime[i] = if i < n - 1 { sup / m } else { 0.0 };
889 d_prime[i] = (rhs[i] - sub * d_prime[i - 1]) / m;
890 }
891
892 let mut x = Array1::zeros(n);
894 x[n - 1] = d_prime[n - 1];
895 for i in (0..n - 1).rev() {
896 x[i] = d_prime[i] - c_prime[i] * x[i + 1];
897 }
898 Ok(x)
899}
900
901fn solve_periodic_tridiag(
903 n: usize,
904 sub: f64,
905 diag: f64,
906 sup: f64,
907 rhs: &Array1<f64>,
908) -> PDEResult<Array1<f64>> {
909 if n < 3 {
910 return Err(PDEError::ComputationError(
911 "Periodic tridiagonal system needs at least 3 unknowns".to_string(),
912 ));
913 }
914
915 let gamma = -diag;
917 let d_mod = diag - gamma; let d_last = diag - sub * sup / gamma; let mut rhs_mod = rhs.clone();
922 let mut diag_arr = vec![diag; n];
931 diag_arr[0] = d_mod;
932 diag_arr[n - 1] = d_last;
933
934 let mut sub_arr = vec![sub; n];
935 sub_arr[0] = 0.0; let mut sup_arr = vec![sup; n];
937 sup_arr[n - 1] = 0.0; let y = solve_general_tridiag(&sub_arr, &diag_arr, &sup_arr, &rhs_mod)?;
941
942 let mut u_sm = Array1::zeros(n);
944 u_sm[0] = gamma;
945 u_sm[n - 1] = sup;
946 let z = solve_general_tridiag(&sub_arr, &diag_arr, &sup_arr, &u_sm)?;
947
948 let v0 = 1.0;
950 let vn = sub / gamma;
951
952 let numer = v0 * y[0] + vn * y[n - 1];
953 let denom_val = 1.0 + v0 * z[0] + vn * z[n - 1];
954
955 if denom_val.abs() < 1e-15 {
956 return Err(PDEError::ComputationError(
957 "Singular periodic tridiagonal system".to_string(),
958 ));
959 }
960
961 let factor = numer / denom_val;
962 let mut x = Array1::zeros(n);
963 for i in 0..n {
964 x[i] = y[i] - factor * z[i];
965 }
966
967 Ok(x)
968}
969
970fn solve_general_tridiag(
972 sub: &[f64],
973 diag: &[f64],
974 sup: &[f64],
975 rhs: &Array1<f64>,
976) -> PDEResult<Array1<f64>> {
977 let n = rhs.len();
978 if n == 0 {
979 return Ok(Array1::zeros(0));
980 }
981
982 let mut c_prime = vec![0.0; n];
983 let mut d_prime = vec![0.0; n];
984
985 if diag[0].abs() < 1e-15 {
986 return Err(PDEError::ComputationError(
987 "Zero pivot in general tridiagonal solve".to_string(),
988 ));
989 }
990 c_prime[0] = sup[0] / diag[0];
991 d_prime[0] = rhs[0] / diag[0];
992
993 for i in 1..n {
994 let m = diag[i] - sub[i] * c_prime[i - 1];
995 if m.abs() < 1e-15 {
996 return Err(PDEError::ComputationError(
997 "Zero pivot in general tridiagonal solve".to_string(),
998 ));
999 }
1000 c_prime[i] = if i < n - 1 { sup[i] / m } else { 0.0 };
1001 d_prime[i] = (rhs[i] - sub[i] * d_prime[i - 1]) / m;
1002 }
1003
1004 let mut x = Array1::zeros(n);
1005 x[n - 1] = d_prime[n - 1];
1006 for i in (0..n - 1).rev() {
1007 x[i] = d_prime[i] - c_prime[i] * x[i + 1];
1008 }
1009 Ok(x)
1010}
1011
1012#[cfg(test)]
1017mod tests {
1018 use super::*;
1019 use std::f64::consts::PI;
1020
1021 #[test]
1022 fn test_cfl_heat_1d_stable() {
1023 let cfl = cfl_heat_1d(0.01, 0.1, 0.1);
1024 assert!(cfl.is_stable);
1026 assert!(cfl.cfl_number < 0.5 + 1e-10);
1027 }
1028
1029 #[test]
1030 fn test_cfl_heat_1d_unstable() {
1031 let cfl = cfl_heat_1d(1.0, 0.01, 0.01);
1032 assert!(!cfl.is_stable);
1034 }
1035
1036 #[test]
1037 fn test_cfl_wave_1d_stable() {
1038 let cfl = cfl_wave_1d(1.0, 0.1, 0.05);
1039 assert!(cfl.is_stable);
1041 }
1042
1043 #[test]
1044 fn test_heat_1d_explicit_constant_ic() {
1045 let result = solve_heat_1d(
1048 0.01,
1049 [0.0, 1.0],
1050 [0.0, 0.1],
1051 21,
1052 100,
1053 &|_x| 1.0,
1054 &FDBoundaryCondition::Dirichlet(1.0),
1055 &FDBoundaryCondition::Dirichlet(1.0),
1056 TimeSteppingMethod::Explicit,
1057 );
1058 let res = result.expect("Should succeed");
1059 let last = res.u.row(res.u.shape()[0] - 1);
1061 for &v in last.iter() {
1062 assert!((v - 1.0).abs() < 1e-10);
1063 }
1064 }
1065
1066 #[test]
1067 fn test_heat_1d_explicit_decay() {
1068 let alpha = 0.01;
1071 let nx = 51;
1072 let nt = 5000;
1073 let result = solve_heat_1d(
1074 alpha,
1075 [0.0, 1.0],
1076 [0.0, 1.0],
1077 nx,
1078 nt,
1079 &|x| (PI * x).sin(),
1080 &FDBoundaryCondition::Dirichlet(0.0),
1081 &FDBoundaryCondition::Dirichlet(0.0),
1082 TimeSteppingMethod::Explicit,
1083 );
1084 let res = result.expect("Should succeed");
1085 let last = res.u.row(res.u.shape()[0] - 1);
1086 let mid = nx / 2;
1088 let exact = (PI * 0.5).sin() * (-PI * PI * alpha * 1.0).exp();
1089 assert!(
1090 (last[mid] - exact).abs() < 0.02,
1091 "Got {}, expected {} (tol=0.02)",
1092 last[mid],
1093 exact
1094 );
1095 }
1096
1097 #[test]
1098 fn test_heat_1d_crank_nicolson() {
1099 let alpha = 0.1;
1100 let nx = 21;
1101 let nt = 50;
1102 let result = solve_heat_1d(
1103 alpha,
1104 [0.0, 1.0],
1105 [0.0, 1.0],
1106 nx,
1107 nt,
1108 &|x| (PI * x).sin(),
1109 &FDBoundaryCondition::Dirichlet(0.0),
1110 &FDBoundaryCondition::Dirichlet(0.0),
1111 TimeSteppingMethod::CrankNicolson,
1112 );
1113 let res = result.expect("Should succeed");
1114 let last = res.u.row(res.u.shape()[0] - 1);
1115 let mid = nx / 2;
1116 let exact = (PI * 0.5).sin() * (-PI * PI * alpha * 1.0).exp();
1117 assert!(
1118 (last[mid] - exact).abs() < 0.05,
1119 "CN got {}, expected {} (tol=0.05)",
1120 last[mid],
1121 exact
1122 );
1123 }
1124
1125 #[test]
1126 fn test_heat_1d_neumann() {
1127 let result = solve_heat_1d(
1130 0.01,
1131 [0.0, 1.0],
1132 [0.0, 0.5],
1133 21,
1134 200,
1135 &|_| 1.0,
1136 &FDBoundaryCondition::Neumann(0.0),
1137 &FDBoundaryCondition::Neumann(0.0),
1138 TimeSteppingMethod::Explicit,
1139 );
1140 let res = result.expect("Should succeed");
1141 let last = res.u.row(res.u.shape()[0] - 1);
1142 for &v in last.iter() {
1143 assert!(
1144 (v - 1.0).abs() < 0.01,
1145 "Neumann with constant IC should stay ~1.0, got {v}"
1146 );
1147 }
1148 }
1149
1150 #[test]
1151 fn test_heat_1d_periodic() {
1152 let alpha = 0.01;
1154 let nx = 41;
1155 let nt = 500;
1156 let result = solve_heat_1d(
1157 alpha,
1158 [0.0, 1.0],
1159 [0.0, 0.5],
1160 nx,
1161 nt,
1162 &|x| (2.0 * PI * x).sin(),
1163 &FDBoundaryCondition::Periodic,
1164 &FDBoundaryCondition::Periodic,
1165 TimeSteppingMethod::Explicit,
1166 );
1167 let res = result.expect("Should succeed");
1168 let last = res.u.row(res.u.shape()[0] - 1);
1169 let decay = (-4.0 * PI * PI * alpha * 0.5).exp();
1172 let mid = nx / 4; let exact = decay * (2.0 * PI * 0.25).sin();
1174 assert!(
1175 (last[mid] - exact).abs() < 0.05,
1176 "Periodic got {}, expected {exact} (tol=0.05)",
1177 last[mid]
1178 );
1179 }
1180
1181 #[test]
1182 fn test_heat_2d_constant() {
1183 let result = solve_heat_2d(
1185 0.01,
1186 [0.0, 1.0],
1187 [0.0, 1.0],
1188 [0.0, 0.1],
1189 11,
1190 11,
1191 50,
1192 &|_, _| 1.0,
1193 [1.0, 1.0, 1.0, 1.0],
1194 50,
1195 );
1196 let res = result.expect("Should succeed");
1197 let last = &res.u[res.u.len() - 1];
1198 for j in 0..11 {
1199 for i in 0..11 {
1200 assert!(
1201 (last[[j, i]] - 1.0).abs() < 1e-10,
1202 "2D heat constant: [{j},{i}] = {}",
1203 last[[j, i]]
1204 );
1205 }
1206 }
1207 }
1208
1209 #[test]
1210 fn test_wave_1d_standing() {
1211 let c = 1.0;
1214 let nx = 101;
1215 let nt = 200;
1216 let result = solve_wave_1d(
1217 c,
1218 [0.0, 1.0],
1219 [0.0, 0.5],
1220 nx,
1221 nt,
1222 &|x| (PI * x).sin(),
1223 &|_x| 0.0,
1224 &FDBoundaryCondition::Dirichlet(0.0),
1225 &FDBoundaryCondition::Dirichlet(0.0),
1226 );
1227 let res = result.expect("Should succeed");
1228 let last = res.u.row(res.u.shape()[0] - 1);
1229 let mid = nx / 2;
1230 let exact = (PI * 0.5).sin() * (PI * c * 0.5).cos();
1231 assert!(
1232 (last[mid] - exact).abs() < 0.05,
1233 "Wave got {}, expected {exact}",
1234 last[mid]
1235 );
1236 }
1237
1238 #[test]
1239 fn test_wave_1d_periodic() {
1240 let c = 1.0;
1241 let nx = 101;
1242 let nt = 100;
1243 let result = solve_wave_1d(
1244 c,
1245 [0.0, 1.0],
1246 [0.0, 0.5],
1247 nx,
1248 nt,
1249 &|x| (2.0 * PI * x).sin(),
1250 &|_x| 0.0,
1251 &FDBoundaryCondition::Periodic,
1252 &FDBoundaryCondition::Periodic,
1253 );
1254 assert!(result.is_ok(), "Periodic wave should succeed");
1255 }
1256
1257 #[test]
1258 fn test_wave_2d_basic() {
1259 let result = solve_wave_2d(
1260 1.0,
1261 [0.0, 1.0],
1262 [0.0, 1.0],
1263 [0.0, 0.1],
1264 21,
1265 21,
1266 50,
1267 &|x, y| (PI * x).sin() * (PI * y).sin(),
1268 &|_, _| 0.0,
1269 0.0,
1270 50,
1271 );
1272 assert!(result.is_ok(), "2D wave should succeed");
1273 }
1274
1275 #[test]
1276 fn test_poisson_zero_source() {
1277 let result = solve_poisson_2d(
1279 &|_, _| 0.0,
1280 [0.0, 1.0],
1281 [0.0, 1.0],
1282 11,
1283 11,
1284 [1.0, 1.0, 1.0, 1.0],
1285 EllipticIterativeMethod::GaussSeidel,
1286 1e-8,
1287 5000,
1288 );
1289 let res = result.expect("Should succeed");
1290 for j in 0..11 {
1291 for i in 0..11 {
1292 assert!(
1293 (res.u[[j, i]] - 1.0).abs() < 1e-4,
1294 "Laplace [{j},{i}] = {} (expected 1.0)",
1295 res.u[[j, i]]
1296 );
1297 }
1298 }
1299 }
1300
1301 #[test]
1302 fn test_poisson_jacobi() {
1303 let result = solve_poisson_2d(
1304 &|_, _| -2.0,
1305 [0.0, 1.0],
1306 [0.0, 1.0],
1307 21,
1308 21,
1309 [0.0, 0.0, 0.0, 0.0],
1310 EllipticIterativeMethod::Jacobi,
1311 1e-6,
1312 10000,
1313 );
1314 let res = result.expect("Should succeed");
1315 let mid = 10;
1318 assert!(
1319 res.u[[mid, mid]] > 0.0,
1320 "Center should be positive for negative source"
1321 );
1322 }
1323
1324 #[test]
1325 fn test_poisson_sor() {
1326 let result_gs = solve_poisson_2d(
1328 &|_, _| -2.0,
1329 [0.0, 1.0],
1330 [0.0, 1.0],
1331 21,
1332 21,
1333 [0.0, 0.0, 0.0, 0.0],
1334 EllipticIterativeMethod::GaussSeidel,
1335 1e-6,
1336 10000,
1337 )
1338 .expect("GS should succeed");
1339
1340 let result_sor = solve_poisson_2d(
1341 &|_, _| -2.0,
1342 [0.0, 1.0],
1343 [0.0, 1.0],
1344 21,
1345 21,
1346 [0.0, 0.0, 0.0, 0.0],
1347 EllipticIterativeMethod::SOR(1.5),
1348 1e-6,
1349 10000,
1350 )
1351 .expect("SOR should succeed");
1352
1353 assert!(
1355 result_sor.iterations <= result_gs.iterations,
1356 "SOR ({}) should converge <= GS ({})",
1357 result_sor.iterations,
1358 result_gs.iterations
1359 );
1360 }
1361
1362 #[test]
1363 fn test_heat_explicit_unstable_rejected() {
1364 let result = solve_heat_1d(
1366 1.0,
1367 [0.0, 1.0],
1368 [0.0, 1.0],
1369 11,
1370 2,
1371 &|_| 0.0,
1372 &FDBoundaryCondition::Dirichlet(0.0),
1373 &FDBoundaryCondition::Dirichlet(0.0),
1374 TimeSteppingMethod::Explicit,
1375 );
1376 assert!(result.is_err(), "Should reject unstable explicit scheme");
1377 }
1378
1379 #[test]
1380 fn test_cfl_heat_2d() {
1381 let cfl = cfl_heat_2d(0.01, 0.1, 0.1, 0.1);
1382 assert!(cfl.is_stable);
1384 }
1385
1386 #[test]
1387 fn test_cfl_wave_2d() {
1388 let cfl = cfl_wave_2d(1.0, 0.1, 0.1, 0.05);
1389 assert!(cfl.is_stable);
1391 }
1392}