Skip to main content

scirs2_interpolate/
physics_interp.rs

1//! Physics-informed interpolation with PDE residual penalty.
2//!
3//! Enforces PDE residuals as soft constraints during the RBF fitting process.
4//! The interpolant solves an augmented system
5//!
6//! ```text
7//! [ Φ_data       ]         [  y_data       ]
8//! [ √λ · Φ_coll  ] w   =   [ √λ · f_coll  ]
9//! ```
10//!
11//! where Φ_data is the data-point RBF matrix, Φ_coll is the collocation-point
12//! RBF matrix, λ = `pde_weight`, and f_coll are the target residual values at
13//! collocation points (usually zero for a homogeneous PDE).
14//!
15//! This formulation is equivalent to minimising
16//!
17//!   ‖ Φ_data w − y ‖² + λ ‖ r(Φ_coll w) ‖²
18//!
19//! where r is the PDE residual operator.  By adding extra rows to the least-
20//! squares system we avoid the need for specialised constrained solvers.
21//!
22//! ## References
23//!
24//! - Kansa, E.J. (1990). *Multiquadrics — a scattered data approximation
25//!   scheme with applications to computational fluid-dynamics*.
26//! - Raissi, M., Perdikaris, P., Karniadakis, G.E. (2019). *Physics-informed
27//!   neural networks*.
28
29use crate::error::InterpolateError;
30
31// ---------------------------------------------------------------------------
32// PDE residual trait
33// ---------------------------------------------------------------------------
34
35/// A differentiable residual operator r(x, y, u) = L\[u\](x, y) − f(x, y).
36///
37/// Implementors encode the PDE; the interpolation penalty minimises the norm
38/// of the residual at the collocation points.
39pub trait PdeResidual: Send + Sync {
40    /// Compute the PDE residual at point `(x, y)` given the interpolated
41    /// value `u`.  Should return zero when the PDE is satisfied.
42    fn residual(&self, x: f64, y: f64, u: f64) -> f64;
43}
44
45// ---------------------------------------------------------------------------
46// Built-in residuals
47// ---------------------------------------------------------------------------
48
49/// Simplified Laplace residual: r(x, y, u) = u − f.
50///
51/// In a full implementation the Laplacian ∇²u would be approximated via
52/// finite differences on the RBF expansion.  Here we use the zero-order
53/// algebraic approximation r = u − f, which drives the fitted values towards
54/// f at collocation points.
55#[derive(Debug, Clone, Copy)]
56pub struct LaplaceResidual {
57    /// Right-hand side of the PDE  ∇²u = f.
58    pub f: f64,
59}
60
61impl PdeResidual for LaplaceResidual {
62    fn residual(&self, _x: f64, _y: f64, u: f64) -> f64 {
63        u - self.f
64    }
65}
66
67// ---------------------------------------------------------------------------
68// Configuration
69// ---------------------------------------------------------------------------
70
71/// Configuration for [`PhysicsInformedInterp`].
72#[derive(Debug, Clone)]
73pub struct PhysicsInterpConfig {
74    /// Penalty weight λ for the PDE constraint rows.
75    pub pde_weight: f64,
76    /// Number of interior collocation points where the PDE is enforced.
77    /// These are generated on a regular grid inside the data bounding box.
78    pub n_collocation: usize,
79    /// Shape parameter ε for the Gaussian RBF φ(r) = exp(-(ε r)²).
80    pub rbf_epsilon: f64,
81    /// Maximum number of iterations (reserved for future iterative solvers).
82    pub max_iter: usize,
83    /// Convergence tolerance (reserved for future iterative solvers).
84    pub tol: f64,
85}
86
87impl Default for PhysicsInterpConfig {
88    fn default() -> Self {
89        Self {
90            pde_weight: 1.0,
91            n_collocation: 16,
92            rbf_epsilon: 1.0,
93            max_iter: 200,
94            tol: 1e-8,
95        }
96    }
97}
98
99// ---------------------------------------------------------------------------
100// Main struct
101// ---------------------------------------------------------------------------
102
103/// Physics-informed RBF interpolator.
104///
105/// Enforces a PDE constraint at a grid of collocation points by augmenting
106/// the standard RBF least-squares system with additional penalty rows.
107///
108/// # Example
109///
110/// ```rust
111/// use scirs2_interpolate::physics_interp::{
112///     PhysicsInformedInterp, PhysicsInterpConfig, LaplaceResidual,
113/// };
114///
115/// let config = PhysicsInterpConfig {
116///     pde_weight: 0.5,
117///     n_collocation: 9,
118///     rbf_epsilon: 2.0,
119///     ..PhysicsInterpConfig::default()
120/// };
121/// let mut interp = PhysicsInformedInterp::new(config);
122///
123/// let points = vec![[0.0_f64, 0.0], [1.0, 0.0], [0.5, 1.0]];
124/// let values = vec![0.0, 1.0, 0.5];
125/// let pde = LaplaceResidual { f: 0.0 };
126///
127/// interp.fit(&points, &values, &pde).expect("fit should succeed");
128/// let out = interp.evaluate(&points).expect("evaluate should succeed");
129/// ```
130#[derive(Debug)]
131pub struct PhysicsInformedInterp {
132    config: PhysicsInterpConfig,
133    data_points: Vec<[f64; 2]>,
134    data_values: Vec<f64>,
135    rbf_weights: Vec<f64>,
136    collocation_points: Vec<[f64; 2]>,
137}
138
139impl PhysicsInformedInterp {
140    /// Create a new interpolator with the given configuration.
141    pub fn new(config: PhysicsInterpConfig) -> Self {
142        Self {
143            config,
144            data_points: Vec::new(),
145            data_values: Vec::new(),
146            rbf_weights: Vec::new(),
147            collocation_points: Vec::new(),
148        }
149    }
150
151    /// Fit the physics-informed RBF to `points` / `values` with PDE `pde`.
152    ///
153    /// Internally, collocation points are placed on a regular grid inside the
154    /// bounding box of the data.  The combined least-squares system is solved
155    /// via normal equations (Φᵀ Φ w = Φᵀ y) using Gaussian elimination.
156    pub fn fit<P: PdeResidual>(
157        &mut self,
158        points: &[[f64; 2]],
159        values: &[f64],
160        pde: &P,
161    ) -> Result<(), InterpolateError> {
162        let nd = points.len();
163        if nd == 0 {
164            return Err(InterpolateError::InsufficientData(
165                "physics_interp: at least one data point required".into(),
166            ));
167        }
168        if values.len() != nd {
169            return Err(InterpolateError::ShapeMismatch {
170                expected: nd.to_string(),
171                actual: values.len().to_string(),
172                object: "values".into(),
173            });
174        }
175        if self.config.rbf_epsilon <= 0.0 {
176            return Err(InterpolateError::InvalidInput {
177                message: "physics_interp: rbf_epsilon must be positive".into(),
178            });
179        }
180
181        // Generate collocation points on a grid within the data bounding box
182        let coll_pts = generate_collocation_points(points, self.config.n_collocation);
183        let nc = coll_pts.len();
184
185        // Number of RBF basis centres = number of data points
186        let nb = nd; // basis centres are placed at data points
187
188        // Build augmented matrix Φ_aug ∈ R^{(nd + nc) × nb}
189        //   top nd rows: data constraints
190        //   bottom nc rows: PDE penalty (scaled by √λ)
191        let sqrt_lam = self.config.pde_weight.sqrt();
192        let n_rows = nd + nc;
193
194        let mut phi_aug: Vec<f64> = vec![0.0; n_rows * nb];
195        let mut rhs: Vec<f64> = vec![0.0; n_rows];
196
197        // Data rows
198        for i in 0..nd {
199            for j in 0..nb {
200                let r = dist2(&points[i], &points[j]);
201                phi_aug[i * nb + j] = gaussian_rbf(r, self.config.rbf_epsilon);
202            }
203            rhs[i] = values[i];
204        }
205
206        // Collocation rows (PDE penalty)
207        for (ci, cp) in coll_pts.iter().enumerate() {
208            let row = nd + ci;
209            let u_approx_dummy = 0.0_f64; // placeholder for residual target
210            let target = pde.residual(cp[0], cp[1], u_approx_dummy);
211            for j in 0..nb {
212                let r = dist2(cp, &points[j]);
213                phi_aug[row * nb + j] = sqrt_lam * gaussian_rbf(r, self.config.rbf_epsilon);
214            }
215            rhs[row] = sqrt_lam * target;
216        }
217
218        // Solve via normal equations: Φᵀ Φ w = Φᵀ rhs
219        let w = solve_normal_equations(&phi_aug, &rhs, n_rows, nb)?;
220
221        self.data_points = points.to_vec();
222        self.data_values = values.to_vec();
223        self.rbf_weights = w;
224        self.collocation_points = coll_pts;
225        Ok(())
226    }
227
228    /// Evaluate the fitted interpolant at `query_points`.
229    pub fn evaluate(&self, query_points: &[[f64; 2]]) -> Result<Vec<f64>, InterpolateError> {
230        if self.rbf_weights.is_empty() {
231            return Err(InterpolateError::InvalidState(
232                "physics_interp: interpolator not fitted — call fit() first".into(),
233            ));
234        }
235        let out = query_points
236            .iter()
237            .map(|q| {
238                self.data_points
239                    .iter()
240                    .zip(self.rbf_weights.iter())
241                    .map(|(p, &w)| {
242                        let r = dist2(q, p);
243                        w * gaussian_rbf(r, self.config.rbf_epsilon)
244                    })
245                    .sum()
246            })
247            .collect();
248        Ok(out)
249    }
250
251    /// Compute the RMS PDE residual norm at the collocation points.
252    ///
253    /// Returns 0.0 if no collocation points exist or the interpolant is not
254    /// fitted.
255    pub fn pde_residual_norm<P: PdeResidual>(&self, pde: &P) -> f64 {
256        if self.rbf_weights.is_empty() || self.collocation_points.is_empty() {
257            return 0.0;
258        }
259        let sum_sq: f64 = self
260            .collocation_points
261            .iter()
262            .map(|cp| {
263                let u: f64 = self
264                    .data_points
265                    .iter()
266                    .zip(self.rbf_weights.iter())
267                    .map(|(p, &w)| {
268                        let r = dist2(cp, p);
269                        w * gaussian_rbf(r, self.config.rbf_epsilon)
270                    })
271                    .sum();
272                let r = pde.residual(cp[0], cp[1], u);
273                r * r
274            })
275            .sum();
276        (sum_sq / self.collocation_points.len() as f64).sqrt()
277    }
278
279    /// Total loss = data_fit_mse + pde_weight * pde_residual_mse.
280    pub fn total_loss<P: PdeResidual>(&self, pde: &P) -> f64 {
281        if self.rbf_weights.is_empty() {
282            return f64::INFINITY;
283        }
284        // Data fit MSE
285        let data_mse: f64 = if self.data_points.is_empty() {
286            0.0
287        } else {
288            let ss: f64 = self
289                .data_points
290                .iter()
291                .zip(self.data_values.iter())
292                .map(|(p, &y)| {
293                    let u: f64 = self
294                        .data_points
295                        .iter()
296                        .zip(self.rbf_weights.iter())
297                        .map(|(q, &w)| {
298                            let r = dist2(p, q);
299                            w * gaussian_rbf(r, self.config.rbf_epsilon)
300                        })
301                        .sum();
302                    (u - y) * (u - y)
303                })
304                .sum();
305            ss / self.data_points.len() as f64
306        };
307
308        // PDE residual MSE (un-scaled)
309        let pde_norm = self.pde_residual_norm(pde);
310        data_mse + self.config.pde_weight * pde_norm * pde_norm
311    }
312}
313
314// ---------------------------------------------------------------------------
315// Internal free functions
316// ---------------------------------------------------------------------------
317
318/// Squared Euclidean distance between two 2D points.
319#[inline]
320fn dist2(a: &[f64; 2], b: &[f64; 2]) -> f64 {
321    let dx = a[0] - b[0];
322    let dy = a[1] - b[1];
323    (dx * dx + dy * dy).sqrt()
324}
325
326/// Gaussian RBF: φ(r) = exp(-(ε r)²).
327#[inline]
328fn gaussian_rbf(r: f64, epsilon: f64) -> f64 {
329    let er = epsilon * r;
330    (-(er * er)).exp()
331}
332
333/// Place `n_coll` points on a regular grid inside the bounding box of `pts`.
334fn generate_collocation_points(pts: &[[f64; 2]], n_coll: usize) -> Vec<[f64; 2]> {
335    if pts.is_empty() || n_coll == 0 {
336        return Vec::new();
337    }
338    let (mut xmin, mut xmax) = (pts[0][0], pts[0][0]);
339    let (mut ymin, mut ymax) = (pts[0][1], pts[0][1]);
340    for p in pts {
341        xmin = xmin.min(p[0]);
342        xmax = xmax.max(p[0]);
343        ymin = ymin.min(p[1]);
344        ymax = ymax.max(p[1]);
345    }
346    // Inset slightly
347    let dx = (xmax - xmin).max(1e-10) * 0.1;
348    let dy = (ymax - ymin).max(1e-10) * 0.1;
349    xmin += dx;
350    xmax -= dx;
351    ymin += dy;
352    ymax -= dy;
353
354    let side = (n_coll as f64).sqrt().ceil() as usize;
355    let side = side.max(1);
356    let mut coll = Vec::with_capacity(side * side);
357    for i in 0..side {
358        for j in 0..side {
359            let x = xmin + (xmax - xmin) * (i as f64 + 0.5) / side as f64;
360            let y = ymin + (ymax - ymin) * (j as f64 + 0.5) / side as f64;
361            coll.push([x, y]);
362        }
363    }
364    coll
365}
366
367/// Solve the over-determined system Φ w = rhs via normal equations Φᵀ Φ w = Φᵀ rhs.
368///
369/// `phi` is stored row-major with shape `(n_rows, n_cols)`.
370fn solve_normal_equations(
371    phi: &[f64],
372    rhs: &[f64],
373    n_rows: usize,
374    n_cols: usize,
375) -> Result<Vec<f64>, InterpolateError> {
376    // AtA = Φᵀ Φ  (n_cols × n_cols)
377    let mut ata: Vec<f64> = vec![0.0; n_cols * n_cols];
378    // Atb = Φᵀ rhs  (n_cols)
379    let mut atb: Vec<f64> = vec![0.0; n_cols];
380
381    for k in 0..n_rows {
382        let row = &phi[k * n_cols..(k + 1) * n_cols];
383        for i in 0..n_cols {
384            atb[i] += row[i] * rhs[k];
385            for j in 0..n_cols {
386                ata[i * n_cols + j] += row[i] * row[j];
387            }
388        }
389    }
390
391    // Add a small Tikhonov regulariser for numerical stability
392    let reg = 1e-12;
393    for i in 0..n_cols {
394        ata[i * n_cols + i] += reg;
395    }
396
397    // Solve AtA w = Atb via Gaussian elimination with partial pivoting
398    crate::gpu_rbf::solve_linear_system(&ata, &atb, n_cols)
399}
400
401// ---------------------------------------------------------------------------
402// Tests
403// ---------------------------------------------------------------------------
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    fn make_config(pde_weight: f64, n_coll: usize) -> PhysicsInterpConfig {
410        PhysicsInterpConfig {
411            pde_weight,
412            n_collocation: n_coll,
413            rbf_epsilon: 2.0,
414            max_iter: 100,
415            tol: 1e-8,
416        }
417    }
418
419    /// With pde_weight = 0 the system degenerates to standard RBF — the fitted
420    /// values at training points should reproduce the data within tolerance.
421    #[test]
422    fn test_zero_pde_weight_is_standard_rbf() {
423        let points = vec![[0.0_f64, 0.0], [1.0, 0.0], [0.5, 0.8], [0.3, 0.3]];
424        let values = vec![1.0, 2.0, 1.5, 0.8];
425
426        let mut interp = PhysicsInformedInterp::new(make_config(0.0, 4));
427        let pde = LaplaceResidual { f: 0.0 };
428        interp.fit(&points, &values, &pde).expect("fit failed");
429
430        let out = interp.evaluate(&points).expect("eval failed");
431        for (got, &exp) in out.iter().zip(values.iter()) {
432            assert!(
433                (got - exp).abs() < 5e-4,
434                "zero pde_weight: got {got:.6} expected {exp:.6}"
435            );
436        }
437    }
438
439    /// A higher pde_weight should drive the PDE residual norm lower when the
440    /// PDE target is consistent with the data.
441    #[test]
442    fn test_higher_pde_weight_reduces_residual() {
443        let points = vec![[0.0_f64, 0.0], [1.0, 0.0], [0.5, 1.0]];
444        let values = vec![0.0, 0.0, 0.0];
445        let pde = LaplaceResidual { f: 0.0 }; // u = 0 satisfies pde exactly
446
447        let mut low = PhysicsInformedInterp::new(make_config(0.01, 4));
448        let mut high = PhysicsInformedInterp::new(make_config(100.0, 4));
449
450        low.fit(&points, &values, &pde).expect("fit low failed");
451        high.fit(&points, &values, &pde).expect("fit high failed");
452
453        let r_low = low.pde_residual_norm(&pde);
454        let r_high = high.pde_residual_norm(&pde);
455
456        // Higher weight should give equal or lower residual norm
457        assert!(
458            r_high <= r_low + 1e-6,
459            "higher pde_weight should reduce residual: low={r_low:.6} high={r_high:.6}"
460        );
461    }
462
463    /// Evaluate at training points must be within reasonable tolerance.
464    #[test]
465    fn test_evaluate_at_training_points() {
466        let points = vec![[0.1_f64, 0.1], [0.9, 0.1], [0.5, 0.9]];
467        let values = vec![1.0, 3.0, 2.0];
468
469        let pde = LaplaceResidual { f: 0.5 };
470        let mut interp = PhysicsInformedInterp::new(make_config(1e-4, 4));
471        interp.fit(&points, &values, &pde).expect("fit failed");
472
473        let out = interp.evaluate(&points).expect("eval failed");
474        for (got, &exp) in out.iter().zip(values.iter()) {
475            assert!(
476                (got - exp).abs() < 0.5,
477                "evaluate at training point: got {got:.4} expected {exp:.4}"
478            );
479        }
480    }
481
482    /// LaplaceResidual::residual(x, y, u) == u - f for any x, y.
483    #[test]
484    fn test_laplace_residual_formula() {
485        let pde = LaplaceResidual { f: 3.0 };
486        for u in [0.0, 1.0, 3.0, -2.5, 7.0] {
487            let r = pde.residual(0.5, 0.5, u);
488            assert!(
489                (r - (u - 3.0)).abs() < 1e-15,
490                "LaplaceResidual: got {r}, expected {:.1}",
491                u - 3.0
492            );
493        }
494    }
495
496    /// total_loss should be non-negative.
497    #[test]
498    fn test_total_loss_non_negative() {
499        let points = vec![[0.0_f64, 0.0], [1.0, 1.0]];
500        let values = vec![0.0, 1.0];
501        let pde = LaplaceResidual { f: 0.0 };
502        let mut interp = PhysicsInformedInterp::new(make_config(1.0, 4));
503        interp.fit(&points, &values, &pde).expect("fit failed");
504        let loss = interp.total_loss(&pde);
505        assert!(loss >= 0.0, "total_loss must be non-negative, got {loss}");
506    }
507}