1use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::{thread_rng, Rng, SeedableRng};
10use sklears_core::{
11 error::{Result, SklearsError},
12 traits::{Fit, Trained, Transform, Untrained},
13 types::Float,
14};
15use std::marker::PhantomData;
16
17#[derive(Debug, Clone, Copy)]
19pub enum BandwidthSelectionStrategy {
21 CrossValidation,
23 MaximumLikelihood,
25 MedianHeuristic,
27 ScottRule,
29 SilvermanRule,
31 LeaveOneOut,
33 GridSearch,
35}
36
37#[derive(Debug, Clone, Copy)]
39pub enum ObjectiveFunction {
41 KernelAlignment,
43 LogLikelihood,
45 CrossValidationError,
47 KernelTrace,
49 EffectiveDimensionality,
51}
52
53#[derive(Debug, Clone)]
86pub struct AdaptiveBandwidthRBFSampler<State = Untrained> {
88 pub n_components: usize,
90 pub strategy: BandwidthSelectionStrategy,
92 pub objective_function: ObjectiveFunction,
94 pub gamma_range: (Float, Float),
96 pub n_gamma_candidates: usize,
98 pub cv_folds: usize,
100 pub random_state: Option<u64>,
102 pub tolerance: Float,
104 pub max_iterations: usize,
106
107 selected_gamma_: Option<Float>,
109 random_weights_: Option<Array2<Float>>,
110 random_offset_: Option<Array1<Float>>,
111 optimization_history_: Option<Vec<(Float, Float)>>, _state: PhantomData<State>,
115}
116
117impl AdaptiveBandwidthRBFSampler<Untrained> {
118 pub fn new(n_components: usize) -> Self {
123 Self {
124 n_components,
125 strategy: BandwidthSelectionStrategy::MedianHeuristic,
126 objective_function: ObjectiveFunction::KernelAlignment,
127 gamma_range: (1e-3, 1e3),
128 n_gamma_candidates: 20,
129 cv_folds: 5,
130 random_state: None,
131 tolerance: 1e-6,
132 max_iterations: 100,
133 selected_gamma_: None,
134 random_weights_: None,
135 random_offset_: None,
136 optimization_history_: None,
137 _state: PhantomData,
138 }
139 }
140
141 pub fn strategy(mut self, strategy: BandwidthSelectionStrategy) -> Self {
143 self.strategy = strategy;
144 self
145 }
146
147 pub fn objective_function(mut self, objective: ObjectiveFunction) -> Self {
149 self.objective_function = objective;
150 self
151 }
152
153 pub fn gamma_range(mut self, min: Float, max: Float) -> Self {
155 self.gamma_range = (min, max);
156 self
157 }
158
159 pub fn n_gamma_candidates(mut self, n: usize) -> Self {
161 self.n_gamma_candidates = n;
162 self
163 }
164
165 pub fn cv_folds(mut self, folds: usize) -> Self {
167 self.cv_folds = folds;
168 self
169 }
170
171 pub fn random_state(mut self, seed: u64) -> Self {
173 self.random_state = Some(seed);
174 self
175 }
176
177 pub fn tolerance(mut self, tol: Float) -> Self {
179 self.tolerance = tol;
180 self
181 }
182
183 pub fn max_iterations(mut self, max_iter: usize) -> Self {
185 self.max_iterations = max_iter;
186 self
187 }
188
189 fn select_gamma(&self, x: &Array2<Float>) -> Result<Float> {
191 match self.strategy {
192 BandwidthSelectionStrategy::MedianHeuristic => self.median_heuristic_gamma(x),
193 BandwidthSelectionStrategy::ScottRule => self.scott_rule_gamma(x),
194 BandwidthSelectionStrategy::SilvermanRule => self.silverman_rule_gamma(x),
195 BandwidthSelectionStrategy::CrossValidation => self.cross_validation_gamma(x),
196 BandwidthSelectionStrategy::MaximumLikelihood => self.maximum_likelihood_gamma(x),
197 BandwidthSelectionStrategy::LeaveOneOut => self.leave_one_out_gamma(x),
198 BandwidthSelectionStrategy::GridSearch => self.grid_search_gamma(x),
199 }
200 }
201
202 fn median_heuristic_gamma(&self, x: &Array2<Float>) -> Result<Float> {
204 let (n_samples, _) = x.dim();
205
206 if n_samples < 2 {
207 return Ok(1.0); }
209
210 let n_pairs = if n_samples > 1000 {
212 1000
214 } else {
215 n_samples * (n_samples - 1) / 2
216 };
217
218 let mut distances_sq = Vec::with_capacity(n_pairs);
219 let step = if n_samples > 1000 { n_samples / 100 } else { 1 };
220
221 for i in (0..n_samples).step_by(step) {
222 for j in ((i + 1)..n_samples).step_by(step) {
223 if distances_sq.len() >= n_pairs {
224 break;
225 }
226 let diff = &x.row(i) - &x.row(j);
227 let dist_sq = diff.mapv(|v| v * v).sum();
228 distances_sq.push(dist_sq);
229 }
230 if distances_sq.len() >= n_pairs {
231 break;
232 }
233 }
234
235 if distances_sq.is_empty() {
236 return Ok(1.0);
237 }
238
239 distances_sq.sort_by(|a, b| a.partial_cmp(b).unwrap());
241 let median_dist_sq = distances_sq[distances_sq.len() / 2];
242
243 Ok(if median_dist_sq > 0.0 {
245 1.0 / (2.0 * median_dist_sq)
246 } else {
247 1.0
248 })
249 }
250
251 fn scott_rule_gamma(&self, x: &Array2<Float>) -> Result<Float> {
253 let (n_samples, n_features) = x.dim();
254 let sigma = (n_samples as Float).powf(-1.0 / (n_features as Float + 4.0));
255 Ok(1.0 / (2.0 * sigma * sigma))
256 }
257
258 fn silverman_rule_gamma(&self, x: &Array2<Float>) -> Result<Float> {
260 let (n_samples, n_features) = x.dim();
261
262 let means = x.mean_axis(Axis(0)).unwrap();
264 let mut stds = Array1::zeros(n_features);
265
266 for j in 0..n_features {
267 let var = x
268 .column(j)
269 .mapv(|v| {
270 let diff = v - means[j];
271 diff * diff
272 })
273 .mean()
274 .unwrap();
275 stds[j] = var.sqrt();
276 }
277
278 let avg_std = stds.mean().unwrap();
279 let h = 1.06 * avg_std * (n_samples as Float).powf(-1.0 / 5.0);
280
281 Ok(1.0 / (2.0 * h * h))
282 }
283
284 fn cross_validation_gamma(&self, x: &Array2<Float>) -> Result<Float> {
286 let gamma_candidates = self.generate_gamma_candidates()?;
287 let mut best_gamma = gamma_candidates[0];
288 let mut best_score = Float::INFINITY;
289
290 for &gamma in &gamma_candidates {
291 let score = self.cross_validation_score(x, gamma)?;
292 if score < best_score {
293 best_score = score;
294 best_gamma = gamma;
295 }
296 }
297
298 Ok(best_gamma)
299 }
300
301 fn maximum_likelihood_gamma(&self, x: &Array2<Float>) -> Result<Float> {
303 let gamma_candidates = self.generate_gamma_candidates()?;
304 let mut best_gamma = gamma_candidates[0];
305 let mut best_likelihood = Float::NEG_INFINITY;
306
307 for &gamma in &gamma_candidates {
308 let likelihood = self.log_likelihood(x, gamma)?;
309 if likelihood > best_likelihood {
310 best_likelihood = likelihood;
311 best_gamma = gamma;
312 }
313 }
314
315 Ok(best_gamma)
316 }
317
318 fn leave_one_out_gamma(&self, x: &Array2<Float>) -> Result<Float> {
320 let gamma_candidates = self.generate_gamma_candidates()?;
321 let mut best_gamma = gamma_candidates[0];
322 let mut best_score = Float::INFINITY;
323
324 for &gamma in &gamma_candidates {
325 let score = self.leave_one_out_score(x, gamma)?;
326 if score < best_score {
327 best_score = score;
328 best_gamma = gamma;
329 }
330 }
331
332 Ok(best_gamma)
333 }
334
335 fn grid_search_gamma(&self, x: &Array2<Float>) -> Result<Float> {
337 let gamma_candidates = self.generate_gamma_candidates()?;
338 let mut best_gamma = gamma_candidates[0];
339 let mut best_score = match self.objective_function {
340 ObjectiveFunction::LogLikelihood => Float::NEG_INFINITY,
341 _ => Float::INFINITY,
342 };
343
344 for &gamma in &gamma_candidates {
345 let score = self.evaluate_objective(x, gamma)?;
346 let is_better = match self.objective_function {
347 ObjectiveFunction::LogLikelihood => score > best_score,
348 _ => score < best_score,
349 };
350
351 if is_better {
352 best_score = score;
353 best_gamma = gamma;
354 }
355 }
356
357 Ok(best_gamma)
358 }
359
360 fn generate_gamma_candidates(&self) -> Result<Vec<Float>> {
362 let (gamma_min, gamma_max) = self.gamma_range;
363 let log_min = gamma_min.ln();
364 let log_max = gamma_max.ln();
365
366 let mut candidates = Vec::with_capacity(self.n_gamma_candidates);
367 for i in 0..self.n_gamma_candidates {
368 let t = i as Float / (self.n_gamma_candidates - 1) as Float;
369 let log_gamma = log_min + t * (log_max - log_min);
370 candidates.push(log_gamma.exp());
371 }
372
373 Ok(candidates)
374 }
375
376 fn cross_validation_score(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
378 let (n_samples, _) = x.dim();
379 let fold_size = n_samples / self.cv_folds;
380 let mut total_error = 0.0;
381
382 for fold in 0..self.cv_folds {
383 let start_idx = fold * fold_size;
384 let end_idx = if fold == self.cv_folds - 1 {
385 n_samples
386 } else {
387 (fold + 1) * fold_size
388 };
389
390 let val_indices: Vec<usize> = (start_idx..end_idx).collect();
392 let train_indices: Vec<usize> = (0..start_idx).chain(end_idx..n_samples).collect();
393
394 if train_indices.is_empty() || val_indices.is_empty() {
395 continue;
396 }
397
398 let error = self.kernel_approximation_error(x, gamma, &train_indices, &val_indices)?;
400 total_error += error;
401 }
402
403 Ok(total_error / self.cv_folds as Float)
404 }
405
406 fn log_likelihood(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
408 let (n_samples, _) = x.dim();
409
410 let mut k_matrix = Array2::zeros((n_samples, n_samples));
412 for i in 0..n_samples {
413 for j in i..n_samples {
414 let diff = &x.row(i) - &x.row(j);
415 let dist_sq = diff.mapv(|v| v * v).sum();
416 let k_val = (-gamma * dist_sq).exp();
417 k_matrix[[i, j]] = k_val;
418 if i != j {
419 k_matrix[[j, i]] = k_val;
420 }
421 }
422 }
423
424 for i in 0..n_samples {
426 k_matrix[[i, i]] += 1e-6;
427 }
428
429 let trace = k_matrix.diag().sum();
431 let det_approx = trace; Ok(-0.5 * det_approx.ln() - 0.5 * n_samples as Float)
434 }
435
436 fn leave_one_out_score(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
438 let (n_samples, _) = x.dim();
439 let mut total_error = 0.0;
440
441 for i in 0..n_samples {
442 let train_indices: Vec<usize> = (0..n_samples).filter(|&j| j != i).collect();
443 let val_indices = vec![i];
444
445 let error = self.kernel_approximation_error(x, gamma, &train_indices, &val_indices)?;
446 total_error += error;
447 }
448
449 Ok(total_error / n_samples as Float)
450 }
451
452 fn evaluate_objective(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
454 match self.objective_function {
455 ObjectiveFunction::KernelAlignment => self.kernel_alignment(x, gamma),
456 ObjectiveFunction::LogLikelihood => self.log_likelihood(x, gamma),
457 ObjectiveFunction::CrossValidationError => self.cross_validation_score(x, gamma),
458 ObjectiveFunction::KernelTrace => self.kernel_trace(x, gamma),
459 ObjectiveFunction::EffectiveDimensionality => self.effective_dimensionality(x, gamma),
460 }
461 }
462
463 fn kernel_alignment(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
465 let (n_samples, _) = x.dim();
466
467 let mut alignment = 0.0;
469 let mut count = 0;
470
471 for i in 0..n_samples.min(100) {
472 for j in (i + 1)..n_samples.min(100) {
474 let diff = &x.row(i) - &x.row(j);
475 let dist_sq = diff.mapv(|v| v * v).sum();
476 let k_val = (-gamma * dist_sq).exp();
477 alignment += k_val * k_val; count += 1;
479 }
480 }
481
482 Ok(if count > 0 {
483 -alignment / count as Float
484 } else {
485 0.0
486 })
487 }
488
489 fn kernel_trace(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
491 let (n_samples, _) = x.dim();
492 let trace = n_samples as Float; Ok(-trace) }
495
496 fn effective_dimensionality(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
498 let characteristic_length = (1.0 / gamma).sqrt();
500 let (_, n_features) = x.dim();
501 let eff_dim = (characteristic_length * n_features as Float).min(n_features as Float);
502 Ok(-eff_dim) }
504
505 fn kernel_approximation_error(
507 &self,
508 x: &Array2<Float>,
509 gamma: Float,
510 train_indices: &[usize],
511 val_indices: &[usize],
512 ) -> Result<Float> {
513 if train_indices.is_empty() || val_indices.is_empty() {
514 return Ok(0.0);
515 }
516
517 let mut error = 0.0;
519 let mut count = 0;
520
521 for &i in val_indices {
522 for &j in train_indices {
523 let diff = &x.row(i) - &x.row(j);
524 let dist_sq = diff.mapv(|v| v * v).sum();
525 let true_kernel = (-gamma * dist_sq).exp();
526
527 let approx_error = (1.0 - true_kernel) * (1.0 - true_kernel);
529 error += approx_error;
530 count += 1;
531 }
532 }
533
534 Ok(if count > 0 {
535 error / count as Float
536 } else {
537 0.0
538 })
539 }
540}
541
542impl Fit<Array2<Float>, ()> for AdaptiveBandwidthRBFSampler<Untrained> {
543 type Fitted = AdaptiveBandwidthRBFSampler<Trained>;
544
545 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
546 let (n_samples, n_features) = x.dim();
547
548 if n_samples == 0 || n_features == 0 {
549 return Err(SklearsError::InvalidInput(
550 "Input array is empty".to_string(),
551 ));
552 }
553
554 let selected_gamma = self.select_gamma(x)?;
556
557 let mut rng = match self.random_state {
558 Some(seed) => RealStdRng::seed_from_u64(seed),
559 None => RealStdRng::from_seed(thread_rng().gen()),
560 };
561
562 let std_dev = (2.0 * selected_gamma).sqrt();
564 let mut random_weights = Array2::zeros((self.n_components, n_features));
565 for i in 0..self.n_components {
566 for j in 0..n_features {
567 let u1 = rng.gen::<Float>();
569 let u2 = rng.gen::<Float>();
570 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
571 random_weights[[i, j]] = z * std_dev;
572 }
573 }
574
575 let mut random_offset = Array1::zeros(self.n_components);
577 for i in 0..self.n_components {
578 random_offset[i] = rng.gen::<Float>() * 2.0 * std::f64::consts::PI;
579 }
580
581 Ok(AdaptiveBandwidthRBFSampler {
582 n_components: self.n_components,
583 strategy: self.strategy,
584 objective_function: self.objective_function,
585 gamma_range: self.gamma_range,
586 n_gamma_candidates: self.n_gamma_candidates,
587 cv_folds: self.cv_folds,
588 random_state: self.random_state,
589 tolerance: self.tolerance,
590 max_iterations: self.max_iterations,
591 selected_gamma_: Some(selected_gamma),
592 random_weights_: Some(random_weights),
593 random_offset_: Some(random_offset),
594 optimization_history_: None, _state: PhantomData,
596 })
597 }
598}
599
600impl AdaptiveBandwidthRBFSampler<Trained> {
601 pub fn selected_gamma(&self) -> Result<Float> {
603 self.selected_gamma_.ok_or_else(|| SklearsError::NotFitted {
604 operation: "selected_gamma".to_string(),
605 })
606 }
607
608 pub fn optimization_history(&self) -> Option<&Vec<(Float, Float)>> {
610 self.optimization_history_.as_ref()
611 }
612}
613
614impl Transform<Array2<Float>> for AdaptiveBandwidthRBFSampler<Trained> {
615 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
616 let random_weights =
617 self.random_weights_
618 .as_ref()
619 .ok_or_else(|| SklearsError::NotFitted {
620 operation: "transform".to_string(),
621 })?;
622
623 let random_offset =
624 self.random_offset_
625 .as_ref()
626 .ok_or_else(|| SklearsError::NotFitted {
627 operation: "transform".to_string(),
628 })?;
629
630 let (n_samples, n_features) = x.dim();
631
632 if n_features != random_weights.ncols() {
633 return Err(SklearsError::InvalidInput(format!(
634 "Input has {} features, expected {}",
635 n_features,
636 random_weights.ncols()
637 )));
638 }
639
640 let projection = x.dot(&random_weights.t()) + random_offset;
642
643 let normalization = (2.0 / random_weights.nrows() as Float).sqrt();
645 Ok(projection.mapv(|x| x.cos() * normalization))
646 }
647}
648
649#[allow(non_snake_case)]
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use approx::assert_abs_diff_eq;
654 use scirs2_core::ndarray::array;
655
656 #[test]
657 fn test_adaptive_bandwidth_rbf_sampler_basic() {
658 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
659
660 let sampler = AdaptiveBandwidthRBFSampler::new(50)
661 .strategy(BandwidthSelectionStrategy::MedianHeuristic)
662 .random_state(42);
663
664 let fitted = sampler.fit(&x, &()).unwrap();
665 let features = fitted.transform(&x).unwrap();
666
667 assert_eq!(features.shape(), &[3, 50]);
668
669 let gamma = fitted.selected_gamma().unwrap();
671 assert!(gamma > 0.0);
672
673 for &val in features.iter() {
675 assert!(val >= -2.0 && val <= 2.0);
676 }
677 }
678
679 #[test]
680 fn test_different_bandwidth_strategies() {
681 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
682
683 let strategies = [
684 BandwidthSelectionStrategy::MedianHeuristic,
685 BandwidthSelectionStrategy::ScottRule,
686 BandwidthSelectionStrategy::SilvermanRule,
687 BandwidthSelectionStrategy::GridSearch,
688 ];
689
690 for strategy in &strategies {
691 let sampler = AdaptiveBandwidthRBFSampler::new(20)
692 .strategy(*strategy)
693 .random_state(42);
694
695 let fitted = sampler.fit(&x, &()).unwrap();
696 let features = fitted.transform(&x).unwrap();
697 let gamma = fitted.selected_gamma().unwrap();
698
699 assert_eq!(features.shape(), &[4, 20]);
700 assert!(gamma > 0.0);
701 }
702 }
703
704 #[test]
705 fn test_cross_validation_strategy() {
706 let x = array![
707 [1.0, 1.0],
708 [1.1, 1.1],
709 [2.0, 2.0],
710 [2.1, 2.1],
711 [5.0, 5.0],
712 [5.1, 5.1]
713 ];
714
715 let sampler = AdaptiveBandwidthRBFSampler::new(30)
716 .strategy(BandwidthSelectionStrategy::CrossValidation)
717 .cv_folds(3)
718 .n_gamma_candidates(5)
719 .random_state(42);
720
721 let fitted = sampler.fit(&x, &()).unwrap();
722 let features = fitted.transform(&x).unwrap();
723 let gamma = fitted.selected_gamma().unwrap();
724
725 assert_eq!(features.shape(), &[6, 30]);
726 assert!(gamma > 0.0);
727 }
728
729 #[test]
730 fn test_different_objective_functions() {
731 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
732
733 let objectives = [
734 ObjectiveFunction::KernelAlignment,
735 ObjectiveFunction::LogLikelihood,
736 ObjectiveFunction::KernelTrace,
737 ObjectiveFunction::EffectiveDimensionality,
738 ];
739
740 for objective in &objectives {
741 let sampler = AdaptiveBandwidthRBFSampler::new(25)
742 .strategy(BandwidthSelectionStrategy::GridSearch)
743 .objective_function(*objective)
744 .n_gamma_candidates(5)
745 .random_state(42);
746
747 let fitted = sampler.fit(&x, &()).unwrap();
748 let gamma = fitted.selected_gamma().unwrap();
749
750 assert!(gamma > 0.0);
751 }
752 }
753
754 #[test]
755 fn test_median_heuristic() {
756 let x = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
757
758 let sampler = AdaptiveBandwidthRBFSampler::new(10);
759 let gamma = sampler.median_heuristic_gamma(&x).unwrap();
760
761 assert!(gamma > 0.1 && gamma < 2.0);
763 }
764
765 #[test]
766 fn test_scott_rule() {
767 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
768
769 let sampler = AdaptiveBandwidthRBFSampler::new(10);
770 let gamma = sampler.scott_rule_gamma(&x).unwrap();
771
772 assert!(gamma > 0.0);
773 }
774
775 #[test]
776 fn test_silverman_rule() {
777 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
778
779 let sampler = AdaptiveBandwidthRBFSampler::new(10);
780 let gamma = sampler.silverman_rule_gamma(&x).unwrap();
781
782 assert!(gamma > 0.0);
783 }
784
785 #[test]
786 fn test_reproducibility() {
787 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
788
789 let sampler1 = AdaptiveBandwidthRBFSampler::new(40)
790 .strategy(BandwidthSelectionStrategy::MedianHeuristic)
791 .random_state(123);
792
793 let sampler2 = AdaptiveBandwidthRBFSampler::new(40)
794 .strategy(BandwidthSelectionStrategy::MedianHeuristic)
795 .random_state(123);
796
797 let fitted1 = sampler1.fit(&x, &()).unwrap();
798 let fitted2 = sampler2.fit(&x, &()).unwrap();
799
800 let features1 = fitted1.transform(&x).unwrap();
801 let features2 = fitted2.transform(&x).unwrap();
802
803 let gamma1 = fitted1.selected_gamma().unwrap();
804 let gamma2 = fitted2.selected_gamma().unwrap();
805
806 assert_abs_diff_eq!(gamma1, gamma2, epsilon = 1e-10);
807
808 for (f1, f2) in features1.iter().zip(features2.iter()) {
809 assert_abs_diff_eq!(f1, f2, epsilon = 1e-10);
810 }
811 }
812
813 #[test]
814 fn test_gamma_range() {
815 let x = array![[1.0, 2.0], [3.0, 4.0]];
816
817 let sampler = AdaptiveBandwidthRBFSampler::new(15)
818 .strategy(BandwidthSelectionStrategy::GridSearch)
819 .gamma_range(0.5, 2.0)
820 .n_gamma_candidates(5)
821 .random_state(42);
822
823 let fitted = sampler.fit(&x, &()).unwrap();
824 let gamma = fitted.selected_gamma().unwrap();
825
826 assert!(gamma >= 0.5 && gamma <= 2.0);
828 }
829
830 #[test]
831 fn test_error_handling() {
832 let empty = Array2::<Float>::zeros((0, 0));
834 let sampler = AdaptiveBandwidthRBFSampler::new(10);
835 assert!(sampler.clone().fit(&empty, &()).is_err());
836
837 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
839 let x_test = array![[1.0, 2.0, 3.0]]; let fitted = sampler.fit(&x_train, &()).unwrap();
842 assert!(fitted.transform(&x_test).is_err());
843 }
844
845 #[test]
846 fn test_single_sample() {
847 let x = array![[1.0, 2.0]];
848
849 let sampler = AdaptiveBandwidthRBFSampler::new(10)
850 .strategy(BandwidthSelectionStrategy::MedianHeuristic);
851
852 let fitted = sampler.fit(&x, &()).unwrap();
853 let gamma = fitted.selected_gamma().unwrap();
854
855 assert!(gamma > 0.0);
857 }
858
859 #[test]
860 fn test_large_dataset_efficiency() {
861 let mut data = Vec::new();
863 for i in 0..500 {
864 data.push([i as Float, (i * 2) as Float]);
865 }
866 let x = Array2::from(data);
867
868 let sampler = AdaptiveBandwidthRBFSampler::new(20)
869 .strategy(BandwidthSelectionStrategy::MedianHeuristic);
870
871 let fitted = sampler.fit(&x, &()).unwrap();
872 let gamma = fitted.selected_gamma().unwrap();
873
874 assert!(gamma > 0.0);
875 }
876}