1use super::{pairwise_sq_distances, solve_general, SurrogateModel};
21use crate::error::{OptimizeError, OptimizeResult};
22use scirs2_core::ndarray::{Array1, Array2};
23
24#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum RbfKernel {
27 Polyharmonic(u32),
30 Multiquadric {
32 shape_param: f64,
34 },
35 InverseMultiquadric {
37 shape_param: f64,
39 },
40 ThinPlateSpline,
42 Gaussian {
44 sigma: f64,
46 },
47}
48
49impl Default for RbfKernel {
50 fn default() -> Self {
51 RbfKernel::Polyharmonic(3) }
53}
54
55impl RbfKernel {
56 fn evaluate(&self, sq_dist: f64) -> f64 {
58 let r = sq_dist.sqrt();
59 match *self {
60 RbfKernel::Polyharmonic(k) => {
61 if r < 1e-30 {
62 0.0
63 } else {
64 r.powi(k as i32)
65 }
66 }
67 RbfKernel::Multiquadric { shape_param } => (sq_dist + shape_param * shape_param).sqrt(),
68 RbfKernel::InverseMultiquadric { shape_param } => {
69 1.0 / (sq_dist + shape_param * shape_param).sqrt()
70 }
71 RbfKernel::ThinPlateSpline => {
72 if r < 1e-30 {
73 0.0
74 } else {
75 sq_dist * r.ln()
76 }
77 }
78 RbfKernel::Gaussian { sigma } => (-sq_dist / (2.0 * sigma * sigma)).exp(),
79 }
80 }
81
82 fn needs_polynomial_tail(&self) -> bool {
84 matches!(
85 self,
86 RbfKernel::Polyharmonic(_) | RbfKernel::ThinPlateSpline
87 )
88 }
89
90 fn polynomial_degree(&self) -> usize {
92 match *self {
93 RbfKernel::Polyharmonic(k) => {
94 (k as usize) / 2
96 }
97 RbfKernel::ThinPlateSpline => 1,
98 _ => 0,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct RbfOptions {
106 pub kernel: RbfKernel,
108 pub regularization: f64,
110 pub normalize: bool,
112}
113
114impl Default for RbfOptions {
115 fn default() -> Self {
116 Self {
117 kernel: RbfKernel::default(),
118 regularization: 1e-10,
119 normalize: true,
120 }
121 }
122}
123
124pub struct RbfSurrogate {
126 options: RbfOptions,
127 x_train: Option<Array2<f64>>,
129 y_train: Option<Array1<f64>>,
131 weights: Option<Array1<f64>>,
133 poly_coeffs: Option<Array1<f64>>,
135 x_mean: Option<Array1<f64>>,
137 x_std: Option<Array1<f64>>,
138 y_mean: f64,
139 y_std: f64,
140 kernel_matrix: Option<Array2<f64>>,
142}
143
144impl RbfSurrogate {
145 pub fn new(options: RbfOptions) -> Self {
147 Self {
148 options,
149 x_train: None,
150 y_train: None,
151 weights: None,
152 poly_coeffs: None,
153 x_mean: None,
154 x_std: None,
155 y_mean: 0.0,
156 y_std: 1.0,
157 kernel_matrix: None,
158 }
159 }
160
161 fn compute_kernel_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
163 let n = x.nrows();
164 let sq_dists = pairwise_sq_distances(x, x);
165 let mut kernel = Array2::zeros((n, n));
166 for i in 0..n {
167 for j in 0..n {
168 kernel[[i, j]] = self.options.kernel.evaluate(sq_dists[[i, j]]);
169 }
170 }
171 kernel
172 }
173
174 fn compute_kernel_vector(&self, x: &Array1<f64>, x_train: &Array2<f64>) -> Array1<f64> {
176 let n = x_train.nrows();
177 let mut k_vec = Array1::zeros(n);
178 for i in 0..n {
179 let mut sq_dist = 0.0;
180 for j in 0..x.len() {
181 let diff = x[j] - x_train[[i, j]];
182 sq_dist += diff * diff;
183 }
184 k_vec[i] = self.options.kernel.evaluate(sq_dist);
185 }
186 k_vec
187 }
188
189 fn build_polynomial_matrix(&self, x: &Array2<f64>, degree: usize) -> Array2<f64> {
191 let n = x.nrows();
192 let d = x.ncols();
193
194 if degree == 0 {
195 Array2::ones((n, 1))
197 } else if degree == 1 {
198 let ncols = 1 + d;
200 let mut p = Array2::zeros((n, ncols));
201 for i in 0..n {
202 p[[i, 0]] = 1.0;
203 for j in 0..d {
204 p[[i, j + 1]] = x[[i, j]];
205 }
206 }
207 p
208 } else {
209 let ncols = 1 + d;
211 let mut p = Array2::zeros((n, ncols));
212 for i in 0..n {
213 p[[i, 0]] = 1.0;
214 for j in 0..d {
215 p[[i, j + 1]] = x[[i, j]];
216 }
217 }
218 p
219 }
220 }
221
222 fn normalize_x(&self, x: &Array2<f64>) -> Array2<f64> {
224 if let (Some(ref mean), Some(ref std)) = (&self.x_mean, &self.x_std) {
225 let mut normalized = x.clone();
226 for i in 0..x.nrows() {
227 for j in 0..x.ncols() {
228 let s = if std[j] > 1e-30 { std[j] } else { 1.0 };
229 normalized[[i, j]] = (x[[i, j]] - mean[j]) / s;
230 }
231 }
232 normalized
233 } else {
234 x.clone()
235 }
236 }
237
238 fn normalize_x_point(&self, x: &Array1<f64>) -> Array1<f64> {
240 if let (Some(ref mean), Some(ref std)) = (&self.x_mean, &self.x_std) {
241 let mut normalized = x.clone();
242 for j in 0..x.len() {
243 let s = if std[j] > 1e-30 { std[j] } else { 1.0 };
244 normalized[j] = (x[j] - mean[j]) / s;
245 }
246 normalized
247 } else {
248 x.clone()
249 }
250 }
251}
252
253impl SurrogateModel for RbfSurrogate {
254 fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
255 let n = x.nrows();
256 let d = x.ncols();
257
258 if n < d + 1 {
259 return Err(OptimizeError::InvalidInput(format!(
260 "Need at least {} data points for {} dimensions, got {}",
261 d + 1,
262 d,
263 n
264 )));
265 }
266
267 if self.options.normalize {
269 let mut x_mean = Array1::zeros(d);
270 let mut x_std = Array1::zeros(d);
271 for j in 0..d {
272 let mut sum = 0.0;
273 for i in 0..n {
274 sum += x[[i, j]];
275 }
276 x_mean[j] = sum / n as f64;
277
278 let mut sq_sum = 0.0;
279 for i in 0..n {
280 let diff = x[[i, j]] - x_mean[j];
281 sq_sum += diff * diff;
282 }
283 x_std[j] = (sq_sum / n as f64).sqrt();
284 if x_std[j] < 1e-30 {
285 x_std[j] = 1.0;
286 }
287 }
288
289 self.x_mean = Some(x_mean);
290 self.x_std = Some(x_std);
291
292 let y_sum: f64 = y.iter().sum();
293 self.y_mean = y_sum / n as f64;
294 let y_var: f64 = y.iter().map(|yi| (yi - self.y_mean).powi(2)).sum::<f64>() / n as f64;
295 self.y_std = y_var.sqrt();
296 if self.y_std < 1e-30 {
297 self.y_std = 1.0;
298 }
299 }
300
301 let x_norm = self.normalize_x(x);
303 let y_norm: Array1<f64> = if self.options.normalize {
304 y.mapv(|yi| (yi - self.y_mean) / self.y_std)
305 } else {
306 y.clone()
307 };
308
309 let mut kernel = self.compute_kernel_matrix(&x_norm);
311
312 for i in 0..n {
314 kernel[[i, i]] += self.options.regularization;
315 }
316
317 self.kernel_matrix = Some(kernel.clone());
318
319 if self.options.kernel.needs_polynomial_tail() {
320 let degree = self.options.kernel.polynomial_degree();
321 let p = self.build_polynomial_matrix(&x_norm, degree);
322 let m = p.ncols();
323
324 let total = n + m;
328 let mut aug = Array2::zeros((total, total));
329 for i in 0..n {
330 for j in 0..n {
331 aug[[i, j]] = kernel[[i, j]];
332 }
333 for j in 0..m {
334 aug[[i, n + j]] = p[[i, j]];
335 aug[[n + j, i]] = p[[i, j]];
336 }
337 }
338
339 let mut rhs = Array1::zeros(total);
340 for i in 0..n {
341 rhs[i] = y_norm[i];
342 }
343
344 let solution = solve_general(&aug, &rhs)?;
345 self.weights = Some(solution.slice(scirs2_core::ndarray::s![..n]).to_owned());
346 self.poly_coeffs = Some(solution.slice(scirs2_core::ndarray::s![n..]).to_owned());
347 } else {
348 let weights = solve_general(&kernel, &y_norm)?;
350 self.weights = Some(weights);
351 self.poly_coeffs = None;
352 }
353
354 self.x_train = Some(x_norm);
355 self.y_train = Some(y_norm);
356
357 Ok(())
358 }
359
360 fn predict(&self, x: &Array1<f64>) -> OptimizeResult<f64> {
361 let x_train = self
362 .x_train
363 .as_ref()
364 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
365 let weights = self
366 .weights
367 .as_ref()
368 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
369
370 let x_norm = self.normalize_x_point(x);
371 let k_vec = self.compute_kernel_vector(&x_norm, x_train);
372
373 let mut prediction = k_vec.dot(weights);
374
375 if let Some(ref poly_coeffs) = self.poly_coeffs {
377 let d = x_norm.len();
378 prediction += poly_coeffs[0];
380 for j in 0..d.min(poly_coeffs.len() - 1) {
382 prediction += poly_coeffs[j + 1] * x_norm[j];
383 }
384 }
385
386 if self.options.normalize {
388 prediction = prediction * self.y_std + self.y_mean;
389 }
390
391 Ok(prediction)
392 }
393
394 fn predict_with_uncertainty(&self, x: &Array1<f64>) -> OptimizeResult<(f64, f64)> {
395 let mean = self.predict(x)?;
396
397 let x_train = self
400 .x_train
401 .as_ref()
402 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
403
404 let x_norm = self.normalize_x_point(x);
405 let n = x_train.nrows();
406
407 let mut min_dist = f64::INFINITY;
409 let mut sum_inv_dist = 0.0;
410 for i in 0..n {
411 let mut sq_dist = 0.0;
412 for j in 0..x_norm.len() {
413 let diff = x_norm[j] - x_train[[i, j]];
414 sq_dist += diff * diff;
415 }
416 let dist = sq_dist.sqrt();
417 if dist < min_dist {
418 min_dist = dist;
419 }
420 if dist > 1e-30 {
421 sum_inv_dist += 1.0 / dist;
422 }
423 }
424
425 let avg_inv_dist = if n > 0 { sum_inv_dist / n as f64 } else { 1.0 };
428 let uncertainty = if avg_inv_dist > 1e-30 {
429 min_dist * avg_inv_dist
430 } else {
431 min_dist
432 };
433
434 let scaled_uncertainty = uncertainty * self.y_std;
436
437 Ok((mean, scaled_uncertainty.max(1e-10)))
438 }
439
440 fn n_samples(&self) -> usize {
441 self.x_train.as_ref().map_or(0, |x| x.nrows())
442 }
443
444 fn n_features(&self) -> usize {
445 self.x_train.as_ref().map_or(0, |x| x.ncols())
446 }
447
448 fn update(&mut self, x: &Array1<f64>, y: f64) -> OptimizeResult<()> {
449 let (new_x, new_y) = if let (Some(ref x_train), Some(ref y_train)) =
451 (&self.x_train, &self.y_train)
452 {
453 let d = x_train.ncols();
455 let n = x_train.nrows();
456 let mut x_denorm = Array2::zeros((n, d));
457 for i in 0..n {
458 for j in 0..d {
459 if self.options.normalize {
460 let s =
461 self.x_std
462 .as_ref()
463 .map_or(1.0, |s| if s[j] > 1e-30 { s[j] } else { 1.0 });
464 let m = self.x_mean.as_ref().map_or(0.0, |m| m[j]);
465 x_denorm[[i, j]] = x_train[[i, j]] * s + m;
466 } else {
467 x_denorm[[i, j]] = x_train[[i, j]];
468 }
469 }
470 }
471 let y_denorm: Array1<f64> = if self.options.normalize {
472 y_train.mapv(|yi| yi * self.y_std + self.y_mean)
473 } else {
474 y_train.clone()
475 };
476
477 let mut new_x = Array2::zeros((n + 1, d));
479 for i in 0..n {
480 for j in 0..d {
481 new_x[[i, j]] = x_denorm[[i, j]];
482 }
483 }
484 for j in 0..d {
485 new_x[[n, j]] = x[j];
486 }
487
488 let mut new_y = Array1::zeros(n + 1);
489 for i in 0..n {
490 new_y[i] = y_denorm[i];
491 }
492 new_y[n] = y;
493
494 (new_x, new_y)
495 } else {
496 let d = x.len();
497 let mut new_x = Array2::zeros((1, d));
498 for j in 0..d {
499 new_x[[0, j]] = x[j];
500 }
501 let new_y = Array1::from_vec(vec![y]);
502 (new_x, new_y)
503 };
504
505 self.fit(&new_x, &new_y)
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_rbf_cubic_interpolation() {
515 let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0])
516 .expect("Array creation failed");
517 let y_train = Array1::from_vec(vec![0.0, 1.0, 4.0, 9.0, 16.0]);
518
519 let mut rbf = RbfSurrogate::new(RbfOptions {
520 kernel: RbfKernel::Polyharmonic(3),
521 regularization: 1e-8,
522 normalize: false,
523 });
524
525 let result = rbf.fit(&x_train, &y_train);
526 assert!(result.is_ok(), "RBF fit failed: {:?}", result.err());
527
528 for i in 0..5 {
530 let x = Array1::from_vec(vec![i as f64]);
531 let pred = rbf.predict(&x).expect("Prediction failed");
532 assert!(
533 (pred - y_train[i]).abs() < 0.5,
534 "Interpolation error at {}: pred={}, actual={}",
535 i,
536 pred,
537 y_train[i]
538 );
539 }
540 }
541
542 #[test]
543 fn test_rbf_gaussian() {
544 let x_train = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
545 .expect("Array creation failed");
546 let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
547
548 let mut rbf = RbfSurrogate::new(RbfOptions {
549 kernel: RbfKernel::Gaussian { sigma: 1.0 },
550 regularization: 1e-6,
551 normalize: true,
552 });
553
554 let result = rbf.fit(&x_train, &y_train);
555 assert!(result.is_ok());
556
557 let x = Array1::from_vec(vec![0.5, 0.5]);
559 let pred = rbf.predict(&x);
560 assert!(pred.is_ok());
561 let val = pred.expect("Gaussian RBF prediction failed");
562 assert!(val > -1.0 && val < 3.0, "Gaussian RBF prediction: {}", val);
564 }
565
566 #[test]
567 fn test_rbf_multiquadric() {
568 let x_train =
569 Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Array creation failed");
570 let y_train = Array1::from_vec(vec![1.0, 2.0, 5.0]);
571
572 let mut rbf = RbfSurrogate::new(RbfOptions {
573 kernel: RbfKernel::Multiquadric { shape_param: 1.0 },
574 regularization: 1e-8,
575 normalize: false,
576 });
577
578 assert!(rbf.fit(&x_train, &y_train).is_ok());
579
580 let pred = rbf.predict(&Array1::from_vec(vec![1.0]));
581 assert!(pred.is_ok());
582 }
583
584 #[test]
585 fn test_rbf_thin_plate_spline() {
586 let x_train = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
587 .expect("Array creation failed");
588 let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
589
590 let mut rbf = RbfSurrogate::new(RbfOptions {
591 kernel: RbfKernel::ThinPlateSpline,
592 regularization: 1e-6,
593 normalize: false,
594 });
595
596 assert!(rbf.fit(&x_train, &y_train).is_ok());
597
598 let pred = rbf.predict(&Array1::from_vec(vec![0.5, 0.5]));
599 assert!(pred.is_ok());
600 }
601
602 #[test]
603 fn test_rbf_uncertainty() {
604 let x_train =
605 Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Array creation failed");
606 let y_train = Array1::from_vec(vec![0.0, 1.0, 4.0]);
607
608 let mut rbf = RbfSurrogate::new(RbfOptions {
609 kernel: RbfKernel::Gaussian { sigma: 1.0 },
610 regularization: 1e-6,
611 normalize: true,
612 });
613 rbf.fit(&x_train, &y_train).expect("Fit failed");
614
615 let (_, unc_near) = rbf
617 .predict_with_uncertainty(&Array1::from_vec(vec![1.0]))
618 .expect("Prediction failed");
619 let (_, unc_far) = rbf
620 .predict_with_uncertainty(&Array1::from_vec(vec![5.0]))
621 .expect("Prediction failed");
622 assert!(
623 unc_far > unc_near,
624 "Far point uncertainty ({}) should be greater than near point ({})",
625 unc_far,
626 unc_near
627 );
628 }
629
630 #[test]
631 fn test_rbf_update() {
632 let x_train =
633 Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Array creation failed");
634 let y_train = Array1::from_vec(vec![0.0, 1.0, 4.0]);
635
636 let mut rbf = RbfSurrogate::new(RbfOptions::default());
637 rbf.fit(&x_train, &y_train).expect("Fit failed");
638 assert_eq!(rbf.n_samples(), 3);
639
640 rbf.update(&Array1::from_vec(vec![3.0]), 9.0)
642 .expect("Update failed");
643 assert_eq!(rbf.n_samples(), 4);
644 }
645
646 #[test]
647 fn test_rbf_inverse_multiquadric() {
648 let x_train =
649 Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Array creation failed");
650 let y_train = Array1::from_vec(vec![1.0, 2.0, 5.0]);
651
652 let mut rbf = RbfSurrogate::new(RbfOptions {
653 kernel: RbfKernel::InverseMultiquadric { shape_param: 1.0 },
654 regularization: 1e-6,
655 normalize: false,
656 });
657
658 assert!(rbf.fit(&x_train, &y_train).is_ok());
659 let pred = rbf.predict(&Array1::from_vec(vec![1.0]));
660 assert!(pred.is_ok());
661 }
662}