1use crate::error::InterpolateError;
30
31pub trait PdeResidual: Send + Sync {
40 fn residual(&self, x: f64, y: f64, u: f64) -> f64;
43}
44
45#[derive(Debug, Clone, Copy)]
56pub struct LaplaceResidual {
57 pub f: f64,
59}
60
61impl PdeResidual for LaplaceResidual {
62 fn residual(&self, _x: f64, _y: f64, u: f64) -> f64 {
63 u - self.f
64 }
65}
66
67#[derive(Debug, Clone)]
73pub struct PhysicsInterpConfig {
74 pub pde_weight: f64,
76 pub n_collocation: usize,
79 pub rbf_epsilon: f64,
81 pub max_iter: usize,
83 pub tol: f64,
85}
86
87impl Default for PhysicsInterpConfig {
88 fn default() -> Self {
89 Self {
90 pde_weight: 1.0,
91 n_collocation: 16,
92 rbf_epsilon: 1.0,
93 max_iter: 200,
94 tol: 1e-8,
95 }
96 }
97}
98
99#[derive(Debug)]
131pub struct PhysicsInformedInterp {
132 config: PhysicsInterpConfig,
133 data_points: Vec<[f64; 2]>,
134 data_values: Vec<f64>,
135 rbf_weights: Vec<f64>,
136 collocation_points: Vec<[f64; 2]>,
137}
138
139impl PhysicsInformedInterp {
140 pub fn new(config: PhysicsInterpConfig) -> Self {
142 Self {
143 config,
144 data_points: Vec::new(),
145 data_values: Vec::new(),
146 rbf_weights: Vec::new(),
147 collocation_points: Vec::new(),
148 }
149 }
150
151 pub fn fit<P: PdeResidual>(
157 &mut self,
158 points: &[[f64; 2]],
159 values: &[f64],
160 pde: &P,
161 ) -> Result<(), InterpolateError> {
162 let nd = points.len();
163 if nd == 0 {
164 return Err(InterpolateError::InsufficientData(
165 "physics_interp: at least one data point required".into(),
166 ));
167 }
168 if values.len() != nd {
169 return Err(InterpolateError::ShapeMismatch {
170 expected: nd.to_string(),
171 actual: values.len().to_string(),
172 object: "values".into(),
173 });
174 }
175 if self.config.rbf_epsilon <= 0.0 {
176 return Err(InterpolateError::InvalidInput {
177 message: "physics_interp: rbf_epsilon must be positive".into(),
178 });
179 }
180
181 let coll_pts = generate_collocation_points(points, self.config.n_collocation);
183 let nc = coll_pts.len();
184
185 let nb = nd; let sqrt_lam = self.config.pde_weight.sqrt();
192 let n_rows = nd + nc;
193
194 let mut phi_aug: Vec<f64> = vec![0.0; n_rows * nb];
195 let mut rhs: Vec<f64> = vec![0.0; n_rows];
196
197 for i in 0..nd {
199 for j in 0..nb {
200 let r = dist2(&points[i], &points[j]);
201 phi_aug[i * nb + j] = gaussian_rbf(r, self.config.rbf_epsilon);
202 }
203 rhs[i] = values[i];
204 }
205
206 for (ci, cp) in coll_pts.iter().enumerate() {
208 let row = nd + ci;
209 let u_approx_dummy = 0.0_f64; let target = pde.residual(cp[0], cp[1], u_approx_dummy);
211 for j in 0..nb {
212 let r = dist2(cp, &points[j]);
213 phi_aug[row * nb + j] = sqrt_lam * gaussian_rbf(r, self.config.rbf_epsilon);
214 }
215 rhs[row] = sqrt_lam * target;
216 }
217
218 let w = solve_normal_equations(&phi_aug, &rhs, n_rows, nb)?;
220
221 self.data_points = points.to_vec();
222 self.data_values = values.to_vec();
223 self.rbf_weights = w;
224 self.collocation_points = coll_pts;
225 Ok(())
226 }
227
228 pub fn evaluate(&self, query_points: &[[f64; 2]]) -> Result<Vec<f64>, InterpolateError> {
230 if self.rbf_weights.is_empty() {
231 return Err(InterpolateError::InvalidState(
232 "physics_interp: interpolator not fitted — call fit() first".into(),
233 ));
234 }
235 let out = query_points
236 .iter()
237 .map(|q| {
238 self.data_points
239 .iter()
240 .zip(self.rbf_weights.iter())
241 .map(|(p, &w)| {
242 let r = dist2(q, p);
243 w * gaussian_rbf(r, self.config.rbf_epsilon)
244 })
245 .sum()
246 })
247 .collect();
248 Ok(out)
249 }
250
251 pub fn pde_residual_norm<P: PdeResidual>(&self, pde: &P) -> f64 {
256 if self.rbf_weights.is_empty() || self.collocation_points.is_empty() {
257 return 0.0;
258 }
259 let sum_sq: f64 = self
260 .collocation_points
261 .iter()
262 .map(|cp| {
263 let u: f64 = self
264 .data_points
265 .iter()
266 .zip(self.rbf_weights.iter())
267 .map(|(p, &w)| {
268 let r = dist2(cp, p);
269 w * gaussian_rbf(r, self.config.rbf_epsilon)
270 })
271 .sum();
272 let r = pde.residual(cp[0], cp[1], u);
273 r * r
274 })
275 .sum();
276 (sum_sq / self.collocation_points.len() as f64).sqrt()
277 }
278
279 pub fn total_loss<P: PdeResidual>(&self, pde: &P) -> f64 {
281 if self.rbf_weights.is_empty() {
282 return f64::INFINITY;
283 }
284 let data_mse: f64 = if self.data_points.is_empty() {
286 0.0
287 } else {
288 let ss: f64 = self
289 .data_points
290 .iter()
291 .zip(self.data_values.iter())
292 .map(|(p, &y)| {
293 let u: f64 = self
294 .data_points
295 .iter()
296 .zip(self.rbf_weights.iter())
297 .map(|(q, &w)| {
298 let r = dist2(p, q);
299 w * gaussian_rbf(r, self.config.rbf_epsilon)
300 })
301 .sum();
302 (u - y) * (u - y)
303 })
304 .sum();
305 ss / self.data_points.len() as f64
306 };
307
308 let pde_norm = self.pde_residual_norm(pde);
310 data_mse + self.config.pde_weight * pde_norm * pde_norm
311 }
312}
313
314#[inline]
320fn dist2(a: &[f64; 2], b: &[f64; 2]) -> f64 {
321 let dx = a[0] - b[0];
322 let dy = a[1] - b[1];
323 (dx * dx + dy * dy).sqrt()
324}
325
326#[inline]
328fn gaussian_rbf(r: f64, epsilon: f64) -> f64 {
329 let er = epsilon * r;
330 (-(er * er)).exp()
331}
332
333fn generate_collocation_points(pts: &[[f64; 2]], n_coll: usize) -> Vec<[f64; 2]> {
335 if pts.is_empty() || n_coll == 0 {
336 return Vec::new();
337 }
338 let (mut xmin, mut xmax) = (pts[0][0], pts[0][0]);
339 let (mut ymin, mut ymax) = (pts[0][1], pts[0][1]);
340 for p in pts {
341 xmin = xmin.min(p[0]);
342 xmax = xmax.max(p[0]);
343 ymin = ymin.min(p[1]);
344 ymax = ymax.max(p[1]);
345 }
346 let dx = (xmax - xmin).max(1e-10) * 0.1;
348 let dy = (ymax - ymin).max(1e-10) * 0.1;
349 xmin += dx;
350 xmax -= dx;
351 ymin += dy;
352 ymax -= dy;
353
354 let side = (n_coll as f64).sqrt().ceil() as usize;
355 let side = side.max(1);
356 let mut coll = Vec::with_capacity(side * side);
357 for i in 0..side {
358 for j in 0..side {
359 let x = xmin + (xmax - xmin) * (i as f64 + 0.5) / side as f64;
360 let y = ymin + (ymax - ymin) * (j as f64 + 0.5) / side as f64;
361 coll.push([x, y]);
362 }
363 }
364 coll
365}
366
367fn solve_normal_equations(
371 phi: &[f64],
372 rhs: &[f64],
373 n_rows: usize,
374 n_cols: usize,
375) -> Result<Vec<f64>, InterpolateError> {
376 let mut ata: Vec<f64> = vec![0.0; n_cols * n_cols];
378 let mut atb: Vec<f64> = vec![0.0; n_cols];
380
381 for k in 0..n_rows {
382 let row = &phi[k * n_cols..(k + 1) * n_cols];
383 for i in 0..n_cols {
384 atb[i] += row[i] * rhs[k];
385 for j in 0..n_cols {
386 ata[i * n_cols + j] += row[i] * row[j];
387 }
388 }
389 }
390
391 let reg = 1e-12;
393 for i in 0..n_cols {
394 ata[i * n_cols + i] += reg;
395 }
396
397 crate::gpu_rbf::solve_linear_system(&ata, &atb, n_cols)
399}
400
401#[cfg(test)]
406mod tests {
407 use super::*;
408
409 fn make_config(pde_weight: f64, n_coll: usize) -> PhysicsInterpConfig {
410 PhysicsInterpConfig {
411 pde_weight,
412 n_collocation: n_coll,
413 rbf_epsilon: 2.0,
414 max_iter: 100,
415 tol: 1e-8,
416 }
417 }
418
419 #[test]
422 fn test_zero_pde_weight_is_standard_rbf() {
423 let points = vec![[0.0_f64, 0.0], [1.0, 0.0], [0.5, 0.8], [0.3, 0.3]];
424 let values = vec![1.0, 2.0, 1.5, 0.8];
425
426 let mut interp = PhysicsInformedInterp::new(make_config(0.0, 4));
427 let pde = LaplaceResidual { f: 0.0 };
428 interp.fit(&points, &values, &pde).expect("fit failed");
429
430 let out = interp.evaluate(&points).expect("eval failed");
431 for (got, &exp) in out.iter().zip(values.iter()) {
432 assert!(
433 (got - exp).abs() < 5e-4,
434 "zero pde_weight: got {got:.6} expected {exp:.6}"
435 );
436 }
437 }
438
439 #[test]
442 fn test_higher_pde_weight_reduces_residual() {
443 let points = vec![[0.0_f64, 0.0], [1.0, 0.0], [0.5, 1.0]];
444 let values = vec![0.0, 0.0, 0.0];
445 let pde = LaplaceResidual { f: 0.0 }; let mut low = PhysicsInformedInterp::new(make_config(0.01, 4));
448 let mut high = PhysicsInformedInterp::new(make_config(100.0, 4));
449
450 low.fit(&points, &values, &pde).expect("fit low failed");
451 high.fit(&points, &values, &pde).expect("fit high failed");
452
453 let r_low = low.pde_residual_norm(&pde);
454 let r_high = high.pde_residual_norm(&pde);
455
456 assert!(
458 r_high <= r_low + 1e-6,
459 "higher pde_weight should reduce residual: low={r_low:.6} high={r_high:.6}"
460 );
461 }
462
463 #[test]
465 fn test_evaluate_at_training_points() {
466 let points = vec![[0.1_f64, 0.1], [0.9, 0.1], [0.5, 0.9]];
467 let values = vec![1.0, 3.0, 2.0];
468
469 let pde = LaplaceResidual { f: 0.5 };
470 let mut interp = PhysicsInformedInterp::new(make_config(1e-4, 4));
471 interp.fit(&points, &values, &pde).expect("fit failed");
472
473 let out = interp.evaluate(&points).expect("eval failed");
474 for (got, &exp) in out.iter().zip(values.iter()) {
475 assert!(
476 (got - exp).abs() < 0.5,
477 "evaluate at training point: got {got:.4} expected {exp:.4}"
478 );
479 }
480 }
481
482 #[test]
484 fn test_laplace_residual_formula() {
485 let pde = LaplaceResidual { f: 3.0 };
486 for u in [0.0, 1.0, 3.0, -2.5, 7.0] {
487 let r = pde.residual(0.5, 0.5, u);
488 assert!(
489 (r - (u - 3.0)).abs() < 1e-15,
490 "LaplaceResidual: got {r}, expected {:.1}",
491 u - 3.0
492 );
493 }
494 }
495
496 #[test]
498 fn test_total_loss_non_negative() {
499 let points = vec![[0.0_f64, 0.0], [1.0, 1.0]];
500 let values = vec![0.0, 1.0];
501 let pde = LaplaceResidual { f: 0.0 };
502 let mut interp = PhysicsInformedInterp::new(make_config(1.0, 4));
503 interp.fit(&points, &values, &pde).expect("fit failed");
504 let loss = interp.total_loss(&pde);
505 assert!(loss >= 0.0, "total_loss must be non-negative, got {loss}");
506 }
507}