1use scirs2_core::ndarray::{s, Array1, Array2};
8use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Fit, Trained, Transform, Untrained},
15 types::Float,
16};
17use std::marker::PhantomData;
18
19#[derive(Debug, Clone, Copy)]
21pub enum BandwidthStrategy {
23 Manual,
25 LogarithmicSpacing,
27 LinearSpacing,
29 GeometricProgression,
31 Adaptive,
33}
34
35#[derive(Debug, Clone, Copy)]
37pub enum CombinationStrategy {
39 Concatenation,
41 WeightedAverage,
43 MaxPooling,
45 AveragePooling,
47 Attention,
49}
50
51#[derive(Debug, Clone)]
85pub struct MultiScaleRBFSampler<State = Untrained> {
87 pub n_components_per_scale: usize,
89 pub n_scales: usize,
91 pub gamma_min: Float,
93 pub gamma_max: Float,
95 pub manual_gammas: Vec<Float>,
97 pub bandwidth_strategy: BandwidthStrategy,
99 pub combination_strategy: CombinationStrategy,
101 pub scale_weights: Vec<Float>,
103 pub random_state: Option<u64>,
105
106 gammas_: Option<Vec<Float>>,
108 random_weights_: Option<Vec<Array2<Float>>>,
109 random_offsets_: Option<Vec<Array1<Float>>>,
110 attention_weights_: Option<Array1<Float>>, _state: PhantomData<State>,
114}
115
116impl MultiScaleRBFSampler<Untrained> {
117 pub fn new(n_components_per_scale: usize) -> Self {
122 Self {
123 n_components_per_scale,
124 n_scales: 3,
125 gamma_min: 0.1,
126 gamma_max: 10.0,
127 manual_gammas: vec![],
128 bandwidth_strategy: BandwidthStrategy::LogarithmicSpacing,
129 combination_strategy: CombinationStrategy::Concatenation,
130 scale_weights: vec![],
131 random_state: None,
132 gammas_: None,
133 random_weights_: None,
134 random_offsets_: None,
135 attention_weights_: None,
136 _state: PhantomData,
137 }
138 }
139
140 pub fn n_scales(mut self, n_scales: usize) -> Self {
142 self.n_scales = n_scales;
143 self
144 }
145
146 pub fn gamma_range(mut self, gamma_min: Float, gamma_max: Float) -> Self {
148 self.gamma_min = gamma_min;
149 self.gamma_max = gamma_max;
150 self
151 }
152
153 pub fn manual_gammas(mut self, gammas: Vec<Float>) -> Self {
155 self.n_scales = gammas.len();
156 self.manual_gammas = gammas;
157 self.bandwidth_strategy = BandwidthStrategy::Manual;
158 self
159 }
160
161 pub fn bandwidth_strategy(mut self, strategy: BandwidthStrategy) -> Self {
163 self.bandwidth_strategy = strategy;
164 self
165 }
166
167 pub fn combination_strategy(mut self, strategy: CombinationStrategy) -> Self {
169 self.combination_strategy = strategy;
170 self
171 }
172
173 pub fn scale_weights(mut self, weights: Vec<Float>) -> Self {
175 self.scale_weights = weights;
176 self
177 }
178
179 pub fn random_state(mut self, seed: u64) -> Self {
181 self.random_state = Some(seed);
182 self
183 }
184
185 fn compute_gammas(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
187 match self.bandwidth_strategy {
188 BandwidthStrategy::Manual => {
189 if self.manual_gammas.is_empty() {
190 return Err(SklearsError::InvalidParameter {
191 name: "manual_gammas".to_string(),
192 reason: "manual gammas not provided".to_string(),
193 });
194 }
195 Ok(self.manual_gammas.clone())
196 }
197 BandwidthStrategy::LogarithmicSpacing => {
198 let mut gammas = Vec::with_capacity(self.n_scales);
199 if self.n_scales == 1 {
200 gammas.push((self.gamma_min * self.gamma_max).sqrt());
201 } else {
202 let log_min = self.gamma_min.ln();
203 let log_max = self.gamma_max.ln();
204 for i in 0..self.n_scales {
205 let t = i as Float / (self.n_scales - 1) as Float;
206 let log_gamma = log_min + t * (log_max - log_min);
207 gammas.push(log_gamma.exp());
208 }
209 }
210 Ok(gammas)
211 }
212 BandwidthStrategy::LinearSpacing => {
213 let mut gammas = Vec::with_capacity(self.n_scales);
214 if self.n_scales == 1 {
215 gammas.push((self.gamma_min + self.gamma_max) / 2.0);
216 } else {
217 for i in 0..self.n_scales {
218 let t = i as Float / (self.n_scales - 1) as Float;
219 let gamma = self.gamma_min + t * (self.gamma_max - self.gamma_min);
220 gammas.push(gamma);
221 }
222 }
223 Ok(gammas)
224 }
225 BandwidthStrategy::GeometricProgression => {
226 let mut gammas = Vec::with_capacity(self.n_scales);
227 let ratio = if self.n_scales == 1 {
228 1.0
229 } else {
230 (self.gamma_max / self.gamma_min).powf(1.0 / (self.n_scales - 1) as Float)
231 };
232 for i in 0..self.n_scales {
233 let gamma = self.gamma_min * ratio.powi(i as i32);
234 gammas.push(gamma);
235 }
236 Ok(gammas)
237 }
238 BandwidthStrategy::Adaptive => {
239 self.compute_adaptive_gammas(x)
241 }
242 }
243 }
244
245 fn compute_adaptive_gammas(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
247 let (n_samples, _n_features) = x.dim();
248
249 if n_samples < 2 {
250 return Err(SklearsError::InvalidInput(
251 "Need at least 2 samples for adaptive bandwidth selection".to_string(),
252 ));
253 }
254
255 let n_subset = n_samples.min(100);
257 let mut distances = Vec::new();
258
259 for i in 0..n_subset {
260 for j in (i + 1)..n_subset {
261 let diff = &x.row(i) - &x.row(j);
262 let dist_sq = diff.mapv(|x| x * x).sum();
263 distances.push(dist_sq.sqrt());
264 }
265 }
266
267 if distances.is_empty() {
268 return Ok(vec![1.0; self.n_scales]);
269 }
270
271 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
273
274 let mut gammas = Vec::with_capacity(self.n_scales);
276 for i in 0..self.n_scales {
277 let percentile = if self.n_scales == 1 {
278 0.5
279 } else {
280 i as Float / (self.n_scales - 1) as Float
281 };
282
283 let idx = ((distances.len() - 1) as Float * percentile) as usize;
284 let characteristic_distance = distances[idx];
285
286 let gamma = if characteristic_distance > 0.0 {
288 1.0 / (2.0 * characteristic_distance * characteristic_distance)
289 } else {
290 1.0
291 };
292
293 gammas.push(gamma);
294 }
295
296 Ok(gammas)
297 }
298}
299
300impl Fit<Array2<Float>, ()> for MultiScaleRBFSampler<Untrained> {
301 type Fitted = MultiScaleRBFSampler<Trained>;
302
303 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
304 let (n_samples, n_features) = x.dim();
305
306 if n_samples == 0 || n_features == 0 {
307 return Err(SklearsError::InvalidInput(
308 "Input array is empty".to_string(),
309 ));
310 }
311
312 if self.n_scales == 0 {
313 return Err(SklearsError::InvalidParameter {
314 name: "n_scales".to_string(),
315 reason: "must be positive".to_string(),
316 });
317 }
318
319 let mut rng = match self.random_state {
320 Some(seed) => RealStdRng::seed_from_u64(seed),
321 None => RealStdRng::from_seed(thread_rng().gen()),
322 };
323
324 let gammas = self.compute_gammas(x)?;
326
327 let mut random_weights = Vec::with_capacity(self.n_scales);
329 let mut random_offsets = Vec::with_capacity(self.n_scales);
330
331 for &gamma in &gammas {
332 let std_dev = (2.0 * gamma).sqrt();
334 let mut weights = Array2::zeros((self.n_components_per_scale, n_features));
335 for i in 0..self.n_components_per_scale {
336 for j in 0..n_features {
337 weights[[i, j]] =
338 rng.sample::<Float, _>(RandNormal::new(0.0, std_dev).map_err(|e| {
339 SklearsError::NumericalError(format!(
340 "Error creating normal distribution: {}",
341 e
342 ))
343 })?);
344 }
345 }
346
347 let mut offsets = Array1::zeros(self.n_components_per_scale);
349 for i in 0..self.n_components_per_scale {
350 offsets[i] = rng
351 .sample::<Float, _>(RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap());
352 }
353
354 random_weights.push(weights);
355 random_offsets.push(offsets);
356 }
357
358 let attention_weights =
360 if matches!(self.combination_strategy, CombinationStrategy::Attention) {
361 Some(compute_attention_weights(&gammas)?)
362 } else {
363 None
364 };
365
366 Ok(MultiScaleRBFSampler {
367 n_components_per_scale: self.n_components_per_scale,
368 n_scales: self.n_scales,
369 gamma_min: self.gamma_min,
370 gamma_max: self.gamma_max,
371 manual_gammas: self.manual_gammas,
372 bandwidth_strategy: self.bandwidth_strategy,
373 combination_strategy: self.combination_strategy,
374 scale_weights: self.scale_weights,
375 random_state: self.random_state,
376 gammas_: Some(gammas),
377 random_weights_: Some(random_weights),
378 random_offsets_: Some(random_offsets),
379 attention_weights_: attention_weights,
380 _state: PhantomData,
381 })
382 }
383}
384
385fn compute_attention_weights(gammas: &[Float]) -> Result<Array1<Float>> {
387 let weights: Vec<Float> = gammas.iter().map(|&g| g.ln()).collect();
389 let weights_array = Array1::from(weights);
390
391 let max_weight = weights_array
393 .iter()
394 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
395 let exp_weights = weights_array.mapv(|w| (w - max_weight).exp());
396 let sum_exp = exp_weights.sum();
397
398 Ok(exp_weights.mapv(|w| w / sum_exp))
399}
400
401impl Transform<Array2<Float>> for MultiScaleRBFSampler<Trained> {
402 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
403 let _gammas = self
404 .gammas_
405 .as_ref()
406 .ok_or_else(|| SklearsError::NotFitted {
407 operation: "transform".to_string(),
408 })?;
409
410 let random_weights =
411 self.random_weights_
412 .as_ref()
413 .ok_or_else(|| SklearsError::NotFitted {
414 operation: "transform".to_string(),
415 })?;
416
417 let random_offsets =
418 self.random_offsets_
419 .as_ref()
420 .ok_or_else(|| SklearsError::NotFitted {
421 operation: "transform".to_string(),
422 })?;
423
424 let (_n_samples, n_features) = x.dim();
425
426 let mut scale_features = Vec::with_capacity(self.n_scales);
428
429 for i in 0..self.n_scales {
430 let weights = &random_weights[i];
431 let offsets = &random_offsets[i];
432
433 if n_features != weights.ncols() {
434 return Err(SklearsError::InvalidInput(format!(
435 "Input has {} features, expected {}",
436 n_features,
437 weights.ncols()
438 )));
439 }
440
441 let projection = x.dot(&weights.t()) + offsets;
443
444 let normalization = (2.0 / weights.nrows() as Float).sqrt();
446 let features = projection.mapv(|x| x.cos() * normalization);
447
448 scale_features.push(features);
449 }
450
451 match self.combination_strategy {
453 CombinationStrategy::Concatenation => self.concatenate_features(scale_features),
454 CombinationStrategy::WeightedAverage => self.weighted_average_features(scale_features),
455 CombinationStrategy::MaxPooling => self.max_pooling_features(scale_features),
456 CombinationStrategy::AveragePooling => self.average_pooling_features(scale_features),
457 CombinationStrategy::Attention => self.attention_combine_features(scale_features),
458 }
459 }
460}
461
462impl MultiScaleRBFSampler<Trained> {
463 fn concatenate_features(&self, scale_features: Vec<Array2<Float>>) -> Result<Array2<Float>> {
465 if scale_features.is_empty() {
466 return Err(SklearsError::InvalidInput(
467 "No scale features to concatenate".to_string(),
468 ));
469 }
470
471 let n_samples = scale_features[0].nrows();
472 let total_features: usize = scale_features.iter().map(|f| f.ncols()).sum();
473
474 let mut result = Array2::zeros((n_samples, total_features));
475 let mut col_offset = 0;
476
477 for features in scale_features {
478 let n_cols = features.ncols();
479 result
480 .slice_mut(s![.., col_offset..col_offset + n_cols])
481 .assign(&features);
482 col_offset += n_cols;
483 }
484
485 Ok(result)
486 }
487
488 fn weighted_average_features(
490 &self,
491 scale_features: Vec<Array2<Float>>,
492 ) -> Result<Array2<Float>> {
493 if scale_features.is_empty() {
494 return Err(SklearsError::InvalidInput(
495 "No scale features to average".to_string(),
496 ));
497 }
498
499 let weights = if self.scale_weights.is_empty() {
500 vec![1.0 / self.n_scales as Float; self.n_scales]
502 } else {
503 let sum: Float = self.scale_weights.iter().sum();
505 self.scale_weights.iter().map(|&w| w / sum).collect()
506 };
507
508 let mut result = scale_features[0].clone() * weights[0];
509 for (i, features) in scale_features.iter().enumerate().skip(1) {
510 result = result + features * weights[i];
511 }
512
513 Ok(result)
514 }
515
516 fn max_pooling_features(&self, scale_features: Vec<Array2<Float>>) -> Result<Array2<Float>> {
518 if scale_features.is_empty() {
519 return Err(SklearsError::InvalidInput(
520 "No scale features for max pooling".to_string(),
521 ));
522 }
523
524 let mut result = scale_features[0].clone();
525 for features in scale_features.iter().skip(1) {
526 for ((i, j), val) in features.indexed_iter() {
527 if *val > result[[i, j]] {
528 result[[i, j]] = *val;
529 }
530 }
531 }
532
533 Ok(result)
534 }
535
536 fn average_pooling_features(
538 &self,
539 scale_features: Vec<Array2<Float>>,
540 ) -> Result<Array2<Float>> {
541 if scale_features.is_empty() {
542 return Err(SklearsError::InvalidInput(
543 "No scale features for average pooling".to_string(),
544 ));
545 }
546
547 let mut result = scale_features[0].clone();
548 for features in scale_features.iter().skip(1) {
549 result += features;
550 }
551
552 result.mapv_inplace(|x| x / self.n_scales as Float);
553 Ok(result)
554 }
555
556 fn attention_combine_features(
558 &self,
559 scale_features: Vec<Array2<Float>>,
560 ) -> Result<Array2<Float>> {
561 if scale_features.is_empty() {
562 return Err(SklearsError::InvalidInput(
563 "No scale features for attention combination".to_string(),
564 ));
565 }
566
567 let attention_weights =
568 self.attention_weights_
569 .as_ref()
570 .ok_or_else(|| SklearsError::NotFitted {
571 operation: "attention combination".to_string(),
572 })?;
573
574 let mut result = scale_features[0].clone() * attention_weights[0];
575 for (i, features) in scale_features.iter().enumerate().skip(1) {
576 result = result + features * attention_weights[i];
577 }
578
579 Ok(result)
580 }
581}
582
583#[allow(non_snake_case)]
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use approx::assert_abs_diff_eq;
588 use scirs2_core::ndarray::array;
589
590 #[test]
591 fn test_multi_scale_rbf_sampler_basic() {
592 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
593
594 let sampler = MultiScaleRBFSampler::new(10)
595 .n_scales(3)
596 .gamma_range(0.1, 10.0)
597 .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
598 .combination_strategy(CombinationStrategy::Concatenation)
599 .random_state(42);
600
601 let fitted = sampler.fit(&x, &()).unwrap();
602 let features = fitted.transform(&x).unwrap();
603
604 assert_eq!(features.shape(), &[3, 30]);
606
607 for &val in features.iter() {
609 assert!(val >= -2.0 && val <= 2.0);
610 }
611 }
612
613 #[test]
614 fn test_different_bandwidth_strategies() {
615 let x = array![[1.0, 2.0], [3.0, 4.0]];
616
617 let strategies = [
618 BandwidthStrategy::LogarithmicSpacing,
619 BandwidthStrategy::LinearSpacing,
620 BandwidthStrategy::GeometricProgression,
621 BandwidthStrategy::Adaptive,
622 ];
623
624 for strategy in &strategies {
625 let sampler = MultiScaleRBFSampler::new(5)
626 .n_scales(3)
627 .gamma_range(0.1, 10.0)
628 .bandwidth_strategy(*strategy)
629 .random_state(42);
630
631 let fitted = sampler.fit(&x, &()).unwrap();
632 let features = fitted.transform(&x).unwrap();
633
634 assert_eq!(features.shape(), &[2, 15]); }
636 }
637
638 #[test]
639 fn test_different_combination_strategies() {
640 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
641
642 let strategies = [
643 (CombinationStrategy::Concatenation, 30), (CombinationStrategy::WeightedAverage, 10), (CombinationStrategy::MaxPooling, 10),
646 (CombinationStrategy::AveragePooling, 10),
647 (CombinationStrategy::Attention, 10),
648 ];
649
650 for (strategy, expected_features) in &strategies {
651 let sampler = MultiScaleRBFSampler::new(10)
652 .n_scales(3)
653 .combination_strategy(*strategy)
654 .random_state(42);
655
656 let fitted = sampler.fit(&x, &()).unwrap();
657 let features = fitted.transform(&x).unwrap();
658
659 assert_eq!(features.shape(), &[3, *expected_features]);
660 }
661 }
662
663 #[test]
664 fn test_manual_gammas() {
665 let x = array![[1.0, 2.0], [3.0, 4.0]];
666 let manual_gammas = vec![0.1, 1.0, 10.0];
667
668 let sampler = MultiScaleRBFSampler::new(8)
669 .manual_gammas(manual_gammas.clone())
670 .random_state(42);
671
672 let fitted = sampler.fit(&x, &()).unwrap();
673 let features = fitted.transform(&x).unwrap();
674
675 assert_eq!(features.shape(), &[2, 24]); assert_eq!(fitted.gammas_.as_ref().unwrap(), &manual_gammas);
677 }
678
679 #[test]
680 fn test_scale_weights() {
681 let x = array![[1.0, 2.0], [3.0, 4.0]];
682 let weights = vec![1.0, 2.0, 0.5];
683
684 let sampler = MultiScaleRBFSampler::new(10)
685 .n_scales(3)
686 .combination_strategy(CombinationStrategy::WeightedAverage)
687 .scale_weights(weights.clone())
688 .random_state(42);
689
690 let fitted = sampler.fit(&x, &()).unwrap();
691 let features = fitted.transform(&x).unwrap();
692
693 assert_eq!(features.shape(), &[2, 10]);
694 }
695
696 #[test]
697 fn test_reproducibility() {
698 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
699
700 let sampler1 = MultiScaleRBFSampler::new(20)
701 .n_scales(4)
702 .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
703 .combination_strategy(CombinationStrategy::Concatenation)
704 .random_state(123);
705
706 let sampler2 = MultiScaleRBFSampler::new(20)
707 .n_scales(4)
708 .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing)
709 .combination_strategy(CombinationStrategy::Concatenation)
710 .random_state(123);
711
712 let fitted1 = sampler1.fit(&x, &()).unwrap();
713 let fitted2 = sampler2.fit(&x, &()).unwrap();
714
715 let features1 = fitted1.transform(&x).unwrap();
716 let features2 = fitted2.transform(&x).unwrap();
717
718 for (f1, f2) in features1.iter().zip(features2.iter()) {
719 assert_abs_diff_eq!(f1, f2, epsilon = 1e-10);
720 }
721 }
722
723 #[test]
724 fn test_adaptive_bandwidth() {
725 let x = array![
726 [1.0, 1.0],
727 [1.1, 1.1],
728 [5.0, 5.0],
729 [5.1, 5.1],
730 [10.0, 10.0],
731 [10.1, 10.1]
732 ];
733
734 let sampler = MultiScaleRBFSampler::new(15)
735 .n_scales(3)
736 .bandwidth_strategy(BandwidthStrategy::Adaptive)
737 .random_state(42);
738
739 let fitted = sampler.fit(&x, &()).unwrap();
740 let features = fitted.transform(&x).unwrap();
741
742 assert_eq!(features.shape(), &[6, 45]); let gammas = fitted.gammas_.as_ref().unwrap();
746 assert_eq!(gammas.len(), 3);
747 assert!(gammas.iter().all(|&g| g > 0.0));
748 }
749
750 #[test]
751 fn test_error_handling() {
752 let empty = Array2::<Float>::zeros((0, 0));
754 let sampler = MultiScaleRBFSampler::new(10);
755 assert!(sampler.clone().fit(&empty, &()).is_err());
756
757 let x = array![[1.0, 2.0]];
759 let invalid_sampler = MultiScaleRBFSampler::new(10).n_scales(0);
760 assert!(invalid_sampler.fit(&x, &()).is_err());
761
762 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
764 let x_test = array![[1.0, 2.0, 3.0]]; let fitted = sampler.fit(&x_train, &()).unwrap();
767 assert!(fitted.transform(&x_test).is_err());
768 }
769
770 #[test]
771 fn test_single_scale() {
772 let x = array![[1.0, 2.0], [3.0, 4.0]];
773
774 let sampler = MultiScaleRBFSampler::new(15)
775 .n_scales(1)
776 .gamma_range(1.0, 1.0)
777 .random_state(42);
778
779 let fitted = sampler.fit(&x, &()).unwrap();
780 let features = fitted.transform(&x).unwrap();
781
782 assert_eq!(features.shape(), &[2, 15]);
783
784 let gammas = fitted.gammas_.as_ref().unwrap();
785 assert_eq!(gammas.len(), 1);
786 }
787
788 #[test]
789 fn test_gamma_computation_strategies() {
790 let sampler = MultiScaleRBFSampler::new(10)
791 .n_scales(4)
792 .gamma_range(0.1, 10.0);
793
794 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
795
796 let log_sampler = sampler
798 .clone()
799 .bandwidth_strategy(BandwidthStrategy::LogarithmicSpacing);
800 let log_gammas = log_sampler.compute_gammas(&x).unwrap();
801 assert_eq!(log_gammas.len(), 4);
802 assert_abs_diff_eq!(log_gammas[0], 0.1, epsilon = 1e-10);
803 assert_abs_diff_eq!(log_gammas[3], 10.0, epsilon = 1e-10);
804
805 let lin_sampler = sampler
807 .clone()
808 .bandwidth_strategy(BandwidthStrategy::LinearSpacing);
809 let lin_gammas = lin_sampler.compute_gammas(&x).unwrap();
810 assert_eq!(lin_gammas.len(), 4);
811 assert_abs_diff_eq!(lin_gammas[0], 0.1, epsilon = 1e-10);
812 assert_abs_diff_eq!(lin_gammas[3], 10.0, epsilon = 1e-10);
813
814 let geo_sampler = sampler
816 .clone()
817 .bandwidth_strategy(BandwidthStrategy::GeometricProgression);
818 let geo_gammas = geo_sampler.compute_gammas(&x).unwrap();
819 assert_eq!(geo_gammas.len(), 4);
820 assert_abs_diff_eq!(geo_gammas[0], 0.1, epsilon = 1e-10);
821 assert_abs_diff_eq!(geo_gammas[3], 10.0, epsilon = 1e-10);
822 }
823}