1use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use sklears_core::{
12 error::{Result, SklearsError},
13 traits::{Fit, Trained, Transform, Untrained},
14 types::Float,
15};
16use std::marker::PhantomData;
17
18#[derive(Debug, Clone, Copy)]
20pub enum BandwidthSelectionStrategy {
22 CrossValidation,
24 MaximumLikelihood,
26 MedianHeuristic,
28 ScottRule,
30 SilvermanRule,
32 LeaveOneOut,
34 GridSearch,
36}
37
38#[derive(Debug, Clone, Copy)]
40pub enum ObjectiveFunction {
42 KernelAlignment,
44 LogLikelihood,
46 CrossValidationError,
48 KernelTrace,
50 EffectiveDimensionality,
52}
53
54#[derive(Debug, Clone)]
87pub struct AdaptiveBandwidthRBFSampler<State = Untrained> {
89 pub n_components: usize,
91 pub strategy: BandwidthSelectionStrategy,
93 pub objective_function: ObjectiveFunction,
95 pub gamma_range: (Float, Float),
97 pub n_gamma_candidates: usize,
99 pub cv_folds: usize,
101 pub random_state: Option<u64>,
103 pub tolerance: Float,
105 pub max_iterations: usize,
107
108 selected_gamma_: Option<Float>,
110 random_weights_: Option<Array2<Float>>,
111 random_offset_: Option<Array1<Float>>,
112 optimization_history_: Option<Vec<(Float, Float)>>, _state: PhantomData<State>,
116}
117
118impl AdaptiveBandwidthRBFSampler<Untrained> {
119 pub fn new(n_components: usize) -> Self {
124 Self {
125 n_components,
126 strategy: BandwidthSelectionStrategy::MedianHeuristic,
127 objective_function: ObjectiveFunction::KernelAlignment,
128 gamma_range: (1e-3, 1e3),
129 n_gamma_candidates: 20,
130 cv_folds: 5,
131 random_state: None,
132 tolerance: 1e-6,
133 max_iterations: 100,
134 selected_gamma_: None,
135 random_weights_: None,
136 random_offset_: None,
137 optimization_history_: None,
138 _state: PhantomData,
139 }
140 }
141
142 pub fn strategy(mut self, strategy: BandwidthSelectionStrategy) -> Self {
144 self.strategy = strategy;
145 self
146 }
147
148 pub fn objective_function(mut self, objective: ObjectiveFunction) -> Self {
150 self.objective_function = objective;
151 self
152 }
153
154 pub fn gamma_range(mut self, min: Float, max: Float) -> Self {
156 self.gamma_range = (min, max);
157 self
158 }
159
160 pub fn n_gamma_candidates(mut self, n: usize) -> Self {
162 self.n_gamma_candidates = n;
163 self
164 }
165
166 pub fn cv_folds(mut self, folds: usize) -> Self {
168 self.cv_folds = folds;
169 self
170 }
171
172 pub fn random_state(mut self, seed: u64) -> Self {
174 self.random_state = Some(seed);
175 self
176 }
177
178 pub fn tolerance(mut self, tol: Float) -> Self {
180 self.tolerance = tol;
181 self
182 }
183
184 pub fn max_iterations(mut self, max_iter: usize) -> Self {
186 self.max_iterations = max_iter;
187 self
188 }
189
190 fn select_gamma(&self, x: &Array2<Float>) -> Result<Float> {
192 match self.strategy {
193 BandwidthSelectionStrategy::MedianHeuristic => self.median_heuristic_gamma(x),
194 BandwidthSelectionStrategy::ScottRule => self.scott_rule_gamma(x),
195 BandwidthSelectionStrategy::SilvermanRule => self.silverman_rule_gamma(x),
196 BandwidthSelectionStrategy::CrossValidation => self.cross_validation_gamma(x),
197 BandwidthSelectionStrategy::MaximumLikelihood => self.maximum_likelihood_gamma(x),
198 BandwidthSelectionStrategy::LeaveOneOut => self.leave_one_out_gamma(x),
199 BandwidthSelectionStrategy::GridSearch => self.grid_search_gamma(x),
200 }
201 }
202
203 fn median_heuristic_gamma(&self, x: &Array2<Float>) -> Result<Float> {
205 let (n_samples, _) = x.dim();
206
207 if n_samples < 2 {
208 return Ok(1.0); }
210
211 let n_pairs = if n_samples > 1000 {
213 1000
215 } else {
216 n_samples * (n_samples - 1) / 2
217 };
218
219 let mut distances_sq = Vec::with_capacity(n_pairs);
220 let step = if n_samples > 1000 { n_samples / 100 } else { 1 };
221
222 for i in (0..n_samples).step_by(step) {
223 for j in ((i + 1)..n_samples).step_by(step) {
224 if distances_sq.len() >= n_pairs {
225 break;
226 }
227 let diff = &x.row(i) - &x.row(j);
228 let dist_sq = diff.mapv(|v| v * v).sum();
229 distances_sq.push(dist_sq);
230 }
231 if distances_sq.len() >= n_pairs {
232 break;
233 }
234 }
235
236 if distances_sq.is_empty() {
237 return Ok(1.0);
238 }
239
240 distances_sq.sort_by(|a, b| a.partial_cmp(b).unwrap());
242 let median_dist_sq = distances_sq[distances_sq.len() / 2];
243
244 Ok(if median_dist_sq > 0.0 {
246 1.0 / (2.0 * median_dist_sq)
247 } else {
248 1.0
249 })
250 }
251
252 fn scott_rule_gamma(&self, x: &Array2<Float>) -> Result<Float> {
254 let (n_samples, n_features) = x.dim();
255 let sigma = (n_samples as Float).powf(-1.0 / (n_features as Float + 4.0));
256 Ok(1.0 / (2.0 * sigma * sigma))
257 }
258
259 fn silverman_rule_gamma(&self, x: &Array2<Float>) -> Result<Float> {
261 let (n_samples, n_features) = x.dim();
262
263 let means = x.mean_axis(Axis(0)).unwrap();
265 let mut stds = Array1::zeros(n_features);
266
267 for j in 0..n_features {
268 let var = x
269 .column(j)
270 .mapv(|v| {
271 let diff = v - means[j];
272 diff * diff
273 })
274 .mean()
275 .unwrap();
276 stds[j] = var.sqrt();
277 }
278
279 let avg_std = stds.mean().unwrap();
280 let h = 1.06 * avg_std * (n_samples as Float).powf(-1.0 / 5.0);
281
282 Ok(1.0 / (2.0 * h * h))
283 }
284
285 fn cross_validation_gamma(&self, x: &Array2<Float>) -> Result<Float> {
287 let gamma_candidates = self.generate_gamma_candidates()?;
288 let mut best_gamma = gamma_candidates[0];
289 let mut best_score = Float::INFINITY;
290
291 for &gamma in &gamma_candidates {
292 let score = self.cross_validation_score(x, gamma)?;
293 if score < best_score {
294 best_score = score;
295 best_gamma = gamma;
296 }
297 }
298
299 Ok(best_gamma)
300 }
301
302 fn maximum_likelihood_gamma(&self, x: &Array2<Float>) -> Result<Float> {
304 let gamma_candidates = self.generate_gamma_candidates()?;
305 let mut best_gamma = gamma_candidates[0];
306 let mut best_likelihood = Float::NEG_INFINITY;
307
308 for &gamma in &gamma_candidates {
309 let likelihood = self.log_likelihood(x, gamma)?;
310 if likelihood > best_likelihood {
311 best_likelihood = likelihood;
312 best_gamma = gamma;
313 }
314 }
315
316 Ok(best_gamma)
317 }
318
319 fn leave_one_out_gamma(&self, x: &Array2<Float>) -> Result<Float> {
321 let gamma_candidates = self.generate_gamma_candidates()?;
322 let mut best_gamma = gamma_candidates[0];
323 let mut best_score = Float::INFINITY;
324
325 for &gamma in &gamma_candidates {
326 let score = self.leave_one_out_score(x, gamma)?;
327 if score < best_score {
328 best_score = score;
329 best_gamma = gamma;
330 }
331 }
332
333 Ok(best_gamma)
334 }
335
336 fn grid_search_gamma(&self, x: &Array2<Float>) -> Result<Float> {
338 let gamma_candidates = self.generate_gamma_candidates()?;
339 let mut best_gamma = gamma_candidates[0];
340 let mut best_score = match self.objective_function {
341 ObjectiveFunction::LogLikelihood => Float::NEG_INFINITY,
342 _ => Float::INFINITY,
343 };
344
345 for &gamma in &gamma_candidates {
346 let score = self.evaluate_objective(x, gamma)?;
347 let is_better = match self.objective_function {
348 ObjectiveFunction::LogLikelihood => score > best_score,
349 _ => score < best_score,
350 };
351
352 if is_better {
353 best_score = score;
354 best_gamma = gamma;
355 }
356 }
357
358 Ok(best_gamma)
359 }
360
361 fn generate_gamma_candidates(&self) -> Result<Vec<Float>> {
363 let (gamma_min, gamma_max) = self.gamma_range;
364 let log_min = gamma_min.ln();
365 let log_max = gamma_max.ln();
366
367 let mut candidates = Vec::with_capacity(self.n_gamma_candidates);
368 for i in 0..self.n_gamma_candidates {
369 let t = i as Float / (self.n_gamma_candidates - 1) as Float;
370 let log_gamma = log_min + t * (log_max - log_min);
371 candidates.push(log_gamma.exp());
372 }
373
374 Ok(candidates)
375 }
376
377 fn cross_validation_score(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
379 let (n_samples, _) = x.dim();
380 let fold_size = n_samples / self.cv_folds;
381 let mut total_error = 0.0;
382
383 for fold in 0..self.cv_folds {
384 let start_idx = fold * fold_size;
385 let end_idx = if fold == self.cv_folds - 1 {
386 n_samples
387 } else {
388 (fold + 1) * fold_size
389 };
390
391 let val_indices: Vec<usize> = (start_idx..end_idx).collect();
393 let train_indices: Vec<usize> = (0..start_idx).chain(end_idx..n_samples).collect();
394
395 if train_indices.is_empty() || val_indices.is_empty() {
396 continue;
397 }
398
399 let error = self.kernel_approximation_error(x, gamma, &train_indices, &val_indices)?;
401 total_error += error;
402 }
403
404 Ok(total_error / self.cv_folds as Float)
405 }
406
407 fn log_likelihood(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
409 let (n_samples, _) = x.dim();
410
411 let mut k_matrix = Array2::zeros((n_samples, n_samples));
413 for i in 0..n_samples {
414 for j in i..n_samples {
415 let diff = &x.row(i) - &x.row(j);
416 let dist_sq = diff.mapv(|v| v * v).sum();
417 let k_val = (-gamma * dist_sq).exp();
418 k_matrix[[i, j]] = k_val;
419 if i != j {
420 k_matrix[[j, i]] = k_val;
421 }
422 }
423 }
424
425 for i in 0..n_samples {
427 k_matrix[[i, i]] += 1e-6;
428 }
429
430 let trace = k_matrix.diag().sum();
432 let det_approx = trace; Ok(-0.5 * det_approx.ln() - 0.5 * n_samples as Float)
435 }
436
437 fn leave_one_out_score(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
439 let (n_samples, _) = x.dim();
440 let mut total_error = 0.0;
441
442 for i in 0..n_samples {
443 let train_indices: Vec<usize> = (0..n_samples).filter(|&j| j != i).collect();
444 let val_indices = vec![i];
445
446 let error = self.kernel_approximation_error(x, gamma, &train_indices, &val_indices)?;
447 total_error += error;
448 }
449
450 Ok(total_error / n_samples as Float)
451 }
452
453 fn evaluate_objective(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
455 match self.objective_function {
456 ObjectiveFunction::KernelAlignment => self.kernel_alignment(x, gamma),
457 ObjectiveFunction::LogLikelihood => self.log_likelihood(x, gamma),
458 ObjectiveFunction::CrossValidationError => self.cross_validation_score(x, gamma),
459 ObjectiveFunction::KernelTrace => self.kernel_trace(x, gamma),
460 ObjectiveFunction::EffectiveDimensionality => self.effective_dimensionality(x, gamma),
461 }
462 }
463
464 fn kernel_alignment(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
466 let (n_samples, _) = x.dim();
467
468 let mut alignment = 0.0;
470 let mut count = 0;
471
472 for i in 0..n_samples.min(100) {
473 for j in (i + 1)..n_samples.min(100) {
475 let diff = &x.row(i) - &x.row(j);
476 let dist_sq = diff.mapv(|v| v * v).sum();
477 let k_val = (-gamma * dist_sq).exp();
478 alignment += k_val * k_val; count += 1;
480 }
481 }
482
483 Ok(if count > 0 {
484 -alignment / count as Float
485 } else {
486 0.0
487 })
488 }
489
490 fn kernel_trace(&self, x: &Array2<Float>, _gamma: Float) -> Result<Float> {
492 let (n_samples, _) = x.dim();
493 let trace = n_samples as Float; Ok(-trace) }
496
497 fn effective_dimensionality(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
499 let characteristic_length = (1.0 / gamma).sqrt();
501 let (_, n_features) = x.dim();
502 let eff_dim = (characteristic_length * n_features as Float).min(n_features as Float);
503 Ok(-eff_dim) }
505
506 fn kernel_approximation_error(
508 &self,
509 x: &Array2<Float>,
510 gamma: Float,
511 train_indices: &[usize],
512 val_indices: &[usize],
513 ) -> Result<Float> {
514 if train_indices.is_empty() || val_indices.is_empty() {
515 return Ok(0.0);
516 }
517
518 let mut error = 0.0;
520 let mut count = 0;
521
522 for &i in val_indices {
523 for &j in train_indices {
524 let diff = &x.row(i) - &x.row(j);
525 let dist_sq = diff.mapv(|v| v * v).sum();
526 let true_kernel = (-gamma * dist_sq).exp();
527
528 let approx_error = (1.0 - true_kernel) * (1.0 - true_kernel);
530 error += approx_error;
531 count += 1;
532 }
533 }
534
535 Ok(if count > 0 {
536 error / count as Float
537 } else {
538 0.0
539 })
540 }
541}
542
543impl Fit<Array2<Float>, ()> for AdaptiveBandwidthRBFSampler<Untrained> {
544 type Fitted = AdaptiveBandwidthRBFSampler<Trained>;
545
546 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
547 let (n_samples, n_features) = x.dim();
548
549 if n_samples == 0 || n_features == 0 {
550 return Err(SklearsError::InvalidInput(
551 "Input array is empty".to_string(),
552 ));
553 }
554
555 let selected_gamma = self.select_gamma(x)?;
557
558 let mut rng = match self.random_state {
559 Some(seed) => RealStdRng::seed_from_u64(seed),
560 None => RealStdRng::from_seed(thread_rng().gen()),
561 };
562
563 let std_dev = (2.0 * selected_gamma).sqrt();
565 let mut random_weights = Array2::zeros((self.n_components, n_features));
566 for i in 0..self.n_components {
567 for j in 0..n_features {
568 let u1 = rng.gen::<Float>();
570 let u2 = rng.gen::<Float>();
571 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
572 random_weights[[i, j]] = z * std_dev;
573 }
574 }
575
576 let mut random_offset = Array1::zeros(self.n_components);
578 for i in 0..self.n_components {
579 random_offset[i] = rng.gen::<Float>() * 2.0 * std::f64::consts::PI;
580 }
581
582 Ok(AdaptiveBandwidthRBFSampler {
583 n_components: self.n_components,
584 strategy: self.strategy,
585 objective_function: self.objective_function,
586 gamma_range: self.gamma_range,
587 n_gamma_candidates: self.n_gamma_candidates,
588 cv_folds: self.cv_folds,
589 random_state: self.random_state,
590 tolerance: self.tolerance,
591 max_iterations: self.max_iterations,
592 selected_gamma_: Some(selected_gamma),
593 random_weights_: Some(random_weights),
594 random_offset_: Some(random_offset),
595 optimization_history_: None, _state: PhantomData,
597 })
598 }
599}
600
601impl AdaptiveBandwidthRBFSampler<Trained> {
602 pub fn selected_gamma(&self) -> Result<Float> {
604 self.selected_gamma_.ok_or_else(|| SklearsError::NotFitted {
605 operation: "selected_gamma".to_string(),
606 })
607 }
608
609 pub fn optimization_history(&self) -> Option<&Vec<(Float, Float)>> {
611 self.optimization_history_.as_ref()
612 }
613}
614
615impl Transform<Array2<Float>> for AdaptiveBandwidthRBFSampler<Trained> {
616 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
617 let random_weights =
618 self.random_weights_
619 .as_ref()
620 .ok_or_else(|| SklearsError::NotFitted {
621 operation: "transform".to_string(),
622 })?;
623
624 let random_offset =
625 self.random_offset_
626 .as_ref()
627 .ok_or_else(|| SklearsError::NotFitted {
628 operation: "transform".to_string(),
629 })?;
630
631 let (_n_samples, n_features) = x.dim();
632
633 if n_features != random_weights.ncols() {
634 return Err(SklearsError::InvalidInput(format!(
635 "Input has {} features, expected {}",
636 n_features,
637 random_weights.ncols()
638 )));
639 }
640
641 let projection = x.dot(&random_weights.t()) + random_offset;
643
644 let normalization = (2.0 / random_weights.nrows() as Float).sqrt();
646 Ok(projection.mapv(|x| x.cos() * normalization))
647 }
648}
649
650#[allow(non_snake_case)]
651#[cfg(test)]
652mod tests {
653 use super::*;
654 use approx::assert_abs_diff_eq;
655 use scirs2_core::ndarray::array;
656
657 #[test]
658 fn test_adaptive_bandwidth_rbf_sampler_basic() {
659 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
660
661 let sampler = AdaptiveBandwidthRBFSampler::new(50)
662 .strategy(BandwidthSelectionStrategy::MedianHeuristic)
663 .random_state(42);
664
665 let fitted = sampler.fit(&x, &()).unwrap();
666 let features = fitted.transform(&x).unwrap();
667
668 assert_eq!(features.shape(), &[3, 50]);
669
670 let gamma = fitted.selected_gamma().unwrap();
672 assert!(gamma > 0.0);
673
674 for &val in features.iter() {
676 assert!(val >= -2.0 && val <= 2.0);
677 }
678 }
679
680 #[test]
681 fn test_different_bandwidth_strategies() {
682 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
683
684 let strategies = [
685 BandwidthSelectionStrategy::MedianHeuristic,
686 BandwidthSelectionStrategy::ScottRule,
687 BandwidthSelectionStrategy::SilvermanRule,
688 BandwidthSelectionStrategy::GridSearch,
689 ];
690
691 for strategy in &strategies {
692 let sampler = AdaptiveBandwidthRBFSampler::new(20)
693 .strategy(*strategy)
694 .random_state(42);
695
696 let fitted = sampler.fit(&x, &()).unwrap();
697 let features = fitted.transform(&x).unwrap();
698 let gamma = fitted.selected_gamma().unwrap();
699
700 assert_eq!(features.shape(), &[4, 20]);
701 assert!(gamma > 0.0);
702 }
703 }
704
705 #[test]
706 fn test_cross_validation_strategy() {
707 let x = array![
708 [1.0, 1.0],
709 [1.1, 1.1],
710 [2.0, 2.0],
711 [2.1, 2.1],
712 [5.0, 5.0],
713 [5.1, 5.1]
714 ];
715
716 let sampler = AdaptiveBandwidthRBFSampler::new(30)
717 .strategy(BandwidthSelectionStrategy::CrossValidation)
718 .cv_folds(3)
719 .n_gamma_candidates(5)
720 .random_state(42);
721
722 let fitted = sampler.fit(&x, &()).unwrap();
723 let features = fitted.transform(&x).unwrap();
724 let gamma = fitted.selected_gamma().unwrap();
725
726 assert_eq!(features.shape(), &[6, 30]);
727 assert!(gamma > 0.0);
728 }
729
730 #[test]
731 fn test_different_objective_functions() {
732 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
733
734 let objectives = [
735 ObjectiveFunction::KernelAlignment,
736 ObjectiveFunction::LogLikelihood,
737 ObjectiveFunction::KernelTrace,
738 ObjectiveFunction::EffectiveDimensionality,
739 ];
740
741 for objective in &objectives {
742 let sampler = AdaptiveBandwidthRBFSampler::new(25)
743 .strategy(BandwidthSelectionStrategy::GridSearch)
744 .objective_function(*objective)
745 .n_gamma_candidates(5)
746 .random_state(42);
747
748 let fitted = sampler.fit(&x, &()).unwrap();
749 let gamma = fitted.selected_gamma().unwrap();
750
751 assert!(gamma > 0.0);
752 }
753 }
754
755 #[test]
756 fn test_median_heuristic() {
757 let x = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
758
759 let sampler = AdaptiveBandwidthRBFSampler::new(10);
760 let gamma = sampler.median_heuristic_gamma(&x).unwrap();
761
762 assert!(gamma > 0.1 && gamma < 2.0);
764 }
765
766 #[test]
767 fn test_scott_rule() {
768 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
769
770 let sampler = AdaptiveBandwidthRBFSampler::new(10);
771 let gamma = sampler.scott_rule_gamma(&x).unwrap();
772
773 assert!(gamma > 0.0);
774 }
775
776 #[test]
777 fn test_silverman_rule() {
778 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
779
780 let sampler = AdaptiveBandwidthRBFSampler::new(10);
781 let gamma = sampler.silverman_rule_gamma(&x).unwrap();
782
783 assert!(gamma > 0.0);
784 }
785
786 #[test]
787 fn test_reproducibility() {
788 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
789
790 let sampler1 = AdaptiveBandwidthRBFSampler::new(40)
791 .strategy(BandwidthSelectionStrategy::MedianHeuristic)
792 .random_state(123);
793
794 let sampler2 = AdaptiveBandwidthRBFSampler::new(40)
795 .strategy(BandwidthSelectionStrategy::MedianHeuristic)
796 .random_state(123);
797
798 let fitted1 = sampler1.fit(&x, &()).unwrap();
799 let fitted2 = sampler2.fit(&x, &()).unwrap();
800
801 let features1 = fitted1.transform(&x).unwrap();
802 let features2 = fitted2.transform(&x).unwrap();
803
804 let gamma1 = fitted1.selected_gamma().unwrap();
805 let gamma2 = fitted2.selected_gamma().unwrap();
806
807 assert_abs_diff_eq!(gamma1, gamma2, epsilon = 1e-10);
808
809 for (f1, f2) in features1.iter().zip(features2.iter()) {
810 assert_abs_diff_eq!(f1, f2, epsilon = 1e-10);
811 }
812 }
813
814 #[test]
815 fn test_gamma_range() {
816 let x = array![[1.0, 2.0], [3.0, 4.0]];
817
818 let sampler = AdaptiveBandwidthRBFSampler::new(15)
819 .strategy(BandwidthSelectionStrategy::GridSearch)
820 .gamma_range(0.5, 2.0)
821 .n_gamma_candidates(5)
822 .random_state(42);
823
824 let fitted = sampler.fit(&x, &()).unwrap();
825 let gamma = fitted.selected_gamma().unwrap();
826
827 assert!(gamma >= 0.5 && gamma <= 2.0);
829 }
830
831 #[test]
832 fn test_error_handling() {
833 let empty = Array2::<Float>::zeros((0, 0));
835 let sampler = AdaptiveBandwidthRBFSampler::new(10);
836 assert!(sampler.clone().fit(&empty, &()).is_err());
837
838 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
840 let x_test = array![[1.0, 2.0, 3.0]]; let fitted = sampler.fit(&x_train, &()).unwrap();
843 assert!(fitted.transform(&x_test).is_err());
844 }
845
846 #[test]
847 fn test_single_sample() {
848 let x = array![[1.0, 2.0]];
849
850 let sampler = AdaptiveBandwidthRBFSampler::new(10)
851 .strategy(BandwidthSelectionStrategy::MedianHeuristic);
852
853 let fitted = sampler.fit(&x, &()).unwrap();
854 let gamma = fitted.selected_gamma().unwrap();
855
856 assert!(gamma > 0.0);
858 }
859
860 #[test]
861 fn test_large_dataset_efficiency() {
862 let mut data = Vec::new();
864 for i in 0..500 {
865 data.push([i as Float, (i * 2) as Float]);
866 }
867 let x = Array2::from(data);
868
869 let sampler = AdaptiveBandwidthRBFSampler::new(20)
870 .strategy(BandwidthSelectionStrategy::MedianHeuristic);
871
872 let fitted = sampler.fit(&x, &()).unwrap();
873 let gamma = fitted.selected_gamma().unwrap();
874
875 assert!(gamma > 0.0);
876 }
877}