Skip to main content

scirs2_optimize/surrogate/
rbf_surrogate.rs

1//! Radial Basis Function (RBF) Surrogate Model
2//!
3//! RBF surrogates interpolate a set of data points using radial basis functions.
4//! They are well-suited for expensive black-box optimization because they provide
5//! smooth interpolation and can handle high-dimensional problems.
6//!
7//! ## Kernel Types
8//!
9//! - **Polyharmonic**: r^k (k=1: linear, k=3: cubic, k=5: quintic)
10//! - **Multiquadric**: sqrt(r^2 + c^2)
11//! - **InverseMultiquadric**: 1 / sqrt(r^2 + c^2)
12//! - **ThinPlateSpline**: r^2 * ln(r)
13//! - **Gaussian**: exp(-r^2 / (2 * sigma^2))
14//!
15//! ## References
16//!
17//! - Buhmann, M.D. (2003). Radial Basis Functions: Theory and Implementations.
18//! - Gutmann, H.-M. (2001). A Radial Basis Function Method for Global Optimization.
19
20use super::{pairwise_sq_distances, solve_general, SurrogateModel};
21use crate::error::{OptimizeError, OptimizeResult};
22use scirs2_core::ndarray::{Array1, Array2};
23
24/// RBF kernel type
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum RbfKernel {
27    /// Polyharmonic spline: r^k
28    /// k=1: linear, k=3: cubic, k=5: quintic
29    Polyharmonic(u32),
30    /// Multiquadric: sqrt(r^2 + c^2)
31    Multiquadric {
32        /// Shape parameter c
33        shape_param: f64,
34    },
35    /// Inverse multiquadric: 1 / sqrt(r^2 + c^2)
36    InverseMultiquadric {
37        /// Shape parameter c
38        shape_param: f64,
39    },
40    /// Thin-plate spline: r^2 * ln(r)
41    ThinPlateSpline,
42    /// Gaussian: exp(-r^2 / (2 * sigma^2))
43    Gaussian {
44        /// Bandwidth parameter sigma
45        sigma: f64,
46    },
47}
48
49impl Default for RbfKernel {
50    fn default() -> Self {
51        RbfKernel::Polyharmonic(3) // cubic
52    }
53}
54
55impl RbfKernel {
56    /// Evaluate the kernel function for a given squared distance
57    fn evaluate(&self, sq_dist: f64) -> f64 {
58        let r = sq_dist.sqrt();
59        match *self {
60            RbfKernel::Polyharmonic(k) => {
61                if r < 1e-30 {
62                    0.0
63                } else {
64                    r.powi(k as i32)
65                }
66            }
67            RbfKernel::Multiquadric { shape_param } => (sq_dist + shape_param * shape_param).sqrt(),
68            RbfKernel::InverseMultiquadric { shape_param } => {
69                1.0 / (sq_dist + shape_param * shape_param).sqrt()
70            }
71            RbfKernel::ThinPlateSpline => {
72                if r < 1e-30 {
73                    0.0
74                } else {
75                    sq_dist * r.ln()
76                }
77            }
78            RbfKernel::Gaussian { sigma } => (-sq_dist / (2.0 * sigma * sigma)).exp(),
79        }
80    }
81
82    /// Whether this kernel requires a polynomial tail for well-posedness
83    fn needs_polynomial_tail(&self) -> bool {
84        matches!(
85            self,
86            RbfKernel::Polyharmonic(_) | RbfKernel::ThinPlateSpline
87        )
88    }
89
90    /// Degree of polynomial tail needed
91    fn polynomial_degree(&self) -> usize {
92        match *self {
93            RbfKernel::Polyharmonic(k) => {
94                // For polyharmonic splines of order k, need polynomial of degree >= floor(k/2)
95                (k as usize) / 2
96            }
97            RbfKernel::ThinPlateSpline => 1,
98            _ => 0,
99        }
100    }
101}
102
103/// Options for RBF surrogate
104#[derive(Debug, Clone)]
105pub struct RbfOptions {
106    /// RBF kernel to use
107    pub kernel: RbfKernel,
108    /// Regularization parameter (nugget) for numerical stability
109    pub regularization: f64,
110    /// Whether to normalize the training data
111    pub normalize: bool,
112}
113
114impl Default for RbfOptions {
115    fn default() -> Self {
116        Self {
117            kernel: RbfKernel::default(),
118            regularization: 1e-10,
119            normalize: true,
120        }
121    }
122}
123
124/// RBF Surrogate Model
125pub struct RbfSurrogate {
126    options: RbfOptions,
127    /// Training points, shape (n_samples, n_features)
128    x_train: Option<Array2<f64>>,
129    /// Training values, shape (n_samples,)
130    y_train: Option<Array1<f64>>,
131    /// RBF weights (alpha)
132    weights: Option<Array1<f64>>,
133    /// Polynomial coefficients (if polynomial tail is used)
134    poly_coeffs: Option<Array1<f64>>,
135    /// Normalization parameters
136    x_mean: Option<Array1<f64>>,
137    x_std: Option<Array1<f64>>,
138    y_mean: f64,
139    y_std: f64,
140    /// Cached kernel matrix for uncertainty estimation
141    kernel_matrix: Option<Array2<f64>>,
142}
143
144impl RbfSurrogate {
145    /// Create a new RBF surrogate
146    pub fn new(options: RbfOptions) -> Self {
147        Self {
148            options,
149            x_train: None,
150            y_train: None,
151            weights: None,
152            poly_coeffs: None,
153            x_mean: None,
154            x_std: None,
155            y_mean: 0.0,
156            y_std: 1.0,
157            kernel_matrix: None,
158        }
159    }
160
161    /// Compute the kernel matrix for given points
162    fn compute_kernel_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
163        let n = x.nrows();
164        let sq_dists = pairwise_sq_distances(x, x);
165        let mut kernel = Array2::zeros((n, n));
166        for i in 0..n {
167            for j in 0..n {
168                kernel[[i, j]] = self.options.kernel.evaluate(sq_dists[[i, j]]);
169            }
170        }
171        kernel
172    }
173
174    /// Compute kernel vector between a point and training points
175    fn compute_kernel_vector(&self, x: &Array1<f64>, x_train: &Array2<f64>) -> Array1<f64> {
176        let n = x_train.nrows();
177        let mut k_vec = Array1::zeros(n);
178        for i in 0..n {
179            let mut sq_dist = 0.0;
180            for j in 0..x.len() {
181                let diff = x[j] - x_train[[i, j]];
182                sq_dist += diff * diff;
183            }
184            k_vec[i] = self.options.kernel.evaluate(sq_dist);
185        }
186        k_vec
187    }
188
189    /// Build the polynomial matrix for the polynomial tail
190    fn build_polynomial_matrix(&self, x: &Array2<f64>, degree: usize) -> Array2<f64> {
191        let n = x.nrows();
192        let d = x.ncols();
193
194        if degree == 0 {
195            // Just a constant term
196            Array2::ones((n, 1))
197        } else if degree == 1 {
198            // Constant + linear terms
199            let ncols = 1 + d;
200            let mut p = Array2::zeros((n, ncols));
201            for i in 0..n {
202                p[[i, 0]] = 1.0;
203                for j in 0..d {
204                    p[[i, j + 1]] = x[[i, j]];
205                }
206            }
207            p
208        } else {
209            // For higher degrees, just use up to linear (simplification)
210            let ncols = 1 + d;
211            let mut p = Array2::zeros((n, ncols));
212            for i in 0..n {
213                p[[i, 0]] = 1.0;
214                for j in 0..d {
215                    p[[i, j + 1]] = x[[i, j]];
216                }
217            }
218            p
219        }
220    }
221
222    /// Normalize x data
223    fn normalize_x(&self, x: &Array2<f64>) -> Array2<f64> {
224        if let (Some(ref mean), Some(ref std)) = (&self.x_mean, &self.x_std) {
225            let mut normalized = x.clone();
226            for i in 0..x.nrows() {
227                for j in 0..x.ncols() {
228                    let s = if std[j] > 1e-30 { std[j] } else { 1.0 };
229                    normalized[[i, j]] = (x[[i, j]] - mean[j]) / s;
230                }
231            }
232            normalized
233        } else {
234            x.clone()
235        }
236    }
237
238    /// Normalize a single x point
239    fn normalize_x_point(&self, x: &Array1<f64>) -> Array1<f64> {
240        if let (Some(ref mean), Some(ref std)) = (&self.x_mean, &self.x_std) {
241            let mut normalized = x.clone();
242            for j in 0..x.len() {
243                let s = if std[j] > 1e-30 { std[j] } else { 1.0 };
244                normalized[j] = (x[j] - mean[j]) / s;
245            }
246            normalized
247        } else {
248            x.clone()
249        }
250    }
251}
252
253impl SurrogateModel for RbfSurrogate {
254    fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
255        let n = x.nrows();
256        let d = x.ncols();
257
258        if n < d + 1 {
259            return Err(OptimizeError::InvalidInput(format!(
260                "Need at least {} data points for {} dimensions, got {}",
261                d + 1,
262                d,
263                n
264            )));
265        }
266
267        // Compute normalization parameters
268        if self.options.normalize {
269            let mut x_mean = Array1::zeros(d);
270            let mut x_std = Array1::zeros(d);
271            for j in 0..d {
272                let mut sum = 0.0;
273                for i in 0..n {
274                    sum += x[[i, j]];
275                }
276                x_mean[j] = sum / n as f64;
277
278                let mut sq_sum = 0.0;
279                for i in 0..n {
280                    let diff = x[[i, j]] - x_mean[j];
281                    sq_sum += diff * diff;
282                }
283                x_std[j] = (sq_sum / n as f64).sqrt();
284                if x_std[j] < 1e-30 {
285                    x_std[j] = 1.0;
286                }
287            }
288
289            self.x_mean = Some(x_mean);
290            self.x_std = Some(x_std);
291
292            let y_sum: f64 = y.iter().sum();
293            self.y_mean = y_sum / n as f64;
294            let y_var: f64 = y.iter().map(|yi| (yi - self.y_mean).powi(2)).sum::<f64>() / n as f64;
295            self.y_std = y_var.sqrt();
296            if self.y_std < 1e-30 {
297                self.y_std = 1.0;
298            }
299        }
300
301        // Normalize training data
302        let x_norm = self.normalize_x(x);
303        let y_norm: Array1<f64> = if self.options.normalize {
304            y.mapv(|yi| (yi - self.y_mean) / self.y_std)
305        } else {
306            y.clone()
307        };
308
309        // Compute kernel matrix
310        let mut kernel = self.compute_kernel_matrix(&x_norm);
311
312        // Add regularization
313        for i in 0..n {
314            kernel[[i, i]] += self.options.regularization;
315        }
316
317        self.kernel_matrix = Some(kernel.clone());
318
319        if self.options.kernel.needs_polynomial_tail() {
320            let degree = self.options.kernel.polynomial_degree();
321            let p = self.build_polynomial_matrix(&x_norm, degree);
322            let m = p.ncols();
323
324            // Solve the augmented system:
325            // [K  P] [alpha]   [y]
326            // [P' 0] [beta ] = [0]
327            let total = n + m;
328            let mut aug = Array2::zeros((total, total));
329            for i in 0..n {
330                for j in 0..n {
331                    aug[[i, j]] = kernel[[i, j]];
332                }
333                for j in 0..m {
334                    aug[[i, n + j]] = p[[i, j]];
335                    aug[[n + j, i]] = p[[i, j]];
336                }
337            }
338
339            let mut rhs = Array1::zeros(total);
340            for i in 0..n {
341                rhs[i] = y_norm[i];
342            }
343
344            let solution = solve_general(&aug, &rhs)?;
345            self.weights = Some(solution.slice(scirs2_core::ndarray::s![..n]).to_owned());
346            self.poly_coeffs = Some(solution.slice(scirs2_core::ndarray::s![n..]).to_owned());
347        } else {
348            // Solve K * alpha = y
349            let weights = solve_general(&kernel, &y_norm)?;
350            self.weights = Some(weights);
351            self.poly_coeffs = None;
352        }
353
354        self.x_train = Some(x_norm);
355        self.y_train = Some(y_norm);
356
357        Ok(())
358    }
359
360    fn predict(&self, x: &Array1<f64>) -> OptimizeResult<f64> {
361        let x_train = self
362            .x_train
363            .as_ref()
364            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
365        let weights = self
366            .weights
367            .as_ref()
368            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
369
370        let x_norm = self.normalize_x_point(x);
371        let k_vec = self.compute_kernel_vector(&x_norm, x_train);
372
373        let mut prediction = k_vec.dot(weights);
374
375        // Add polynomial tail contribution
376        if let Some(ref poly_coeffs) = self.poly_coeffs {
377            let d = x_norm.len();
378            // Constant term
379            prediction += poly_coeffs[0];
380            // Linear terms
381            for j in 0..d.min(poly_coeffs.len() - 1) {
382                prediction += poly_coeffs[j + 1] * x_norm[j];
383            }
384        }
385
386        // Denormalize
387        if self.options.normalize {
388            prediction = prediction * self.y_std + self.y_mean;
389        }
390
391        Ok(prediction)
392    }
393
394    fn predict_with_uncertainty(&self, x: &Array1<f64>) -> OptimizeResult<(f64, f64)> {
395        let mean = self.predict(x)?;
396
397        // Estimate uncertainty using leave-one-out cross-validation approximation
398        // Simple heuristic: distance-based uncertainty
399        let x_train = self
400            .x_train
401            .as_ref()
402            .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
403
404        let x_norm = self.normalize_x_point(x);
405        let n = x_train.nrows();
406
407        // Compute minimum distance to training points
408        let mut min_dist = f64::INFINITY;
409        let mut sum_inv_dist = 0.0;
410        for i in 0..n {
411            let mut sq_dist = 0.0;
412            for j in 0..x_norm.len() {
413                let diff = x_norm[j] - x_train[[i, j]];
414                sq_dist += diff * diff;
415            }
416            let dist = sq_dist.sqrt();
417            if dist < min_dist {
418                min_dist = dist;
419            }
420            if dist > 1e-30 {
421                sum_inv_dist += 1.0 / dist;
422            }
423        }
424
425        // Uncertainty is proportional to distance from training data
426        // Normalize by the average distance scale
427        let avg_inv_dist = if n > 0 { sum_inv_dist / n as f64 } else { 1.0 };
428        let uncertainty = if avg_inv_dist > 1e-30 {
429            min_dist * avg_inv_dist
430        } else {
431            min_dist
432        };
433
434        // Scale by y_std
435        let scaled_uncertainty = uncertainty * self.y_std;
436
437        Ok((mean, scaled_uncertainty.max(1e-10)))
438    }
439
440    fn n_samples(&self) -> usize {
441        self.x_train.as_ref().map_or(0, |x| x.nrows())
442    }
443
444    fn n_features(&self) -> usize {
445        self.x_train.as_ref().map_or(0, |x| x.ncols())
446    }
447
448    fn update(&mut self, x: &Array1<f64>, y: f64) -> OptimizeResult<()> {
449        // Refit with the new point added
450        let (new_x, new_y) = if let (Some(ref x_train), Some(ref y_train)) =
451            (&self.x_train, &self.y_train)
452        {
453            // Denormalize existing data
454            let d = x_train.ncols();
455            let n = x_train.nrows();
456            let mut x_denorm = Array2::zeros((n, d));
457            for i in 0..n {
458                for j in 0..d {
459                    if self.options.normalize {
460                        let s =
461                            self.x_std
462                                .as_ref()
463                                .map_or(1.0, |s| if s[j] > 1e-30 { s[j] } else { 1.0 });
464                        let m = self.x_mean.as_ref().map_or(0.0, |m| m[j]);
465                        x_denorm[[i, j]] = x_train[[i, j]] * s + m;
466                    } else {
467                        x_denorm[[i, j]] = x_train[[i, j]];
468                    }
469                }
470            }
471            let y_denorm: Array1<f64> = if self.options.normalize {
472                y_train.mapv(|yi| yi * self.y_std + self.y_mean)
473            } else {
474                y_train.clone()
475            };
476
477            // Append new point
478            let mut new_x = Array2::zeros((n + 1, d));
479            for i in 0..n {
480                for j in 0..d {
481                    new_x[[i, j]] = x_denorm[[i, j]];
482                }
483            }
484            for j in 0..d {
485                new_x[[n, j]] = x[j];
486            }
487
488            let mut new_y = Array1::zeros(n + 1);
489            for i in 0..n {
490                new_y[i] = y_denorm[i];
491            }
492            new_y[n] = y;
493
494            (new_x, new_y)
495        } else {
496            let d = x.len();
497            let mut new_x = Array2::zeros((1, d));
498            for j in 0..d {
499                new_x[[0, j]] = x[j];
500            }
501            let new_y = Array1::from_vec(vec![y]);
502            (new_x, new_y)
503        };
504
505        self.fit(&new_x, &new_y)
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    #[test]
514    fn test_rbf_cubic_interpolation() {
515        let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0])
516            .expect("Array creation failed");
517        let y_train = Array1::from_vec(vec![0.0, 1.0, 4.0, 9.0, 16.0]);
518
519        let mut rbf = RbfSurrogate::new(RbfOptions {
520            kernel: RbfKernel::Polyharmonic(3),
521            regularization: 1e-8,
522            normalize: false,
523        });
524
525        let result = rbf.fit(&x_train, &y_train);
526        assert!(result.is_ok(), "RBF fit failed: {:?}", result.err());
527
528        // Predict at training points (should interpolate)
529        for i in 0..5 {
530            let x = Array1::from_vec(vec![i as f64]);
531            let pred = rbf.predict(&x).expect("Prediction failed");
532            assert!(
533                (pred - y_train[i]).abs() < 0.5,
534                "Interpolation error at {}: pred={}, actual={}",
535                i,
536                pred,
537                y_train[i]
538            );
539        }
540    }
541
542    #[test]
543    fn test_rbf_gaussian() {
544        let x_train = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
545            .expect("Array creation failed");
546        let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
547
548        let mut rbf = RbfSurrogate::new(RbfOptions {
549            kernel: RbfKernel::Gaussian { sigma: 1.0 },
550            regularization: 1e-6,
551            normalize: true,
552        });
553
554        let result = rbf.fit(&x_train, &y_train);
555        assert!(result.is_ok());
556
557        // Predict at a middle point
558        let x = Array1::from_vec(vec![0.5, 0.5]);
559        let pred = rbf.predict(&x);
560        assert!(pred.is_ok());
561        let val = pred.expect("Gaussian RBF prediction failed");
562        // Should be roughly 1.0 (average of corner values)
563        assert!(val > -1.0 && val < 3.0, "Gaussian RBF prediction: {}", val);
564    }
565
566    #[test]
567    fn test_rbf_multiquadric() {
568        let x_train =
569            Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Array creation failed");
570        let y_train = Array1::from_vec(vec![1.0, 2.0, 5.0]);
571
572        let mut rbf = RbfSurrogate::new(RbfOptions {
573            kernel: RbfKernel::Multiquadric { shape_param: 1.0 },
574            regularization: 1e-8,
575            normalize: false,
576        });
577
578        assert!(rbf.fit(&x_train, &y_train).is_ok());
579
580        let pred = rbf.predict(&Array1::from_vec(vec![1.0]));
581        assert!(pred.is_ok());
582    }
583
584    #[test]
585    fn test_rbf_thin_plate_spline() {
586        let x_train = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
587            .expect("Array creation failed");
588        let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
589
590        let mut rbf = RbfSurrogate::new(RbfOptions {
591            kernel: RbfKernel::ThinPlateSpline,
592            regularization: 1e-6,
593            normalize: false,
594        });
595
596        assert!(rbf.fit(&x_train, &y_train).is_ok());
597
598        let pred = rbf.predict(&Array1::from_vec(vec![0.5, 0.5]));
599        assert!(pred.is_ok());
600    }
601
602    #[test]
603    fn test_rbf_uncertainty() {
604        let x_train =
605            Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Array creation failed");
606        let y_train = Array1::from_vec(vec![0.0, 1.0, 4.0]);
607
608        let mut rbf = RbfSurrogate::new(RbfOptions {
609            kernel: RbfKernel::Gaussian { sigma: 1.0 },
610            regularization: 1e-6,
611            normalize: true,
612        });
613        rbf.fit(&x_train, &y_train).expect("Fit failed");
614
615        // Uncertainty at a training point should be lower than far away
616        let (_, unc_near) = rbf
617            .predict_with_uncertainty(&Array1::from_vec(vec![1.0]))
618            .expect("Prediction failed");
619        let (_, unc_far) = rbf
620            .predict_with_uncertainty(&Array1::from_vec(vec![5.0]))
621            .expect("Prediction failed");
622        assert!(
623            unc_far > unc_near,
624            "Far point uncertainty ({}) should be greater than near point ({})",
625            unc_far,
626            unc_near
627        );
628    }
629
630    #[test]
631    fn test_rbf_update() {
632        let x_train =
633            Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Array creation failed");
634        let y_train = Array1::from_vec(vec![0.0, 1.0, 4.0]);
635
636        let mut rbf = RbfSurrogate::new(RbfOptions::default());
637        rbf.fit(&x_train, &y_train).expect("Fit failed");
638        assert_eq!(rbf.n_samples(), 3);
639
640        // Add a new point
641        rbf.update(&Array1::from_vec(vec![3.0]), 9.0)
642            .expect("Update failed");
643        assert_eq!(rbf.n_samples(), 4);
644    }
645
646    #[test]
647    fn test_rbf_inverse_multiquadric() {
648        let x_train =
649            Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Array creation failed");
650        let y_train = Array1::from_vec(vec![1.0, 2.0, 5.0]);
651
652        let mut rbf = RbfSurrogate::new(RbfOptions {
653            kernel: RbfKernel::InverseMultiquadric { shape_param: 1.0 },
654            regularization: 1e-6,
655            normalize: false,
656        });
657
658        assert!(rbf.fit(&x_train, &y_train).is_ok());
659        let pred = rbf.predict(&Array1::from_vec(vec![1.0]));
660        assert!(pred.is_ok());
661    }
662}