sklears_kernel_approximation/
multi_scale_rbf.rs

1//! Multi-scale RBF kernel approximation methods
2//!
3//! This module implements multi-scale RBF kernel approximation that uses multiple bandwidth
4//! parameters to capture patterns at different scales. This provides better approximation
5//! quality for data with features at multiple scales.
6
7use scirs2_core::ndarray::{s, Array1, Array2};
8use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Fit, Trained, Transform, Untrained},
15    types::Float,
16};
17use std::marker::PhantomData;
18
19/// Multi-scale bandwidth selection strategy
20#[derive(Debug, Clone, Copy)]
21/// BandwidthStrategy
22pub enum BandwidthStrategy {
23    /// Manual specification of gamma values
24    Manual,
25    /// Logarithmic spacing: gamma_i = gamma_min * (gamma_max/gamma_min)^(i/(n_scales-1))
26    LogarithmicSpacing,
27    /// Linear spacing: gamma_i = gamma_min + i * (gamma_max - gamma_min) / (n_scales - 1)
28    LinearSpacing,
29    /// Geometric progression: gamma_i = gamma_min * ratio^i
30    GeometricProgression,
31    /// Adaptive spacing based on data characteristics
32    Adaptive,
33}
34
35/// Feature combination strategy for multi-scale features
36#[derive(Debug, Clone, Copy)]
37/// CombinationStrategy
38pub enum CombinationStrategy {
39    /// Concatenate features from all scales
40    Concatenation,
41    /// Weighted average of features across scales
42    WeightedAverage,
43    /// Maximum response across scales
44    MaxPooling,
45    /// Average pooling across scales
46    AveragePooling,
47    /// Attention-based combination
48    Attention,
49}
50
51/// Multi-scale RBF sampler that generates random Fourier features at multiple scales
52///
53/// This sampler generates RBF kernel approximations using multiple bandwidth parameters
54/// (gamma values) to capture patterns at different scales. Each scale captures different
55/// frequency characteristics of the data, providing a more comprehensive representation.
56///
57/// # Mathematical Background
58///
59/// For multiple scales with bandwidths γ₁, γ₂, ..., γₖ, the multi-scale RBF kernel is:
60/// K(x,y) = Σᵢ wᵢ * exp(-γᵢ||x-y||²)
61///
62/// The random Fourier features for each scale i are:
63/// zᵢ(x) = √(2/nᵢ) * [cos(ωᵢⱼᵀx + bᵢⱼ), sin(ωᵢⱼᵀx + bᵢⱼ)]
64/// where ωᵢⱼ ~ N(0, 2γᵢI) and bᵢⱼ ~ Uniform[0, 2π]
65///
66/// # Examples
67///
68/// ```rust,ignore
69/// use sklears_kernel_approximation::{MultiScaleRBFSampler, BandwidthStrategy, CombinationStrategy};
70/// use sklears_core::traits::{Transform, Fit, Untrained}
71/// use scirs2_core::ndarray::array;
72///
73/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
74///
75/// let sampler = MultiScaleRBFSampler::new(100)
76///     .n_scales(3)
77///     .gamma_range(0.1, 10.0)
78///     .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
79///     .combination_strategy(CombinationStrategy::Concatenation);
80///
81/// let fitted = sampler.fit(&x, &()).unwrap();
82/// let features = fitted.transform(&x).unwrap();
83/// ```
84#[derive(Debug, Clone)]
85/// MultiScaleRBFSampler
86pub struct MultiScaleRBFSampler<State = Untrained> {
87    /// Number of components per scale
88    pub n_components_per_scale: usize,
89    /// Number of scales
90    pub n_scales: usize,
91    /// Minimum gamma value
92    pub gamma_min: Float,
93    /// Maximum gamma value
94    pub gamma_max: Float,
95    /// Manual gamma values (used when strategy is Manual)
96    pub manual_gammas: Vec<Float>,
97    /// Bandwidth selection strategy
98    pub bandwidth_strategy: BandwidthStrategy,
99    /// Feature combination strategy
100    pub combination_strategy: CombinationStrategy,
101    /// Scale weights for weighted combination
102    pub scale_weights: Vec<Float>,
103    /// Random seed for reproducibility
104    pub random_state: Option<u64>,
105
106    // Fitted attributes
107    gammas_: Option<Vec<Float>>,
108    random_weights_: Option<Vec<Array2<Float>>>,
109    random_offsets_: Option<Vec<Array1<Float>>>,
110    attention_weights_: Option<Array1<Float>>, // For attention-based combination
111
112    // State marker
113    _state: PhantomData<State>,
114}
115
116impl MultiScaleRBFSampler<Untrained> {
117    /// Create a new multi-scale RBF sampler
118    ///
119    /// # Arguments
120    /// * `n_components_per_scale` - Number of random features per scale
121    pub fn new(n_components_per_scale: usize) -> Self {
122        Self {
123            n_components_per_scale,
124            n_scales: 3,
125            gamma_min: 0.1,
126            gamma_max: 10.0,
127            manual_gammas: vec![],
128            bandwidth_strategy: BandwidthStrategy::LogarithmicSpacing,
129            combination_strategy: CombinationStrategy::Concatenation,
130            scale_weights: vec![],
131            random_state: None,
132            gammas_: None,
133            random_weights_: None,
134            random_offsets_: None,
135            attention_weights_: None,
136            _state: PhantomData,
137        }
138    }
139
140    /// Set the number of scales
141    pub fn n_scales(mut self, n_scales: usize) -> Self {
142        self.n_scales = n_scales;
143        self
144    }
145
146    /// Set the gamma range for automatic bandwidth selection
147    pub fn gamma_range(mut self, gamma_min: Float, gamma_max: Float) -> Self {
148        self.gamma_min = gamma_min;
149        self.gamma_max = gamma_max;
150        self
151    }
152
153    /// Set manual gamma values
154    pub fn manual_gammas(mut self, gammas: Vec<Float>) -> Self {
155        self.n_scales = gammas.len();
156        self.manual_gammas = gammas;
157        self.bandwidth_strategy = BandwidthStrategy::Manual;
158        self
159    }
160
161    /// Set the bandwidth selection strategy
162    pub fn bandwidth_strategy(mut self, strategy: BandwidthStrategy) -> Self {
163        self.bandwidth_strategy = strategy;
164        self
165    }
166
167    /// Set the feature combination strategy
168    pub fn combination_strategy(mut self, strategy: CombinationStrategy) -> Self {
169        self.combination_strategy = strategy;
170        self
171    }
172
173    /// Set weights for different scales (used in weighted combination)
174    pub fn scale_weights(mut self, weights: Vec<Float>) -> Self {
175        self.scale_weights = weights;
176        self
177    }
178
179    /// Set the random state for reproducibility
180    pub fn random_state(mut self, seed: u64) -> Self {
181        self.random_state = Some(seed);
182        self
183    }
184
185    /// Compute gamma values based on the selected strategy
186    fn compute_gammas(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
187        match self.bandwidth_strategy {
188            BandwidthStrategy::Manual => {
189                if self.manual_gammas.is_empty() {
190                    return Err(SklearsError::InvalidParameter {
191                        name: "manual_gammas".to_string(),
192                        reason: "manual gammas not provided".to_string(),
193                    });
194                }
195                Ok(self.manual_gammas.clone())
196            }
197            BandwidthStrategy::LogarithmicSpacing => {
198                let mut gammas = Vec::with_capacity(self.n_scales);
199                if self.n_scales == 1 {
200                    gammas.push((self.gamma_min * self.gamma_max).sqrt());
201                } else {
202                    let log_min = self.gamma_min.ln();
203                    let log_max = self.gamma_max.ln();
204                    for i in 0..self.n_scales {
205                        let t = i as Float / (self.n_scales - 1) as Float;
206                        let log_gamma = log_min + t * (log_max - log_min);
207                        gammas.push(log_gamma.exp());
208                    }
209                }
210                Ok(gammas)
211            }
212            BandwidthStrategy::LinearSpacing => {
213                let mut gammas = Vec::with_capacity(self.n_scales);
214                if self.n_scales == 1 {
215                    gammas.push((self.gamma_min + self.gamma_max) / 2.0);
216                } else {
217                    for i in 0..self.n_scales {
218                        let t = i as Float / (self.n_scales - 1) as Float;
219                        let gamma = self.gamma_min + t * (self.gamma_max - self.gamma_min);
220                        gammas.push(gamma);
221                    }
222                }
223                Ok(gammas)
224            }
225            BandwidthStrategy::GeometricProgression => {
226                let mut gammas = Vec::with_capacity(self.n_scales);
227                let ratio = if self.n_scales == 1 {
228                    1.0
229                } else {
230                    (self.gamma_max / self.gamma_min).powf(1.0 / (self.n_scales - 1) as Float)
231                };
232                for i in 0..self.n_scales {
233                    let gamma = self.gamma_min * ratio.powi(i as i32);
234                    gammas.push(gamma);
235                }
236                Ok(gammas)
237            }
238            BandwidthStrategy::Adaptive => {
239                // Adaptive bandwidth selection based on data characteristics
240                self.compute_adaptive_gammas(x)
241            }
242        }
243    }
244
245    /// Compute adaptive gamma values based on data characteristics
246    fn compute_adaptive_gammas(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
247        let (n_samples, _n_features) = x.dim();
248
249        if n_samples < 2 {
250            return Err(SklearsError::InvalidInput(
251                "Need at least 2 samples for adaptive bandwidth selection".to_string(),
252            ));
253        }
254
255        // Compute pairwise distances for a subset of points
256        let n_subset = n_samples.min(100);
257        let mut distances = Vec::new();
258
259        for i in 0..n_subset {
260            for j in (i + 1)..n_subset {
261                let diff = &x.row(i) - &x.row(j);
262                let dist_sq = diff.mapv(|x| x * x).sum();
263                distances.push(dist_sq.sqrt());
264            }
265        }
266
267        if distances.is_empty() {
268            return Ok(vec![1.0; self.n_scales]);
269        }
270
271        // Sort distances
272        distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
273
274        // Use percentiles to determine scales
275        let mut gammas = Vec::with_capacity(self.n_scales);
276        for i in 0..self.n_scales {
277            let percentile = if self.n_scales == 1 {
278                0.5
279            } else {
280                i as Float / (self.n_scales - 1) as Float
281            };
282
283            let idx = ((distances.len() - 1) as Float * percentile) as usize;
284            let characteristic_distance = distances[idx];
285
286            // gamma = 1 / (2 * sigma^2), where sigma is related to characteristic distance
287            let gamma = if characteristic_distance > 0.0 {
288                1.0 / (2.0 * characteristic_distance * characteristic_distance)
289            } else {
290                1.0
291            };
292
293            gammas.push(gamma);
294        }
295
296        Ok(gammas)
297    }
298}
299
300impl Fit<Array2<Float>, ()> for MultiScaleRBFSampler<Untrained> {
301    type Fitted = MultiScaleRBFSampler<Trained>;
302
303    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
304        let (n_samples, n_features) = x.dim();
305
306        if n_samples == 0 || n_features == 0 {
307            return Err(SklearsError::InvalidInput(
308                "Input array is empty".to_string(),
309            ));
310        }
311
312        if self.n_scales == 0 {
313            return Err(SklearsError::InvalidParameter {
314                name: "n_scales".to_string(),
315                reason: "must be positive".to_string(),
316            });
317        }
318
319        let mut rng = match self.random_state {
320            Some(seed) => RealStdRng::seed_from_u64(seed),
321            None => RealStdRng::from_seed(thread_rng().gen()),
322        };
323
324        // Compute gamma values for each scale
325        let gammas = self.compute_gammas(x)?;
326
327        // Generate random weights and offsets for each scale
328        let mut random_weights = Vec::with_capacity(self.n_scales);
329        let mut random_offsets = Vec::with_capacity(self.n_scales);
330
331        for &gamma in &gammas {
332            // Generate random weights ~ N(0, 2*gamma*I)
333            let std_dev = (2.0 * gamma).sqrt();
334            let mut weights = Array2::zeros((self.n_components_per_scale, n_features));
335            for i in 0..self.n_components_per_scale {
336                for j in 0..n_features {
337                    weights[[i, j]] =
338                        rng.sample::<Float, _>(RandNormal::new(0.0, std_dev).map_err(|e| {
339                            SklearsError::NumericalError(format!(
340                                "Error creating normal distribution: {}",
341                                e
342                            ))
343                        })?);
344                }
345            }
346
347            // Generate random offsets ~ Uniform[0, 2π]
348            let mut offsets = Array1::zeros(self.n_components_per_scale);
349            for i in 0..self.n_components_per_scale {
350                offsets[i] = rng
351                    .sample::<Float, _>(RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap());
352            }
353
354            random_weights.push(weights);
355            random_offsets.push(offsets);
356        }
357
358        // Compute attention weights if using attention-based combination
359        let attention_weights =
360            if matches!(self.combination_strategy, CombinationStrategy::Attention) {
361                Some(compute_attention_weights(&gammas)?)
362            } else {
363                None
364            };
365
366        Ok(MultiScaleRBFSampler {
367            n_components_per_scale: self.n_components_per_scale,
368            n_scales: self.n_scales,
369            gamma_min: self.gamma_min,
370            gamma_max: self.gamma_max,
371            manual_gammas: self.manual_gammas,
372            bandwidth_strategy: self.bandwidth_strategy,
373            combination_strategy: self.combination_strategy,
374            scale_weights: self.scale_weights,
375            random_state: self.random_state,
376            gammas_: Some(gammas),
377            random_weights_: Some(random_weights),
378            random_offsets_: Some(random_offsets),
379            attention_weights_: attention_weights,
380            _state: PhantomData,
381        })
382    }
383}
384
385/// Compute attention weights based on gamma values
386fn compute_attention_weights(gammas: &[Float]) -> Result<Array1<Float>> {
387    // Simple attention mechanism: higher gamma (smaller scale) gets more weight
388    let weights: Vec<Float> = gammas.iter().map(|&g| g.ln()).collect();
389    let weights_array = Array1::from(weights);
390
391    // Softmax normalization
392    let max_weight = weights_array
393        .iter()
394        .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
395    let exp_weights = weights_array.mapv(|w| (w - max_weight).exp());
396    let sum_exp = exp_weights.sum();
397
398    Ok(exp_weights.mapv(|w| w / sum_exp))
399}
400
401impl Transform<Array2<Float>> for MultiScaleRBFSampler<Trained> {
402    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
403        let _gammas = self
404            .gammas_
405            .as_ref()
406            .ok_or_else(|| SklearsError::NotFitted {
407                operation: "transform".to_string(),
408            })?;
409
410        let random_weights =
411            self.random_weights_
412                .as_ref()
413                .ok_or_else(|| SklearsError::NotFitted {
414                    operation: "transform".to_string(),
415                })?;
416
417        let random_offsets =
418            self.random_offsets_
419                .as_ref()
420                .ok_or_else(|| SklearsError::NotFitted {
421                    operation: "transform".to_string(),
422                })?;
423
424        let (_n_samples, n_features) = x.dim();
425
426        // Generate features for each scale
427        let mut scale_features = Vec::with_capacity(self.n_scales);
428
429        for i in 0..self.n_scales {
430            let weights = &random_weights[i];
431            let offsets = &random_offsets[i];
432
433            if n_features != weights.ncols() {
434                return Err(SklearsError::InvalidInput(format!(
435                    "Input has {} features, expected {}",
436                    n_features,
437                    weights.ncols()
438                )));
439            }
440
441            // Compute X @ W.T + b
442            let projection = x.dot(&weights.t()) + offsets;
443
444            // Apply cosine transformation and normalize
445            let normalization = (2.0 / weights.nrows() as Float).sqrt();
446            let features = projection.mapv(|x| x.cos() * normalization);
447
448            scale_features.push(features);
449        }
450
451        // Combine features across scales
452        match self.combination_strategy {
453            CombinationStrategy::Concatenation => self.concatenate_features(scale_features),
454            CombinationStrategy::WeightedAverage => self.weighted_average_features(scale_features),
455            CombinationStrategy::MaxPooling => self.max_pooling_features(scale_features),
456            CombinationStrategy::AveragePooling => self.average_pooling_features(scale_features),
457            CombinationStrategy::Attention => self.attention_combine_features(scale_features),
458        }
459    }
460}
461
462impl MultiScaleRBFSampler<Trained> {
463    /// Concatenate features from all scales
464    fn concatenate_features(&self, scale_features: Vec<Array2<Float>>) -> Result<Array2<Float>> {
465        if scale_features.is_empty() {
466            return Err(SklearsError::InvalidInput(
467                "No scale features to concatenate".to_string(),
468            ));
469        }
470
471        let n_samples = scale_features[0].nrows();
472        let total_features: usize = scale_features.iter().map(|f| f.ncols()).sum();
473
474        let mut result = Array2::zeros((n_samples, total_features));
475        let mut col_offset = 0;
476
477        for features in scale_features {
478            let n_cols = features.ncols();
479            result
480                .slice_mut(s![.., col_offset..col_offset + n_cols])
481                .assign(&features);
482            col_offset += n_cols;
483        }
484
485        Ok(result)
486    }
487
488    /// Compute weighted average of features across scales
489    fn weighted_average_features(
490        &self,
491        scale_features: Vec<Array2<Float>>,
492    ) -> Result<Array2<Float>> {
493        if scale_features.is_empty() {
494            return Err(SklearsError::InvalidInput(
495                "No scale features to average".to_string(),
496            ));
497        }
498
499        let weights = if self.scale_weights.is_empty() {
500            // Equal weights
501            vec![1.0 / self.n_scales as Float; self.n_scales]
502        } else {
503            // Normalize provided weights
504            let sum: Float = self.scale_weights.iter().sum();
505            self.scale_weights.iter().map(|&w| w / sum).collect()
506        };
507
508        let mut result = scale_features[0].clone() * weights[0];
509        for (i, features) in scale_features.iter().enumerate().skip(1) {
510            result = result + features * weights[i];
511        }
512
513        Ok(result)
514    }
515
516    /// Apply max pooling across scales
517    fn max_pooling_features(&self, scale_features: Vec<Array2<Float>>) -> Result<Array2<Float>> {
518        if scale_features.is_empty() {
519            return Err(SklearsError::InvalidInput(
520                "No scale features for max pooling".to_string(),
521            ));
522        }
523
524        let mut result = scale_features[0].clone();
525        for features in scale_features.iter().skip(1) {
526            for ((i, j), val) in features.indexed_iter() {
527                if *val > result[[i, j]] {
528                    result[[i, j]] = *val;
529                }
530            }
531        }
532
533        Ok(result)
534    }
535
536    /// Apply average pooling across scales
537    fn average_pooling_features(
538        &self,
539        scale_features: Vec<Array2<Float>>,
540    ) -> Result<Array2<Float>> {
541        if scale_features.is_empty() {
542            return Err(SklearsError::InvalidInput(
543                "No scale features for average pooling".to_string(),
544            ));
545        }
546
547        let mut result = scale_features[0].clone();
548        for features in scale_features.iter().skip(1) {
549            result += features;
550        }
551
552        result.mapv_inplace(|x| x / self.n_scales as Float);
553        Ok(result)
554    }
555
556    /// Apply attention-based combination of features
557    fn attention_combine_features(
558        &self,
559        scale_features: Vec<Array2<Float>>,
560    ) -> Result<Array2<Float>> {
561        if scale_features.is_empty() {
562            return Err(SklearsError::InvalidInput(
563                "No scale features for attention combination".to_string(),
564            ));
565        }
566
567        let attention_weights =
568            self.attention_weights_
569                .as_ref()
570                .ok_or_else(|| SklearsError::NotFitted {
571                    operation: "attention combination".to_string(),
572                })?;
573
574        let mut result = scale_features[0].clone() * attention_weights[0];
575        for (i, features) in scale_features.iter().enumerate().skip(1) {
576            result = result + features * attention_weights[i];
577        }
578
579        Ok(result)
580    }
581}
582
583#[allow(non_snake_case)]
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use approx::assert_abs_diff_eq;
588    use scirs2_core::ndarray::array;
589
590    #[test]
591    fn test_multi_scale_rbf_sampler_basic() {
592        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
593
594        let sampler = MultiScaleRBFSampler::new(10)
595            .n_scales(3)
596            .gamma_range(0.1, 10.0)
597            .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
598            .combination_strategy(CombinationStrategy::Concatenation)
599            .random_state(42);
600
601        let fitted = sampler.fit(&x, &()).unwrap();
602        let features = fitted.transform(&x).unwrap();
603
604        // Concatenation should give 3 scales * 10 components = 30 features
605        assert_eq!(features.shape(), &[3, 30]);
606
607        // Check that features are bounded (cosine function)
608        for &val in features.iter() {
609            assert!(val >= -2.0 && val <= 2.0);
610        }
611    }
612
613    #[test]
614    fn test_different_bandwidth_strategies() {
615        let x = array![[1.0, 2.0], [3.0, 4.0]];
616
617        let strategies = [
618            BandwidthStrategy::LogarithmicSpacing,
619            BandwidthStrategy::LinearSpacing,
620            BandwidthStrategy::GeometricProgression,
621            BandwidthStrategy::Adaptive,
622        ];
623
624        for strategy in &strategies {
625            let sampler = MultiScaleRBFSampler::new(5)
626                .n_scales(3)
627                .gamma_range(0.1, 10.0)
628                .bandwidth_strategy(*strategy)
629                .random_state(42);
630
631            let fitted = sampler.fit(&x, &()).unwrap();
632            let features = fitted.transform(&x).unwrap();
633
634            assert_eq!(features.shape(), &[2, 15]); // 3 scales * 5 components
635        }
636    }
637
638    #[test]
639    fn test_different_combination_strategies() {
640        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
641
642        let strategies = [
643            (CombinationStrategy::Concatenation, 30), // 3 scales * 10 components
644            (CombinationStrategy::WeightedAverage, 10), // Same as single scale
645            (CombinationStrategy::MaxPooling, 10),
646            (CombinationStrategy::AveragePooling, 10),
647            (CombinationStrategy::Attention, 10),
648        ];
649
650        for (strategy, expected_features) in &strategies {
651            let sampler = MultiScaleRBFSampler::new(10)
652                .n_scales(3)
653                .combination_strategy(*strategy)
654                .random_state(42);
655
656            let fitted = sampler.fit(&x, &()).unwrap();
657            let features = fitted.transform(&x).unwrap();
658
659            assert_eq!(features.shape(), &[3, *expected_features]);
660        }
661    }
662
663    #[test]
664    fn test_manual_gammas() {
665        let x = array![[1.0, 2.0], [3.0, 4.0]];
666        let manual_gammas = vec![0.1, 1.0, 10.0];
667
668        let sampler = MultiScaleRBFSampler::new(8)
669            .manual_gammas(manual_gammas.clone())
670            .random_state(42);
671
672        let fitted = sampler.fit(&x, &()).unwrap();
673        let features = fitted.transform(&x).unwrap();
674
675        assert_eq!(features.shape(), &[2, 24]); // 3 scales * 8 components
676        assert_eq!(fitted.gammas_.as_ref().unwrap(), &manual_gammas);
677    }
678
679    #[test]
680    fn test_scale_weights() {
681        let x = array![[1.0, 2.0], [3.0, 4.0]];
682        let weights = vec![1.0, 2.0, 0.5];
683
684        let sampler = MultiScaleRBFSampler::new(10)
685            .n_scales(3)
686            .combination_strategy(CombinationStrategy::WeightedAverage)
687            .scale_weights(weights.clone())
688            .random_state(42);
689
690        let fitted = sampler.fit(&x, &()).unwrap();
691        let features = fitted.transform(&x).unwrap();
692
693        assert_eq!(features.shape(), &[2, 10]);
694    }
695
696    #[test]
697    fn test_reproducibility() {
698        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
699
700        let sampler1 = MultiScaleRBFSampler::new(20)
701            .n_scales(4)
702            .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
703            .combination_strategy(CombinationStrategy::Concatenation)
704            .random_state(123);
705
706        let sampler2 = MultiScaleRBFSampler::new(20)
707            .n_scales(4)
708            .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
709            .combination_strategy(CombinationStrategy::Concatenation)
710            .random_state(123);
711
712        let fitted1 = sampler1.fit(&x, &()).unwrap();
713        let fitted2 = sampler2.fit(&x, &()).unwrap();
714
715        let features1 = fitted1.transform(&x).unwrap();
716        let features2 = fitted2.transform(&x).unwrap();
717
718        for (f1, f2) in features1.iter().zip(features2.iter()) {
719            assert_abs_diff_eq!(f1, f2, epsilon = 1e-10);
720        }
721    }
722
723    #[test]
724    fn test_adaptive_bandwidth() {
725        let x = array![
726            [1.0, 1.0],
727            [1.1, 1.1],
728            [5.0, 5.0],
729            [5.1, 5.1],
730            [10.0, 10.0],
731            [10.1, 10.1]
732        ];
733
734        let sampler = MultiScaleRBFSampler::new(15)
735            .n_scales(3)
736            .bandwidth_strategy(BandwidthStrategy::Adaptive)
737            .random_state(42);
738
739        let fitted = sampler.fit(&x, &()).unwrap();
740        let features = fitted.transform(&x).unwrap();
741
742        assert_eq!(features.shape(), &[6, 45]); // 3 scales * 15 components
743
744        // Check that adaptive gammas were computed
745        let gammas = fitted.gammas_.as_ref().unwrap();
746        assert_eq!(gammas.len(), 3);
747        assert!(gammas.iter().all(|&g| g > 0.0));
748    }
749
750    #[test]
751    fn test_error_handling() {
752        // Empty input
753        let empty = Array2::<Float>::zeros((0, 0));
754        let sampler = MultiScaleRBFSampler::new(10);
755        assert!(sampler.clone().fit(&empty, &()).is_err());
756
757        // Zero scales
758        let x = array![[1.0, 2.0]];
759        let invalid_sampler = MultiScaleRBFSampler::new(10).n_scales(0);
760        assert!(invalid_sampler.fit(&x, &()).is_err());
761
762        // Dimension mismatch in transform
763        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
764        let x_test = array![[1.0, 2.0, 3.0]]; // Wrong number of features
765
766        let fitted = sampler.fit(&x_train, &()).unwrap();
767        assert!(fitted.transform(&x_test).is_err());
768    }
769
770    #[test]
771    fn test_single_scale() {
772        let x = array![[1.0, 2.0], [3.0, 4.0]];
773
774        let sampler = MultiScaleRBFSampler::new(15)
775            .n_scales(1)
776            .gamma_range(1.0, 1.0)
777            .random_state(42);
778
779        let fitted = sampler.fit(&x, &()).unwrap();
780        let features = fitted.transform(&x).unwrap();
781
782        assert_eq!(features.shape(), &[2, 15]);
783
784        let gammas = fitted.gammas_.as_ref().unwrap();
785        assert_eq!(gammas.len(), 1);
786    }
787
788    #[test]
789    fn test_gamma_computation_strategies() {
790        let sampler = MultiScaleRBFSampler::new(10)
791            .n_scales(4)
792            .gamma_range(0.1, 10.0);
793
794        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
795
796        // Test logarithmic spacing
797        let log_sampler = sampler
798            .clone()
799            .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing);
800        let log_gammas = log_sampler.compute_gammas(&x).unwrap();
801        assert_eq!(log_gammas.len(), 4);
802        assert_abs_diff_eq!(log_gammas[0], 0.1, epsilon = 1e-10);
803        assert_abs_diff_eq!(log_gammas[3], 10.0, epsilon = 1e-10);
804
805        // Test linear spacing
806        let lin_sampler = sampler
807            .clone()
808            .bandwidth_strategy(BandwidthStrategy::LinearSpacing);
809        let lin_gammas = lin_sampler.compute_gammas(&x).unwrap();
810        assert_eq!(lin_gammas.len(), 4);
811        assert_abs_diff_eq!(lin_gammas[0], 0.1, epsilon = 1e-10);
812        assert_abs_diff_eq!(lin_gammas[3], 10.0, epsilon = 1e-10);
813
814        // Test geometric progression
815        let geo_sampler = sampler
816            .clone()
817            .bandwidth_strategy(BandwidthStrategy::GeometricProgression);
818        let geo_gammas = geo_sampler.compute_gammas(&x).unwrap();
819        assert_eq!(geo_gammas.len(), 4);
820        assert_abs_diff_eq!(geo_gammas[0], 0.1, epsilon = 1e-10);
821        assert_abs_diff_eq!(geo_gammas[3], 10.0, epsilon = 1e-10);
822    }
823}