Skip to main content

sklears_kernel_approximation/
multi_scale_rbf.rs

1//! Multi-scale RBF kernel approximation methods
2//!
3//! This module implements multi-scale RBF kernel approximation that uses multiple bandwidth
4//! parameters to capture patterns at different scales. This provides better approximation
5//! quality for data with features at multiple scales.
6
7use scirs2_core::ndarray::{s, Array1, Array2};
8use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::RngExt;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Fit, Trained, Transform, Untrained},
15    types::Float,
16};
17use std::marker::PhantomData;
18
19/// Multi-scale bandwidth selection strategy
20#[derive(Debug, Clone, Copy)]
21/// BandwidthStrategy
22pub enum BandwidthStrategy {
23    /// Manual specification of gamma values
24    Manual,
25    /// Logarithmic spacing: gamma_i = gamma_min * (gamma_max/gamma_min)^(i/(n_scales-1))
26    LogarithmicSpacing,
27    /// Linear spacing: gamma_i = gamma_min + i * (gamma_max - gamma_min) / (n_scales - 1)
28    LinearSpacing,
29    /// Geometric progression: gamma_i = gamma_min * ratio^i
30    GeometricProgression,
31    /// Adaptive spacing based on data characteristics
32    Adaptive,
33}
34
35/// Feature combination strategy for multi-scale features
36#[derive(Debug, Clone, Copy)]
37/// CombinationStrategy
38pub enum CombinationStrategy {
39    /// Concatenate features from all scales
40    Concatenation,
41    /// Weighted average of features across scales
42    WeightedAverage,
43    /// Maximum response across scales
44    MaxPooling,
45    /// Average pooling across scales
46    AveragePooling,
47    /// Attention-based combination
48    Attention,
49}
50
51/// Multi-scale RBF sampler that generates random Fourier features at multiple scales
52///
53/// This sampler generates RBF kernel approximations using multiple bandwidth parameters
54/// (gamma values) to capture patterns at different scales. Each scale captures different
55/// frequency characteristics of the data, providing a more comprehensive representation.
56///
57/// # Mathematical Background
58///
59/// For multiple scales with bandwidths γ₁, γ₂, ..., γₖ, the multi-scale RBF kernel is:
60/// K(x,y) = Σᵢ wᵢ * exp(-γᵢ||x-y||²)
61///
62/// The random Fourier features for each scale i are:
63/// zᵢ(x) = √(2/nᵢ) * [cos(ωᵢⱼᵀx + bᵢⱼ), sin(ωᵢⱼᵀx + bᵢⱼ)]
64/// where ωᵢⱼ ~ N(0, 2γᵢI) and bᵢⱼ ~ Uniform[0, 2π]
65///
66/// # Examples
67///
68/// ```rust,ignore
69/// use sklears_kernel_approximation::{MultiScaleRBFSampler, BandwidthStrategy, CombinationStrategy};
70/// use sklears_core::traits::{Transform, Fit, Untrained}
71/// use scirs2_core::ndarray::array;
72///
73/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
74///
75/// let sampler = MultiScaleRBFSampler::new(100)
76///     .n_scales(3)
77///     .gamma_range(0.1, 10.0)
78///     .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
79///     .combination_strategy(CombinationStrategy::Concatenation);
80///
81/// let fitted = sampler.fit(&x, &()).unwrap();
82/// let features = fitted.transform(&x).unwrap();
83/// ```
84#[derive(Debug, Clone)]
85/// MultiScaleRBFSampler
86pub struct MultiScaleRBFSampler<State = Untrained> {
87    /// Number of components per scale
88    pub n_components_per_scale: usize,
89    /// Number of scales
90    pub n_scales: usize,
91    /// Minimum gamma value
92    pub gamma_min: Float,
93    /// Maximum gamma value
94    pub gamma_max: Float,
95    /// Manual gamma values (used when strategy is Manual)
96    pub manual_gammas: Vec<Float>,
97    /// Bandwidth selection strategy
98    pub bandwidth_strategy: BandwidthStrategy,
99    /// Feature combination strategy
100    pub combination_strategy: CombinationStrategy,
101    /// Scale weights for weighted combination
102    pub scale_weights: Vec<Float>,
103    /// Random seed for reproducibility
104    pub random_state: Option<u64>,
105
106    // Fitted attributes
107    gammas_: Option<Vec<Float>>,
108    random_weights_: Option<Vec<Array2<Float>>>,
109    random_offsets_: Option<Vec<Array1<Float>>>,
110    attention_weights_: Option<Array1<Float>>, // For attention-based combination
111
112    // State marker
113    _state: PhantomData<State>,
114}
115
116impl MultiScaleRBFSampler<Untrained> {
117    /// Create a new multi-scale RBF sampler
118    ///
119    /// # Arguments
120    /// * `n_components_per_scale` - Number of random features per scale
121    pub fn new(n_components_per_scale: usize) -> Self {
122        Self {
123            n_components_per_scale,
124            n_scales: 3,
125            gamma_min: 0.1,
126            gamma_max: 10.0,
127            manual_gammas: vec![],
128            bandwidth_strategy: BandwidthStrategy::LogarithmicSpacing,
129            combination_strategy: CombinationStrategy::Concatenation,
130            scale_weights: vec![],
131            random_state: None,
132            gammas_: None,
133            random_weights_: None,
134            random_offsets_: None,
135            attention_weights_: None,
136            _state: PhantomData,
137        }
138    }
139
140    /// Set the number of scales
141    pub fn n_scales(mut self, n_scales: usize) -> Self {
142        self.n_scales = n_scales;
143        self
144    }
145
146    /// Set the gamma range for automatic bandwidth selection
147    pub fn gamma_range(mut self, gamma_min: Float, gamma_max: Float) -> Self {
148        self.gamma_min = gamma_min;
149        self.gamma_max = gamma_max;
150        self
151    }
152
153    /// Set manual gamma values
154    pub fn manual_gammas(mut self, gammas: Vec<Float>) -> Self {
155        self.n_scales = gammas.len();
156        self.manual_gammas = gammas;
157        self.bandwidth_strategy = BandwidthStrategy::Manual;
158        self
159    }
160
161    /// Set the bandwidth selection strategy
162    pub fn bandwidth_strategy(mut self, strategy: BandwidthStrategy) -> Self {
163        self.bandwidth_strategy = strategy;
164        self
165    }
166
167    /// Set the feature combination strategy
168    pub fn combination_strategy(mut self, strategy: CombinationStrategy) -> Self {
169        self.combination_strategy = strategy;
170        self
171    }
172
173    /// Set weights for different scales (used in weighted combination)
174    pub fn scale_weights(mut self, weights: Vec<Float>) -> Self {
175        self.scale_weights = weights;
176        self
177    }
178
179    /// Set the random state for reproducibility
180    pub fn random_state(mut self, seed: u64) -> Self {
181        self.random_state = Some(seed);
182        self
183    }
184
185    /// Compute gamma values based on the selected strategy
186    fn compute_gammas(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
187        match self.bandwidth_strategy {
188            BandwidthStrategy::Manual => {
189                if self.manual_gammas.is_empty() {
190                    return Err(SklearsError::InvalidParameter {
191                        name: "manual_gammas".to_string(),
192                        reason: "manual gammas not provided".to_string(),
193                    });
194                }
195                Ok(self.manual_gammas.clone())
196            }
197            BandwidthStrategy::LogarithmicSpacing => {
198                let mut gammas = Vec::with_capacity(self.n_scales);
199                if self.n_scales == 1 {
200                    gammas.push((self.gamma_min * self.gamma_max).sqrt());
201                } else {
202                    let log_min = self.gamma_min.ln();
203                    let log_max = self.gamma_max.ln();
204                    for i in 0..self.n_scales {
205                        let t = i as Float / (self.n_scales - 1) as Float;
206                        let log_gamma = log_min + t * (log_max - log_min);
207                        gammas.push(log_gamma.exp());
208                    }
209                }
210                Ok(gammas)
211            }
212            BandwidthStrategy::LinearSpacing => {
213                let mut gammas = Vec::with_capacity(self.n_scales);
214                if self.n_scales == 1 {
215                    gammas.push((self.gamma_min + self.gamma_max) / 2.0);
216                } else {
217                    for i in 0..self.n_scales {
218                        let t = i as Float / (self.n_scales - 1) as Float;
219                        let gamma = self.gamma_min + t * (self.gamma_max - self.gamma_min);
220                        gammas.push(gamma);
221                    }
222                }
223                Ok(gammas)
224            }
225            BandwidthStrategy::GeometricProgression => {
226                let mut gammas = Vec::with_capacity(self.n_scales);
227                let ratio = if self.n_scales == 1 {
228                    1.0
229                } else {
230                    (self.gamma_max / self.gamma_min).powf(1.0 / (self.n_scales - 1) as Float)
231                };
232                for i in 0..self.n_scales {
233                    let gamma = self.gamma_min * ratio.powi(i as i32);
234                    gammas.push(gamma);
235                }
236                Ok(gammas)
237            }
238            BandwidthStrategy::Adaptive => {
239                // Adaptive bandwidth selection based on data characteristics
240                self.compute_adaptive_gammas(x)
241            }
242        }
243    }
244
245    /// Compute adaptive gamma values based on data characteristics
246    fn compute_adaptive_gammas(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
247        let (n_samples, _n_features) = x.dim();
248
249        if n_samples < 2 {
250            return Err(SklearsError::InvalidInput(
251                "Need at least 2 samples for adaptive bandwidth selection".to_string(),
252            ));
253        }
254
255        // Compute pairwise distances for a subset of points
256        let n_subset = n_samples.min(100);
257        let mut distances = Vec::new();
258
259        for i in 0..n_subset {
260            for j in (i + 1)..n_subset {
261                let diff = &x.row(i) - &x.row(j);
262                let dist_sq = diff.mapv(|x| x * x).sum();
263                distances.push(dist_sq.sqrt());
264            }
265        }
266
267        if distances.is_empty() {
268            return Ok(vec![1.0; self.n_scales]);
269        }
270
271        // Sort distances
272        distances.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
273
274        // Use percentiles to determine scales
275        let mut gammas = Vec::with_capacity(self.n_scales);
276        for i in 0..self.n_scales {
277            let percentile = if self.n_scales == 1 {
278                0.5
279            } else {
280                i as Float / (self.n_scales - 1) as Float
281            };
282
283            let idx = ((distances.len() - 1) as Float * percentile) as usize;
284            let characteristic_distance = distances[idx];
285
286            // gamma = 1 / (2 * sigma^2), where sigma is related to characteristic distance
287            let gamma = if characteristic_distance > 0.0 {
288                1.0 / (2.0 * characteristic_distance * characteristic_distance)
289            } else {
290                1.0
291            };
292
293            gammas.push(gamma);
294        }
295
296        Ok(gammas)
297    }
298}
299
300impl Fit<Array2<Float>, ()> for MultiScaleRBFSampler<Untrained> {
301    type Fitted = MultiScaleRBFSampler<Trained>;
302
303    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
304        let (n_samples, n_features) = x.dim();
305
306        if n_samples == 0 || n_features == 0 {
307            return Err(SklearsError::InvalidInput(
308                "Input array is empty".to_string(),
309            ));
310        }
311
312        if self.n_scales == 0 {
313            return Err(SklearsError::InvalidParameter {
314                name: "n_scales".to_string(),
315                reason: "must be positive".to_string(),
316            });
317        }
318
319        let mut rng = match self.random_state {
320            Some(seed) => RealStdRng::seed_from_u64(seed),
321            None => RealStdRng::from_seed(thread_rng().random()),
322        };
323
324        // Compute gamma values for each scale
325        let gammas = self.compute_gammas(x)?;
326
327        // Generate random weights and offsets for each scale
328        let mut random_weights = Vec::with_capacity(self.n_scales);
329        let mut random_offsets = Vec::with_capacity(self.n_scales);
330
331        for &gamma in &gammas {
332            // Generate random weights ~ N(0, 2*gamma*I)
333            let std_dev = (2.0 * gamma).sqrt();
334            let mut weights = Array2::zeros((self.n_components_per_scale, n_features));
335            for i in 0..self.n_components_per_scale {
336                for j in 0..n_features {
337                    weights[[i, j]] =
338                        rng.sample::<Float, _>(RandNormal::new(0.0, std_dev).map_err(|e| {
339                            SklearsError::NumericalError(format!(
340                                "Error creating normal distribution: {}",
341                                e
342                            ))
343                        })?);
344                }
345            }
346
347            // Generate random offsets ~ Uniform[0, 2π]
348            let mut offsets = Array1::zeros(self.n_components_per_scale);
349            for i in 0..self.n_components_per_scale {
350                offsets[i] = rng.sample::<Float, _>(
351                    RandUniform::new(0.0, 2.0 * std::f64::consts::PI)
352                        .expect("operation should succeed"),
353                );
354            }
355
356            random_weights.push(weights);
357            random_offsets.push(offsets);
358        }
359
360        // Compute attention weights if using attention-based combination
361        let attention_weights =
362            if matches!(self.combination_strategy, CombinationStrategy::Attention) {
363                Some(compute_attention_weights(&gammas)?)
364            } else {
365                None
366            };
367
368        Ok(MultiScaleRBFSampler {
369            n_components_per_scale: self.n_components_per_scale,
370            n_scales: self.n_scales,
371            gamma_min: self.gamma_min,
372            gamma_max: self.gamma_max,
373            manual_gammas: self.manual_gammas,
374            bandwidth_strategy: self.bandwidth_strategy,
375            combination_strategy: self.combination_strategy,
376            scale_weights: self.scale_weights,
377            random_state: self.random_state,
378            gammas_: Some(gammas),
379            random_weights_: Some(random_weights),
380            random_offsets_: Some(random_offsets),
381            attention_weights_: attention_weights,
382            _state: PhantomData,
383        })
384    }
385}
386
387/// Compute attention weights based on gamma values
388fn compute_attention_weights(gammas: &[Float]) -> Result<Array1<Float>> {
389    // Simple attention mechanism: higher gamma (smaller scale) gets more weight
390    let weights: Vec<Float> = gammas.iter().map(|&g| g.ln()).collect();
391    let weights_array = Array1::from(weights);
392
393    // Softmax normalization
394    let max_weight = weights_array
395        .iter()
396        .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
397    let exp_weights = weights_array.mapv(|w| (w - max_weight).exp());
398    let sum_exp = exp_weights.sum();
399
400    Ok(exp_weights.mapv(|w| w / sum_exp))
401}
402
403impl Transform<Array2<Float>> for MultiScaleRBFSampler<Trained> {
404    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
405        let _gammas = self
406            .gammas_
407            .as_ref()
408            .ok_or_else(|| SklearsError::NotFitted {
409                operation: "transform".to_string(),
410            })?;
411
412        let random_weights =
413            self.random_weights_
414                .as_ref()
415                .ok_or_else(|| SklearsError::NotFitted {
416                    operation: "transform".to_string(),
417                })?;
418
419        let random_offsets =
420            self.random_offsets_
421                .as_ref()
422                .ok_or_else(|| SklearsError::NotFitted {
423                    operation: "transform".to_string(),
424                })?;
425
426        let (_n_samples, n_features) = x.dim();
427
428        // Generate features for each scale
429        let mut scale_features = Vec::with_capacity(self.n_scales);
430
431        for i in 0..self.n_scales {
432            let weights = &random_weights[i];
433            let offsets = &random_offsets[i];
434
435            if n_features != weights.ncols() {
436                return Err(SklearsError::InvalidInput(format!(
437                    "Input has {} features, expected {}",
438                    n_features,
439                    weights.ncols()
440                )));
441            }
442
443            // Compute X @ W.T + b
444            let projection = x.dot(&weights.t()) + offsets;
445
446            // Apply cosine transformation and normalize
447            let normalization = (2.0 / weights.nrows() as Float).sqrt();
448            let features = projection.mapv(|x| x.cos() * normalization);
449
450            scale_features.push(features);
451        }
452
453        // Combine features across scales
454        match self.combination_strategy {
455            CombinationStrategy::Concatenation => self.concatenate_features(scale_features),
456            CombinationStrategy::WeightedAverage => self.weighted_average_features(scale_features),
457            CombinationStrategy::MaxPooling => self.max_pooling_features(scale_features),
458            CombinationStrategy::AveragePooling => self.average_pooling_features(scale_features),
459            CombinationStrategy::Attention => self.attention_combine_features(scale_features),
460        }
461    }
462}
463
464impl MultiScaleRBFSampler<Trained> {
465    /// Concatenate features from all scales
466    fn concatenate_features(&self, scale_features: Vec<Array2<Float>>) -> Result<Array2<Float>> {
467        if scale_features.is_empty() {
468            return Err(SklearsError::InvalidInput(
469                "No scale features to concatenate".to_string(),
470            ));
471        }
472
473        let n_samples = scale_features[0].nrows();
474        let total_features: usize = scale_features.iter().map(|f| f.ncols()).sum();
475
476        let mut result = Array2::zeros((n_samples, total_features));
477        let mut col_offset = 0;
478
479        for features in scale_features {
480            let n_cols = features.ncols();
481            result
482                .slice_mut(s![.., col_offset..col_offset + n_cols])
483                .assign(&features);
484            col_offset += n_cols;
485        }
486
487        Ok(result)
488    }
489
490    /// Compute weighted average of features across scales
491    fn weighted_average_features(
492        &self,
493        scale_features: Vec<Array2<Float>>,
494    ) -> Result<Array2<Float>> {
495        if scale_features.is_empty() {
496            return Err(SklearsError::InvalidInput(
497                "No scale features to average".to_string(),
498            ));
499        }
500
501        let weights = if self.scale_weights.is_empty() {
502            // Equal weights
503            vec![1.0 / self.n_scales as Float; self.n_scales]
504        } else {
505            // Normalize provided weights
506            let sum: Float = self.scale_weights.iter().sum();
507            self.scale_weights.iter().map(|&w| w / sum).collect()
508        };
509
510        let mut result = scale_features[0].clone() * weights[0];
511        for (i, features) in scale_features.iter().enumerate().skip(1) {
512            result = result + features * weights[i];
513        }
514
515        Ok(result)
516    }
517
518    /// Apply max pooling across scales
519    fn max_pooling_features(&self, scale_features: Vec<Array2<Float>>) -> Result<Array2<Float>> {
520        if scale_features.is_empty() {
521            return Err(SklearsError::InvalidInput(
522                "No scale features for max pooling".to_string(),
523            ));
524        }
525
526        let mut result = scale_features[0].clone();
527        for features in scale_features.iter().skip(1) {
528            for ((i, j), val) in features.indexed_iter() {
529                if *val > result[[i, j]] {
530                    result[[i, j]] = *val;
531                }
532            }
533        }
534
535        Ok(result)
536    }
537
538    /// Apply average pooling across scales
539    fn average_pooling_features(
540        &self,
541        scale_features: Vec<Array2<Float>>,
542    ) -> Result<Array2<Float>> {
543        if scale_features.is_empty() {
544            return Err(SklearsError::InvalidInput(
545                "No scale features for average pooling".to_string(),
546            ));
547        }
548
549        let mut result = scale_features[0].clone();
550        for features in scale_features.iter().skip(1) {
551            result += features;
552        }
553
554        result.mapv_inplace(|x| x / self.n_scales as Float);
555        Ok(result)
556    }
557
558    /// Apply attention-based combination of features
559    fn attention_combine_features(
560        &self,
561        scale_features: Vec<Array2<Float>>,
562    ) -> Result<Array2<Float>> {
563        if scale_features.is_empty() {
564            return Err(SklearsError::InvalidInput(
565                "No scale features for attention combination".to_string(),
566            ));
567        }
568
569        let attention_weights =
570            self.attention_weights_
571                .as_ref()
572                .ok_or_else(|| SklearsError::NotFitted {
573                    operation: "attention combination".to_string(),
574                })?;
575
576        let mut result = scale_features[0].clone() * attention_weights[0];
577        for (i, features) in scale_features.iter().enumerate().skip(1) {
578            result = result + features * attention_weights[i];
579        }
580
581        Ok(result)
582    }
583}
584
585#[allow(non_snake_case)]
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use approx::assert_abs_diff_eq;
590    use scirs2_core::ndarray::array;
591
592    #[test]
593    fn test_multi_scale_rbf_sampler_basic() {
594        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
595
596        let sampler = MultiScaleRBFSampler::new(10)
597            .n_scales(3)
598            .gamma_range(0.1, 10.0)
599            .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
600            .combination_strategy(CombinationStrategy::Concatenation)
601            .random_state(42);
602
603        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
604        let features = fitted.transform(&x).expect("operation should succeed");
605
606        // Concatenation should give 3 scales * 10 components = 30 features
607        assert_eq!(features.shape(), &[3, 30]);
608
609        // Check that features are bounded (cosine function)
610        for &val in features.iter() {
611            assert!(val >= -2.0 && val <= 2.0);
612        }
613    }
614
615    #[test]
616    fn test_different_bandwidth_strategies() {
617        let x = array![[1.0, 2.0], [3.0, 4.0]];
618
619        let strategies = [
620            BandwidthStrategy::LogarithmicSpacing,
621            BandwidthStrategy::LinearSpacing,
622            BandwidthStrategy::GeometricProgression,
623            BandwidthStrategy::Adaptive,
624        ];
625
626        for strategy in &strategies {
627            let sampler = MultiScaleRBFSampler::new(5)
628                .n_scales(3)
629                .gamma_range(0.1, 10.0)
630                .bandwidth_strategy(*strategy)
631                .random_state(42);
632
633            let fitted = sampler.fit(&x, &()).expect("operation should succeed");
634            let features = fitted.transform(&x).expect("operation should succeed");
635
636            assert_eq!(features.shape(), &[2, 15]); // 3 scales * 5 components
637        }
638    }
639
640    #[test]
641    fn test_different_combination_strategies() {
642        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
643
644        let strategies = [
645            (CombinationStrategy::Concatenation, 30), // 3 scales * 10 components
646            (CombinationStrategy::WeightedAverage, 10), // Same as single scale
647            (CombinationStrategy::MaxPooling, 10),
648            (CombinationStrategy::AveragePooling, 10),
649            (CombinationStrategy::Attention, 10),
650        ];
651
652        for (strategy, expected_features) in &strategies {
653            let sampler = MultiScaleRBFSampler::new(10)
654                .n_scales(3)
655                .combination_strategy(*strategy)
656                .random_state(42);
657
658            let fitted = sampler.fit(&x, &()).expect("operation should succeed");
659            let features = fitted.transform(&x).expect("operation should succeed");
660
661            assert_eq!(features.shape(), &[3, *expected_features]);
662        }
663    }
664
665    #[test]
666    fn test_manual_gammas() {
667        let x = array![[1.0, 2.0], [3.0, 4.0]];
668        let manual_gammas = vec![0.1, 1.0, 10.0];
669
670        let sampler = MultiScaleRBFSampler::new(8)
671            .manual_gammas(manual_gammas.clone())
672            .random_state(42);
673
674        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
675        let features = fitted.transform(&x).expect("operation should succeed");
676
677        assert_eq!(features.shape(), &[2, 24]); // 3 scales * 8 components
678        assert_eq!(
679            fitted.gammas_.as_ref().expect("operation should succeed"),
680            &manual_gammas
681        );
682    }
683
684    #[test]
685    fn test_scale_weights() {
686        let x = array![[1.0, 2.0], [3.0, 4.0]];
687        let weights = vec![1.0, 2.0, 0.5];
688
689        let sampler = MultiScaleRBFSampler::new(10)
690            .n_scales(3)
691            .combination_strategy(CombinationStrategy::WeightedAverage)
692            .scale_weights(weights.clone())
693            .random_state(42);
694
695        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
696        let features = fitted.transform(&x).expect("operation should succeed");
697
698        assert_eq!(features.shape(), &[2, 10]);
699    }
700
701    #[test]
702    fn test_reproducibility() {
703        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
704
705        let sampler1 = MultiScaleRBFSampler::new(20)
706            .n_scales(4)
707            .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
708            .combination_strategy(CombinationStrategy::Concatenation)
709            .random_state(123);
710
711        let sampler2 = MultiScaleRBFSampler::new(20)
712            .n_scales(4)
713            .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
714            .combination_strategy(CombinationStrategy::Concatenation)
715            .random_state(123);
716
717        let fitted1 = sampler1.fit(&x, &()).expect("operation should succeed");
718        let fitted2 = sampler2.fit(&x, &()).expect("operation should succeed");
719
720        let features1 = fitted1.transform(&x).expect("operation should succeed");
721        let features2 = fitted2.transform(&x).expect("operation should succeed");
722
723        for (f1, f2) in features1.iter().zip(features2.iter()) {
724            assert_abs_diff_eq!(f1, f2, epsilon = 1e-10);
725        }
726    }
727
728    #[test]
729    fn test_adaptive_bandwidth() {
730        let x = array![
731            [1.0, 1.0],
732            [1.1, 1.1],
733            [5.0, 5.0],
734            [5.1, 5.1],
735            [10.0, 10.0],
736            [10.1, 10.1]
737        ];
738
739        let sampler = MultiScaleRBFSampler::new(15)
740            .n_scales(3)
741            .bandwidth_strategy(BandwidthStrategy::Adaptive)
742            .random_state(42);
743
744        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
745        let features = fitted.transform(&x).expect("operation should succeed");
746
747        assert_eq!(features.shape(), &[6, 45]); // 3 scales * 15 components
748
749        // Check that adaptive gammas were computed
750        let gammas = fitted.gammas_.as_ref().expect("operation should succeed");
751        assert_eq!(gammas.len(), 3);
752        assert!(gammas.iter().all(|&g| g > 0.0));
753    }
754
755    #[test]
756    fn test_error_handling() {
757        // Empty input
758        let empty = Array2::<Float>::zeros((0, 0));
759        let sampler = MultiScaleRBFSampler::new(10);
760        assert!(sampler.clone().fit(&empty, &()).is_err());
761
762        // Zero scales
763        let x = array![[1.0, 2.0]];
764        let invalid_sampler = MultiScaleRBFSampler::new(10).n_scales(0);
765        assert!(invalid_sampler.fit(&x, &()).is_err());
766
767        // Dimension mismatch in transform
768        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
769        let x_test = array![[1.0, 2.0, 3.0]]; // Wrong number of features
770
771        let fitted = sampler
772            .fit(&x_train, &())
773            .expect("operation should succeed");
774        assert!(fitted.transform(&x_test).is_err());
775    }
776
777    #[test]
778    fn test_single_scale() {
779        let x = array![[1.0, 2.0], [3.0, 4.0]];
780
781        let sampler = MultiScaleRBFSampler::new(15)
782            .n_scales(1)
783            .gamma_range(1.0, 1.0)
784            .random_state(42);
785
786        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
787        let features = fitted.transform(&x).expect("operation should succeed");
788
789        assert_eq!(features.shape(), &[2, 15]);
790
791        let gammas = fitted.gammas_.as_ref().expect("operation should succeed");
792        assert_eq!(gammas.len(), 1);
793    }
794
795    #[test]
796    fn test_gamma_computation_strategies() {
797        let sampler = MultiScaleRBFSampler::new(10)
798            .n_scales(4)
799            .gamma_range(0.1, 10.0);
800
801        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
802
803        // Test logarithmic spacing
804        let log_sampler = sampler
805            .clone()
806            .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing);
807        let log_gammas = log_sampler
808            .compute_gammas(&x)
809            .expect("operation should succeed");
810        assert_eq!(log_gammas.len(), 4);
811        assert_abs_diff_eq!(log_gammas[0], 0.1, epsilon = 1e-10);
812        assert_abs_diff_eq!(log_gammas[3], 10.0, epsilon = 1e-10);
813
814        // Test linear spacing
815        let lin_sampler = sampler
816            .clone()
817            .bandwidth_strategy(BandwidthStrategy::LinearSpacing);
818        let lin_gammas = lin_sampler
819            .compute_gammas(&x)
820            .expect("operation should succeed");
821        assert_eq!(lin_gammas.len(), 4);
822        assert_abs_diff_eq!(lin_gammas[0], 0.1, epsilon = 1e-10);
823        assert_abs_diff_eq!(lin_gammas[3], 10.0, epsilon = 1e-10);
824
825        // Test geometric progression
826        let geo_sampler = sampler
827            .clone()
828            .bandwidth_strategy(BandwidthStrategy::GeometricProgression);
829        let geo_gammas = geo_sampler
830            .compute_gammas(&x)
831            .expect("operation should succeed");
832        assert_eq!(geo_gammas.len(), 4);
833        assert_abs_diff_eq!(geo_gammas[0], 0.1, epsilon = 1e-10);
834        assert_abs_diff_eq!(geo_gammas[3], 10.0, epsilon = 1e-10);
835    }
836}