sklears_kernel_approximation/
adaptive_bandwidth_rbf.rs

1//! Adaptive bandwidth RBF kernel approximation methods
2//!
3//! This module implements RBF kernel approximation with automatic bandwidth selection
4//! based on data characteristics. The bandwidth (gamma) parameter is optimized using
5//! various strategies including cross-validation, maximum likelihood, and heuristic methods.
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::{thread_rng, Rng, SeedableRng};
10use sklears_core::{
11    error::{Result, SklearsError},
12    traits::{Fit, Trained, Transform, Untrained},
13    types::Float,
14};
15use std::marker::PhantomData;
16
17/// Bandwidth selection strategy for adaptive RBF
18#[derive(Debug, Clone, Copy)]
19/// BandwidthSelectionStrategy
20pub enum BandwidthSelectionStrategy {
21    /// Cross-validation to minimize approximation error
22    CrossValidation,
23    /// Maximum likelihood estimation
24    MaximumLikelihood,
25    /// Median heuristic based on pairwise distances
26    MedianHeuristic,
27    /// Scott's rule based on data dimensionality and sample size
28    ScottRule,
29    /// Silverman's rule of thumb
30    SilvermanRule,
31    /// Leave-one-out cross-validation
32    LeaveOneOut,
33    /// Grid search over a range of gamma values
34    GridSearch,
35}
36
37/// Objective function for bandwidth optimization
38#[derive(Debug, Clone, Copy)]
39/// ObjectiveFunction
40pub enum ObjectiveFunction {
41    /// Kernel alignment
42    KernelAlignment,
43    /// Log-likelihood
44    LogLikelihood,
45    /// Cross-validation error
46    CrossValidationError,
47    /// Kernel matrix trace
48    KernelTrace,
49    /// Effective dimensionality
50    EffectiveDimensionality,
51}
52
53/// Adaptive bandwidth RBF sampler with automatic gamma selection
54///
55/// This sampler automatically selects the optimal bandwidth parameter (gamma) for RBF
56/// kernel approximation based on data characteristics. Multiple strategies are available
57/// for bandwidth selection, from simple heuristics to sophisticated optimization methods.
58///
59/// # Mathematical Background
60///
61/// The RBF kernel with adaptive bandwidth is: K(x,y) = exp(-γ*||x-y||²)
62/// where γ is automatically selected to optimize a given objective function.
63///
64/// Common bandwidth selection strategies:
65/// - Median heuristic: γ = 1/(2*median²) where median is the median pairwise distance
66/// - Scott's rule: γ = n^(-1/(d+4)) for n samples and d dimensions
67/// - Cross-validation: γ = argmin CV_error(γ)
68///
69/// # Examples
70///
71/// ```ignore
72/// use sklears_kernel_approximation::{AdaptiveBandwidthRBFSampler, BandwidthSelectionStrategy};
73/// use sklears_core::traits::{Transform, Fit, Untrained}
74/// use scirs2_core::ndarray::array;
75///
76/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
77///
78/// let sampler = AdaptiveBandwidthRBFSampler::new(100)
79///     .strategy(BandwidthSelectionStrategy::MedianHeuristic);
80///
81/// let fitted = sampler.fit(&x, &()).unwrap();
82/// let features = fitted.transform(&x).unwrap();
83/// let optimal_gamma = fitted.selected_gamma();
84/// ```
85#[derive(Debug, Clone)]
86/// AdaptiveBandwidthRBFSampler
87pub struct AdaptiveBandwidthRBFSampler<State = Untrained> {
88    /// Number of random features
89    pub n_components: usize,
90    /// Bandwidth selection strategy
91    pub strategy: BandwidthSelectionStrategy,
92    /// Objective function for optimization
93    pub objective_function: ObjectiveFunction,
94    /// Search range for gamma (min, max)
95    pub gamma_range: (Float, Float),
96    /// Number of gamma candidates for grid search
97    pub n_gamma_candidates: usize,
98    /// Cross-validation folds
99    pub cv_folds: usize,
100    /// Random seed for reproducibility
101    pub random_state: Option<u64>,
102    /// Tolerance for optimization convergence
103    pub tolerance: Float,
104    /// Maximum iterations for optimization
105    pub max_iterations: usize,
106
107    // Fitted attributes
108    selected_gamma_: Option<Float>,
109    random_weights_: Option<Array2<Float>>,
110    random_offset_: Option<Array1<Float>>,
111    optimization_history_: Option<Vec<(Float, Float)>>, // (gamma, objective_value)
112
113    // State marker
114    _state: PhantomData<State>,
115}
116
117impl AdaptiveBandwidthRBFSampler<Untrained> {
118    /// Create a new adaptive bandwidth RBF sampler
119    ///
120    /// # Arguments
121    /// * `n_components` - Number of random features to generate
122    pub fn new(n_components: usize) -> Self {
123        Self {
124            n_components,
125            strategy: BandwidthSelectionStrategy::MedianHeuristic,
126            objective_function: ObjectiveFunction::KernelAlignment,
127            gamma_range: (1e-3, 1e3),
128            n_gamma_candidates: 20,
129            cv_folds: 5,
130            random_state: None,
131            tolerance: 1e-6,
132            max_iterations: 100,
133            selected_gamma_: None,
134            random_weights_: None,
135            random_offset_: None,
136            optimization_history_: None,
137            _state: PhantomData,
138        }
139    }
140
141    /// Set the bandwidth selection strategy
142    pub fn strategy(mut self, strategy: BandwidthSelectionStrategy) -> Self {
143        self.strategy = strategy;
144        self
145    }
146
147    /// Set the objective function for bandwidth optimization
148    pub fn objective_function(mut self, objective: ObjectiveFunction) -> Self {
149        self.objective_function = objective;
150        self
151    }
152
153    /// Set the search range for gamma values
154    pub fn gamma_range(mut self, min: Float, max: Float) -> Self {
155        self.gamma_range = (min, max);
156        self
157    }
158
159    /// Set the number of gamma candidates for grid search
160    pub fn n_gamma_candidates(mut self, n: usize) -> Self {
161        self.n_gamma_candidates = n;
162        self
163    }
164
165    /// Set the number of cross-validation folds
166    pub fn cv_folds(mut self, folds: usize) -> Self {
167        self.cv_folds = folds;
168        self
169    }
170
171    /// Set the random state for reproducibility
172    pub fn random_state(mut self, seed: u64) -> Self {
173        self.random_state = Some(seed);
174        self
175    }
176
177    /// Set optimization tolerance
178    pub fn tolerance(mut self, tol: Float) -> Self {
179        self.tolerance = tol;
180        self
181    }
182
183    /// Set maximum optimization iterations
184    pub fn max_iterations(mut self, max_iter: usize) -> Self {
185        self.max_iterations = max_iter;
186        self
187    }
188
189    /// Select optimal gamma based on the chosen strategy
190    fn select_gamma(&self, x: &Array2<Float>) -> Result<Float> {
191        match self.strategy {
192            BandwidthSelectionStrategy::MedianHeuristic => self.median_heuristic_gamma(x),
193            BandwidthSelectionStrategy::ScottRule => self.scott_rule_gamma(x),
194            BandwidthSelectionStrategy::SilvermanRule => self.silverman_rule_gamma(x),
195            BandwidthSelectionStrategy::CrossValidation => self.cross_validation_gamma(x),
196            BandwidthSelectionStrategy::MaximumLikelihood => self.maximum_likelihood_gamma(x),
197            BandwidthSelectionStrategy::LeaveOneOut => self.leave_one_out_gamma(x),
198            BandwidthSelectionStrategy::GridSearch => self.grid_search_gamma(x),
199        }
200    }
201
202    /// Median heuristic: gamma = 1/(2 * median_distance²)
203    fn median_heuristic_gamma(&self, x: &Array2<Float>) -> Result<Float> {
204        let (n_samples, _) = x.dim();
205
206        if n_samples < 2 {
207            return Ok(1.0); // Default gamma for insufficient data
208        }
209
210        // Compute pairwise squared distances
211        let n_pairs = if n_samples > 1000 {
212            // Subsample for large datasets to avoid O(n²) complexity
213            1000
214        } else {
215            n_samples * (n_samples - 1) / 2
216        };
217
218        let mut distances_sq = Vec::with_capacity(n_pairs);
219        let step = if n_samples > 1000 { n_samples / 100 } else { 1 };
220
221        for i in (0..n_samples).step_by(step) {
222            for j in ((i + 1)..n_samples).step_by(step) {
223                if distances_sq.len() >= n_pairs {
224                    break;
225                }
226                let diff = &x.row(i) - &x.row(j);
227                let dist_sq = diff.mapv(|v| v * v).sum();
228                distances_sq.push(dist_sq);
229            }
230            if distances_sq.len() >= n_pairs {
231                break;
232            }
233        }
234
235        if distances_sq.is_empty() {
236            return Ok(1.0);
237        }
238
239        // Find median
240        distances_sq.sort_by(|a, b| a.partial_cmp(b).unwrap());
241        let median_dist_sq = distances_sq[distances_sq.len() / 2];
242
243        // gamma = 1 / (2 * sigma²), where sigma² ≈ median_distance²
244        Ok(if median_dist_sq > 0.0 {
245            1.0 / (2.0 * median_dist_sq)
246        } else {
247            1.0
248        })
249    }
250
251    /// Scott's rule: sigma = n^(-1/(d+4))
252    fn scott_rule_gamma(&self, x: &Array2<Float>) -> Result<Float> {
253        let (n_samples, n_features) = x.dim();
254        let sigma = (n_samples as Float).powf(-1.0 / (n_features as Float + 4.0));
255        Ok(1.0 / (2.0 * sigma * sigma))
256    }
257
258    /// Silverman's rule of thumb
259    fn silverman_rule_gamma(&self, x: &Array2<Float>) -> Result<Float> {
260        let (n_samples, n_features) = x.dim();
261
262        // Compute standard deviations for each dimension
263        let means = x.mean_axis(Axis(0)).unwrap();
264        let mut stds = Array1::zeros(n_features);
265
266        for j in 0..n_features {
267            let var = x
268                .column(j)
269                .mapv(|v| {
270                    let diff = v - means[j];
271                    diff * diff
272                })
273                .mean()
274                .unwrap();
275            stds[j] = var.sqrt();
276        }
277
278        let avg_std = stds.mean().unwrap();
279        let h = 1.06 * avg_std * (n_samples as Float).powf(-1.0 / 5.0);
280
281        Ok(1.0 / (2.0 * h * h))
282    }
283
284    /// Cross-validation based gamma selection
285    fn cross_validation_gamma(&self, x: &Array2<Float>) -> Result<Float> {
286        let gamma_candidates = self.generate_gamma_candidates()?;
287        let mut best_gamma = gamma_candidates[0];
288        let mut best_score = Float::INFINITY;
289
290        for &gamma in &gamma_candidates {
291            let score = self.cross_validation_score(x, gamma)?;
292            if score < best_score {
293                best_score = score;
294                best_gamma = gamma;
295            }
296        }
297
298        Ok(best_gamma)
299    }
300
301    /// Maximum likelihood gamma selection
302    fn maximum_likelihood_gamma(&self, x: &Array2<Float>) -> Result<Float> {
303        let gamma_candidates = self.generate_gamma_candidates()?;
304        let mut best_gamma = gamma_candidates[0];
305        let mut best_likelihood = Float::NEG_INFINITY;
306
307        for &gamma in &gamma_candidates {
308            let likelihood = self.log_likelihood(x, gamma)?;
309            if likelihood > best_likelihood {
310                best_likelihood = likelihood;
311                best_gamma = gamma;
312            }
313        }
314
315        Ok(best_gamma)
316    }
317
318    /// Leave-one-out cross-validation gamma selection
319    fn leave_one_out_gamma(&self, x: &Array2<Float>) -> Result<Float> {
320        let gamma_candidates = self.generate_gamma_candidates()?;
321        let mut best_gamma = gamma_candidates[0];
322        let mut best_score = Float::INFINITY;
323
324        for &gamma in &gamma_candidates {
325            let score = self.leave_one_out_score(x, gamma)?;
326            if score < best_score {
327                best_score = score;
328                best_gamma = gamma;
329            }
330        }
331
332        Ok(best_gamma)
333    }
334
335    /// Grid search gamma selection
336    fn grid_search_gamma(&self, x: &Array2<Float>) -> Result<Float> {
337        let gamma_candidates = self.generate_gamma_candidates()?;
338        let mut best_gamma = gamma_candidates[0];
339        let mut best_score = match self.objective_function {
340            ObjectiveFunction::LogLikelihood => Float::NEG_INFINITY,
341            _ => Float::INFINITY,
342        };
343
344        for &gamma in &gamma_candidates {
345            let score = self.evaluate_objective(x, gamma)?;
346            let is_better = match self.objective_function {
347                ObjectiveFunction::LogLikelihood => score > best_score,
348                _ => score < best_score,
349            };
350
351            if is_better {
352                best_score = score;
353                best_gamma = gamma;
354            }
355        }
356
357        Ok(best_gamma)
358    }
359
360    /// Generate candidates for gamma search
361    fn generate_gamma_candidates(&self) -> Result<Vec<Float>> {
362        let (gamma_min, gamma_max) = self.gamma_range;
363        let log_min = gamma_min.ln();
364        let log_max = gamma_max.ln();
365
366        let mut candidates = Vec::with_capacity(self.n_gamma_candidates);
367        for i in 0..self.n_gamma_candidates {
368            let t = i as Float / (self.n_gamma_candidates - 1) as Float;
369            let log_gamma = log_min + t * (log_max - log_min);
370            candidates.push(log_gamma.exp());
371        }
372
373        Ok(candidates)
374    }
375
376    /// Cross-validation score for a given gamma
377    fn cross_validation_score(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
378        let (n_samples, _) = x.dim();
379        let fold_size = n_samples / self.cv_folds;
380        let mut total_error = 0.0;
381
382        for fold in 0..self.cv_folds {
383            let start_idx = fold * fold_size;
384            let end_idx = if fold == self.cv_folds - 1 {
385                n_samples
386            } else {
387                (fold + 1) * fold_size
388            };
389
390            // Create train/validation split
391            let val_indices: Vec<usize> = (start_idx..end_idx).collect();
392            let train_indices: Vec<usize> = (0..start_idx).chain(end_idx..n_samples).collect();
393
394            if train_indices.is_empty() || val_indices.is_empty() {
395                continue;
396            }
397
398            // Evaluate kernel approximation quality
399            let error = self.kernel_approximation_error(x, gamma, &train_indices, &val_indices)?;
400            total_error += error;
401        }
402
403        Ok(total_error / self.cv_folds as Float)
404    }
405
406    /// Log-likelihood for Gaussian process with RBF kernel
407    fn log_likelihood(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
408        let (n_samples, _) = x.dim();
409
410        // Build kernel matrix (simplified version for efficiency)
411        let mut k_matrix = Array2::zeros((n_samples, n_samples));
412        for i in 0..n_samples {
413            for j in i..n_samples {
414                let diff = &x.row(i) - &x.row(j);
415                let dist_sq = diff.mapv(|v| v * v).sum();
416                let k_val = (-gamma * dist_sq).exp();
417                k_matrix[[i, j]] = k_val;
418                if i != j {
419                    k_matrix[[j, i]] = k_val;
420                }
421            }
422        }
423
424        // Add noise term for numerical stability
425        for i in 0..n_samples {
426            k_matrix[[i, i]] += 1e-6;
427        }
428
429        // Simplified log-likelihood (without full matrix decomposition for efficiency)
430        let trace = k_matrix.diag().sum();
431        let det_approx = trace; // Rough approximation
432
433        Ok(-0.5 * det_approx.ln() - 0.5 * n_samples as Float)
434    }
435
436    /// Leave-one-out cross-validation score
437    fn leave_one_out_score(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
438        let (n_samples, _) = x.dim();
439        let mut total_error = 0.0;
440
441        for i in 0..n_samples {
442            let train_indices: Vec<usize> = (0..n_samples).filter(|&j| j != i).collect();
443            let val_indices = vec![i];
444
445            let error = self.kernel_approximation_error(x, gamma, &train_indices, &val_indices)?;
446            total_error += error;
447        }
448
449        Ok(total_error / n_samples as Float)
450    }
451
452    /// Evaluate objective function
453    fn evaluate_objective(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
454        match self.objective_function {
455            ObjectiveFunction::KernelAlignment => self.kernel_alignment(x, gamma),
456            ObjectiveFunction::LogLikelihood => self.log_likelihood(x, gamma),
457            ObjectiveFunction::CrossValidationError => self.cross_validation_score(x, gamma),
458            ObjectiveFunction::KernelTrace => self.kernel_trace(x, gamma),
459            ObjectiveFunction::EffectiveDimensionality => self.effective_dimensionality(x, gamma),
460        }
461    }
462
463    /// Kernel alignment objective
464    fn kernel_alignment(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
465        let (n_samples, _) = x.dim();
466
467        // Simplified kernel alignment computation
468        let mut alignment = 0.0;
469        let mut count = 0;
470
471        for i in 0..n_samples.min(100) {
472            // Limit for efficiency
473            for j in (i + 1)..n_samples.min(100) {
474                let diff = &x.row(i) - &x.row(j);
475                let dist_sq = diff.mapv(|v| v * v).sum();
476                let k_val = (-gamma * dist_sq).exp();
477                alignment += k_val * k_val; // Self-alignment
478                count += 1;
479            }
480        }
481
482        Ok(if count > 0 {
483            -alignment / count as Float
484        } else {
485            0.0
486        })
487    }
488
489    /// Kernel trace objective
490    fn kernel_trace(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
491        let (n_samples, _) = x.dim();
492        let trace = n_samples as Float; // All diagonal elements are 1.0 for RBF kernel
493        Ok(-trace) // Negative because we typically minimize
494    }
495
496    /// Effective dimensionality objective
497    fn effective_dimensionality(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
498        // Simplified effective dimensionality based on kernel scale
499        let characteristic_length = (1.0 / gamma).sqrt();
500        let (_, n_features) = x.dim();
501        let eff_dim = (characteristic_length * n_features as Float).min(n_features as Float);
502        Ok(-eff_dim) // Negative for minimization
503    }
504
505    /// Kernel approximation error for validation
506    fn kernel_approximation_error(
507        &self,
508        x: &Array2<Float>,
509        gamma: Float,
510        train_indices: &[usize],
511        val_indices: &[usize],
512    ) -> Result<Float> {
513        if train_indices.is_empty() || val_indices.is_empty() {
514            return Ok(0.0);
515        }
516
517        // Simplified approximation quality metric
518        let mut error = 0.0;
519        let mut count = 0;
520
521        for &i in val_indices {
522            for &j in train_indices {
523                let diff = &x.row(i) - &x.row(j);
524                let dist_sq = diff.mapv(|v| v * v).sum();
525                let true_kernel = (-gamma * dist_sq).exp();
526
527                // Simulate RFF approximation error (simplified)
528                let approx_error = (1.0 - true_kernel) * (1.0 - true_kernel);
529                error += approx_error;
530                count += 1;
531            }
532        }
533
534        Ok(if count > 0 {
535            error / count as Float
536        } else {
537            0.0
538        })
539    }
540}
541
542impl Fit<Array2<Float>, ()> for AdaptiveBandwidthRBFSampler<Untrained> {
543    type Fitted = AdaptiveBandwidthRBFSampler<Trained>;
544
545    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
546        let (n_samples, n_features) = x.dim();
547
548        if n_samples == 0 || n_features == 0 {
549            return Err(SklearsError::InvalidInput(
550                "Input array is empty".to_string(),
551            ));
552        }
553
554        // Select optimal gamma
555        let selected_gamma = self.select_gamma(x)?;
556
557        let mut rng = match self.random_state {
558            Some(seed) => RealStdRng::seed_from_u64(seed),
559            None => RealStdRng::from_seed(thread_rng().gen()),
560        };
561
562        // Generate random weights ~ N(0, 2*gamma*I)
563        let std_dev = (2.0 * selected_gamma).sqrt();
564        let mut random_weights = Array2::zeros((self.n_components, n_features));
565        for i in 0..self.n_components {
566            for j in 0..n_features {
567                // Use Box-Muller transformation for normal distribution
568                let u1 = rng.gen::<Float>();
569                let u2 = rng.gen::<Float>();
570                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
571                random_weights[[i, j]] = z * std_dev;
572            }
573        }
574
575        // Generate random offsets ~ Uniform[0, 2π]
576        let mut random_offset = Array1::zeros(self.n_components);
577        for i in 0..self.n_components {
578            random_offset[i] = rng.gen::<Float>() * 2.0 * std::f64::consts::PI;
579        }
580
581        Ok(AdaptiveBandwidthRBFSampler {
582            n_components: self.n_components,
583            strategy: self.strategy,
584            objective_function: self.objective_function,
585            gamma_range: self.gamma_range,
586            n_gamma_candidates: self.n_gamma_candidates,
587            cv_folds: self.cv_folds,
588            random_state: self.random_state,
589            tolerance: self.tolerance,
590            max_iterations: self.max_iterations,
591            selected_gamma_: Some(selected_gamma),
592            random_weights_: Some(random_weights),
593            random_offset_: Some(random_offset),
594            optimization_history_: None, // Could be populated during optimization
595            _state: PhantomData,
596        })
597    }
598}
599
600impl AdaptiveBandwidthRBFSampler<Trained> {
601    /// Get the selected gamma value
602    pub fn selected_gamma(&self) -> Result<Float> {
603        self.selected_gamma_.ok_or_else(|| SklearsError::NotFitted {
604            operation: "selected_gamma".to_string(),
605        })
606    }
607
608    /// Get the optimization history (if available)
609    pub fn optimization_history(&self) -> Option<&Vec<(Float, Float)>> {
610        self.optimization_history_.as_ref()
611    }
612}
613
614impl Transform<Array2<Float>> for AdaptiveBandwidthRBFSampler<Trained> {
615    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
616        let random_weights =
617            self.random_weights_
618                .as_ref()
619                .ok_or_else(|| SklearsError::NotFitted {
620                    operation: "transform".to_string(),
621                })?;
622
623        let random_offset =
624            self.random_offset_
625                .as_ref()
626                .ok_or_else(|| SklearsError::NotFitted {
627                    operation: "transform".to_string(),
628                })?;
629
630        let (n_samples, n_features) = x.dim();
631
632        if n_features != random_weights.ncols() {
633            return Err(SklearsError::InvalidInput(format!(
634                "Input has {} features, expected {}",
635                n_features,
636                random_weights.ncols()
637            )));
638        }
639
640        // Compute X @ W.T + b
641        let projection = x.dot(&random_weights.t()) + random_offset;
642
643        // Apply cosine transformation and normalize
644        let normalization = (2.0 / random_weights.nrows() as Float).sqrt();
645        Ok(projection.mapv(|x| x.cos() * normalization))
646    }
647}
648
649#[allow(non_snake_case)]
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use approx::assert_abs_diff_eq;
654    use scirs2_core::ndarray::array;
655
656    #[test]
657    fn test_adaptive_bandwidth_rbf_sampler_basic() {
658        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
659
660        let sampler = AdaptiveBandwidthRBFSampler::new(50)
661            .strategy(BandwidthSelectionStrategy::MedianHeuristic)
662            .random_state(42);
663
664        let fitted = sampler.fit(&x, &()).unwrap();
665        let features = fitted.transform(&x).unwrap();
666
667        assert_eq!(features.shape(), &[3, 50]);
668
669        // Check that gamma was selected
670        let gamma = fitted.selected_gamma().unwrap();
671        assert!(gamma > 0.0);
672
673        // Check that features are bounded (cosine function)
674        for &val in features.iter() {
675            assert!(val >= -2.0 && val <= 2.0);
676        }
677    }
678
679    #[test]
680    fn test_different_bandwidth_strategies() {
681        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
682
683        let strategies = [
684            BandwidthSelectionStrategy::MedianHeuristic,
685            BandwidthSelectionStrategy::ScottRule,
686            BandwidthSelectionStrategy::SilvermanRule,
687            BandwidthSelectionStrategy::GridSearch,
688        ];
689
690        for strategy in &strategies {
691            let sampler = AdaptiveBandwidthRBFSampler::new(20)
692                .strategy(*strategy)
693                .random_state(42);
694
695            let fitted = sampler.fit(&x, &()).unwrap();
696            let features = fitted.transform(&x).unwrap();
697            let gamma = fitted.selected_gamma().unwrap();
698
699            assert_eq!(features.shape(), &[4, 20]);
700            assert!(gamma > 0.0);
701        }
702    }
703
704    #[test]
705    fn test_cross_validation_strategy() {
706        let x = array![
707            [1.0, 1.0],
708            [1.1, 1.1],
709            [2.0, 2.0],
710            [2.1, 2.1],
711            [5.0, 5.0],
712            [5.1, 5.1]
713        ];
714
715        let sampler = AdaptiveBandwidthRBFSampler::new(30)
716            .strategy(BandwidthSelectionStrategy::CrossValidation)
717            .cv_folds(3)
718            .n_gamma_candidates(5)
719            .random_state(42);
720
721        let fitted = sampler.fit(&x, &()).unwrap();
722        let features = fitted.transform(&x).unwrap();
723        let gamma = fitted.selected_gamma().unwrap();
724
725        assert_eq!(features.shape(), &[6, 30]);
726        assert!(gamma > 0.0);
727    }
728
729    #[test]
730    fn test_different_objective_functions() {
731        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
732
733        let objectives = [
734            ObjectiveFunction::KernelAlignment,
735            ObjectiveFunction::LogLikelihood,
736            ObjectiveFunction::KernelTrace,
737            ObjectiveFunction::EffectiveDimensionality,
738        ];
739
740        for objective in &objectives {
741            let sampler = AdaptiveBandwidthRBFSampler::new(25)
742                .strategy(BandwidthSelectionStrategy::GridSearch)
743                .objective_function(*objective)
744                .n_gamma_candidates(5)
745                .random_state(42);
746
747            let fitted = sampler.fit(&x, &()).unwrap();
748            let gamma = fitted.selected_gamma().unwrap();
749
750            assert!(gamma > 0.0);
751        }
752    }
753
754    #[test]
755    fn test_median_heuristic() {
756        let x = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
757
758        let sampler = AdaptiveBandwidthRBFSampler::new(10);
759        let gamma = sampler.median_heuristic_gamma(&x).unwrap();
760
761        // With unit distances, median distance² ≈ 1, so gamma ≈ 0.5
762        assert!(gamma > 0.1 && gamma < 2.0);
763    }
764
765    #[test]
766    fn test_scott_rule() {
767        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
768
769        let sampler = AdaptiveBandwidthRBFSampler::new(10);
770        let gamma = sampler.scott_rule_gamma(&x).unwrap();
771
772        assert!(gamma > 0.0);
773    }
774
775    #[test]
776    fn test_silverman_rule() {
777        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
778
779        let sampler = AdaptiveBandwidthRBFSampler::new(10);
780        let gamma = sampler.silverman_rule_gamma(&x).unwrap();
781
782        assert!(gamma > 0.0);
783    }
784
785    #[test]
786    fn test_reproducibility() {
787        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
788
789        let sampler1 = AdaptiveBandwidthRBFSampler::new(40)
790            .strategy(BandwidthSelectionStrategy::MedianHeuristic)
791            .random_state(123);
792
793        let sampler2 = AdaptiveBandwidthRBFSampler::new(40)
794            .strategy(BandwidthSelectionStrategy::MedianHeuristic)
795            .random_state(123);
796
797        let fitted1 = sampler1.fit(&x, &()).unwrap();
798        let fitted2 = sampler2.fit(&x, &()).unwrap();
799
800        let features1 = fitted1.transform(&x).unwrap();
801        let features2 = fitted2.transform(&x).unwrap();
802
803        let gamma1 = fitted1.selected_gamma().unwrap();
804        let gamma2 = fitted2.selected_gamma().unwrap();
805
806        assert_abs_diff_eq!(gamma1, gamma2, epsilon = 1e-10);
807
808        for (f1, f2) in features1.iter().zip(features2.iter()) {
809            assert_abs_diff_eq!(f1, f2, epsilon = 1e-10);
810        }
811    }
812
813    #[test]
814    fn test_gamma_range() {
815        let x = array![[1.0, 2.0], [3.0, 4.0]];
816
817        let sampler = AdaptiveBandwidthRBFSampler::new(15)
818            .strategy(BandwidthSelectionStrategy::GridSearch)
819            .gamma_range(0.5, 2.0)
820            .n_gamma_candidates(5)
821            .random_state(42);
822
823        let fitted = sampler.fit(&x, &()).unwrap();
824        let gamma = fitted.selected_gamma().unwrap();
825
826        // Selected gamma should be within the specified range
827        assert!(gamma >= 0.5 && gamma <= 2.0);
828    }
829
830    #[test]
831    fn test_error_handling() {
832        // Empty input
833        let empty = Array2::<Float>::zeros((0, 0));
834        let sampler = AdaptiveBandwidthRBFSampler::new(10);
835        assert!(sampler.clone().fit(&empty, &()).is_err());
836
837        // Dimension mismatch in transform
838        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
839        let x_test = array![[1.0, 2.0, 3.0]]; // Wrong number of features
840
841        let fitted = sampler.fit(&x_train, &()).unwrap();
842        assert!(fitted.transform(&x_test).is_err());
843    }
844
845    #[test]
846    fn test_single_sample() {
847        let x = array![[1.0, 2.0]];
848
849        let sampler = AdaptiveBandwidthRBFSampler::new(10)
850            .strategy(BandwidthSelectionStrategy::MedianHeuristic);
851
852        let fitted = sampler.fit(&x, &()).unwrap();
853        let gamma = fitted.selected_gamma().unwrap();
854
855        // Should use default gamma for single sample
856        assert!(gamma > 0.0);
857    }
858
859    #[test]
860    fn test_large_dataset_efficiency() {
861        // Test that median heuristic works efficiently on larger datasets
862        let mut data = Vec::new();
863        for i in 0..500 {
864            data.push([i as Float, (i * 2) as Float]);
865        }
866        let x = Array2::from(data);
867
868        let sampler = AdaptiveBandwidthRBFSampler::new(20)
869            .strategy(BandwidthSelectionStrategy::MedianHeuristic);
870
871        let fitted = sampler.fit(&x, &()).unwrap();
872        let gamma = fitted.selected_gamma().unwrap();
873
874        assert!(gamma > 0.0);
875    }
876}