sklears_kernel_approximation/
adaptive_bandwidth_rbf.rs

1//! Adaptive bandwidth RBF kernel approximation methods
2//!
3//! This module implements RBF kernel approximation with automatic bandwidth selection
4//! based on data characteristics. The bandwidth (gamma) parameter is optimized using
5//! various strategies including cross-validation, maximum likelihood, and heuristic methods.
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use sklears_core::{
12    error::{Result, SklearsError},
13    traits::{Fit, Trained, Transform, Untrained},
14    types::Float,
15};
16use std::marker::PhantomData;
17
18/// Bandwidth selection strategy for adaptive RBF
19#[derive(Debug, Clone, Copy)]
20/// BandwidthSelectionStrategy
21pub enum BandwidthSelectionStrategy {
22    /// Cross-validation to minimize approximation error
23    CrossValidation,
24    /// Maximum likelihood estimation
25    MaximumLikelihood,
26    /// Median heuristic based on pairwise distances
27    MedianHeuristic,
28    /// Scott's rule based on data dimensionality and sample size
29    ScottRule,
30    /// Silverman's rule of thumb
31    SilvermanRule,
32    /// Leave-one-out cross-validation
33    LeaveOneOut,
34    /// Grid search over a range of gamma values
35    GridSearch,
36}
37
38/// Objective function for bandwidth optimization
39#[derive(Debug, Clone, Copy)]
40/// ObjectiveFunction
41pub enum ObjectiveFunction {
42    /// Kernel alignment
43    KernelAlignment,
44    /// Log-likelihood
45    LogLikelihood,
46    /// Cross-validation error
47    CrossValidationError,
48    /// Kernel matrix trace
49    KernelTrace,
50    /// Effective dimensionality
51    EffectiveDimensionality,
52}
53
54/// Adaptive bandwidth RBF sampler with automatic gamma selection
55///
56/// This sampler automatically selects the optimal bandwidth parameter (gamma) for RBF
57/// kernel approximation based on data characteristics. Multiple strategies are available
58/// for bandwidth selection, from simple heuristics to sophisticated optimization methods.
59///
60/// # Mathematical Background
61///
62/// The RBF kernel with adaptive bandwidth is: K(x,y) = exp(-γ*||x-y||²)
63/// where γ is automatically selected to optimize a given objective function.
64///
65/// Common bandwidth selection strategies:
66/// - Median heuristic: γ = 1/(2*median²) where median is the median pairwise distance
67/// - Scott's rule: γ = n^(-1/(d+4)) for n samples and d dimensions
68/// - Cross-validation: γ = argmin CV_error(γ)
69///
70/// # Examples
71///
72/// ```ignore
73/// use sklears_kernel_approximation::{AdaptiveBandwidthRBFSampler, BandwidthSelectionStrategy};
74/// use sklears_core::traits::{Transform, Fit, Untrained}
75/// use scirs2_core::ndarray::array;
76///
77/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
78///
79/// let sampler = AdaptiveBandwidthRBFSampler::new(100)
80///     .strategy(BandwidthSelectionStrategy::MedianHeuristic);
81///
82/// let fitted = sampler.fit(&x, &()).unwrap();
83/// let features = fitted.transform(&x).unwrap();
84/// let optimal_gamma = fitted.selected_gamma();
85/// ```
86#[derive(Debug, Clone)]
87/// AdaptiveBandwidthRBFSampler
88pub struct AdaptiveBandwidthRBFSampler<State = Untrained> {
89    /// Number of random features
90    pub n_components: usize,
91    /// Bandwidth selection strategy
92    pub strategy: BandwidthSelectionStrategy,
93    /// Objective function for optimization
94    pub objective_function: ObjectiveFunction,
95    /// Search range for gamma (min, max)
96    pub gamma_range: (Float, Float),
97    /// Number of gamma candidates for grid search
98    pub n_gamma_candidates: usize,
99    /// Cross-validation folds
100    pub cv_folds: usize,
101    /// Random seed for reproducibility
102    pub random_state: Option<u64>,
103    /// Tolerance for optimization convergence
104    pub tolerance: Float,
105    /// Maximum iterations for optimization
106    pub max_iterations: usize,
107
108    // Fitted attributes
109    selected_gamma_: Option<Float>,
110    random_weights_: Option<Array2<Float>>,
111    random_offset_: Option<Array1<Float>>,
112    optimization_history_: Option<Vec<(Float, Float)>>, // (gamma, objective_value)
113
114    // State marker
115    _state: PhantomData<State>,
116}
117
118impl AdaptiveBandwidthRBFSampler<Untrained> {
119    /// Create a new adaptive bandwidth RBF sampler
120    ///
121    /// # Arguments
122    /// * `n_components` - Number of random features to generate
123    pub fn new(n_components: usize) -> Self {
124        Self {
125            n_components,
126            strategy: BandwidthSelectionStrategy::MedianHeuristic,
127            objective_function: ObjectiveFunction::KernelAlignment,
128            gamma_range: (1e-3, 1e3),
129            n_gamma_candidates: 20,
130            cv_folds: 5,
131            random_state: None,
132            tolerance: 1e-6,
133            max_iterations: 100,
134            selected_gamma_: None,
135            random_weights_: None,
136            random_offset_: None,
137            optimization_history_: None,
138            _state: PhantomData,
139        }
140    }
141
142    /// Set the bandwidth selection strategy
143    pub fn strategy(mut self, strategy: BandwidthSelectionStrategy) -> Self {
144        self.strategy = strategy;
145        self
146    }
147
148    /// Set the objective function for bandwidth optimization
149    pub fn objective_function(mut self, objective: ObjectiveFunction) -> Self {
150        self.objective_function = objective;
151        self
152    }
153
154    /// Set the search range for gamma values
155    pub fn gamma_range(mut self, min: Float, max: Float) -> Self {
156        self.gamma_range = (min, max);
157        self
158    }
159
160    /// Set the number of gamma candidates for grid search
161    pub fn n_gamma_candidates(mut self, n: usize) -> Self {
162        self.n_gamma_candidates = n;
163        self
164    }
165
166    /// Set the number of cross-validation folds
167    pub fn cv_folds(mut self, folds: usize) -> Self {
168        self.cv_folds = folds;
169        self
170    }
171
172    /// Set the random state for reproducibility
173    pub fn random_state(mut self, seed: u64) -> Self {
174        self.random_state = Some(seed);
175        self
176    }
177
178    /// Set optimization tolerance
179    pub fn tolerance(mut self, tol: Float) -> Self {
180        self.tolerance = tol;
181        self
182    }
183
184    /// Set maximum optimization iterations
185    pub fn max_iterations(mut self, max_iter: usize) -> Self {
186        self.max_iterations = max_iter;
187        self
188    }
189
190    /// Select optimal gamma based on the chosen strategy
191    fn select_gamma(&self, x: &Array2<Float>) -> Result<Float> {
192        match self.strategy {
193            BandwidthSelectionStrategy::MedianHeuristic => self.median_heuristic_gamma(x),
194            BandwidthSelectionStrategy::ScottRule => self.scott_rule_gamma(x),
195            BandwidthSelectionStrategy::SilvermanRule => self.silverman_rule_gamma(x),
196            BandwidthSelectionStrategy::CrossValidation => self.cross_validation_gamma(x),
197            BandwidthSelectionStrategy::MaximumLikelihood => self.maximum_likelihood_gamma(x),
198            BandwidthSelectionStrategy::LeaveOneOut => self.leave_one_out_gamma(x),
199            BandwidthSelectionStrategy::GridSearch => self.grid_search_gamma(x),
200        }
201    }
202
203    /// Median heuristic: gamma = 1/(2 * median_distance²)
204    fn median_heuristic_gamma(&self, x: &Array2<Float>) -> Result<Float> {
205        let (n_samples, _) = x.dim();
206
207        if n_samples < 2 {
208            return Ok(1.0); // Default gamma for insufficient data
209        }
210
211        // Compute pairwise squared distances
212        let n_pairs = if n_samples > 1000 {
213            // Subsample for large datasets to avoid O(n²) complexity
214            1000
215        } else {
216            n_samples * (n_samples - 1) / 2
217        };
218
219        let mut distances_sq = Vec::with_capacity(n_pairs);
220        let step = if n_samples > 1000 { n_samples / 100 } else { 1 };
221
222        for i in (0..n_samples).step_by(step) {
223            for j in ((i + 1)..n_samples).step_by(step) {
224                if distances_sq.len() >= n_pairs {
225                    break;
226                }
227                let diff = &x.row(i) - &x.row(j);
228                let dist_sq = diff.mapv(|v| v * v).sum();
229                distances_sq.push(dist_sq);
230            }
231            if distances_sq.len() >= n_pairs {
232                break;
233            }
234        }
235
236        if distances_sq.is_empty() {
237            return Ok(1.0);
238        }
239
240        // Find median
241        distances_sq.sort_by(|a, b| a.partial_cmp(b).unwrap());
242        let median_dist_sq = distances_sq[distances_sq.len() / 2];
243
244        // gamma = 1 / (2 * sigma²), where sigma² ≈ median_distance²
245        Ok(if median_dist_sq > 0.0 {
246            1.0 / (2.0 * median_dist_sq)
247        } else {
248            1.0
249        })
250    }
251
252    /// Scott's rule: sigma = n^(-1/(d+4))
253    fn scott_rule_gamma(&self, x: &Array2<Float>) -> Result<Float> {
254        let (n_samples, n_features) = x.dim();
255        let sigma = (n_samples as Float).powf(-1.0 / (n_features as Float + 4.0));
256        Ok(1.0 / (2.0 * sigma * sigma))
257    }
258
259    /// Silverman's rule of thumb
260    fn silverman_rule_gamma(&self, x: &Array2<Float>) -> Result<Float> {
261        let (n_samples, n_features) = x.dim();
262
263        // Compute standard deviations for each dimension
264        let means = x.mean_axis(Axis(0)).unwrap();
265        let mut stds = Array1::zeros(n_features);
266
267        for j in 0..n_features {
268            let var = x
269                .column(j)
270                .mapv(|v| {
271                    let diff = v - means[j];
272                    diff * diff
273                })
274                .mean()
275                .unwrap();
276            stds[j] = var.sqrt();
277        }
278
279        let avg_std = stds.mean().unwrap();
280        let h = 1.06 * avg_std * (n_samples as Float).powf(-1.0 / 5.0);
281
282        Ok(1.0 / (2.0 * h * h))
283    }
284
285    /// Cross-validation based gamma selection
286    fn cross_validation_gamma(&self, x: &Array2<Float>) -> Result<Float> {
287        let gamma_candidates = self.generate_gamma_candidates()?;
288        let mut best_gamma = gamma_candidates[0];
289        let mut best_score = Float::INFINITY;
290
291        for &gamma in &gamma_candidates {
292            let score = self.cross_validation_score(x, gamma)?;
293            if score < best_score {
294                best_score = score;
295                best_gamma = gamma;
296            }
297        }
298
299        Ok(best_gamma)
300    }
301
302    /// Maximum likelihood gamma selection
303    fn maximum_likelihood_gamma(&self, x: &Array2<Float>) -> Result<Float> {
304        let gamma_candidates = self.generate_gamma_candidates()?;
305        let mut best_gamma = gamma_candidates[0];
306        let mut best_likelihood = Float::NEG_INFINITY;
307
308        for &gamma in &gamma_candidates {
309            let likelihood = self.log_likelihood(x, gamma)?;
310            if likelihood > best_likelihood {
311                best_likelihood = likelihood;
312                best_gamma = gamma;
313            }
314        }
315
316        Ok(best_gamma)
317    }
318
319    /// Leave-one-out cross-validation gamma selection
320    fn leave_one_out_gamma(&self, x: &Array2<Float>) -> Result<Float> {
321        let gamma_candidates = self.generate_gamma_candidates()?;
322        let mut best_gamma = gamma_candidates[0];
323        let mut best_score = Float::INFINITY;
324
325        for &gamma in &gamma_candidates {
326            let score = self.leave_one_out_score(x, gamma)?;
327            if score < best_score {
328                best_score = score;
329                best_gamma = gamma;
330            }
331        }
332
333        Ok(best_gamma)
334    }
335
336    /// Grid search gamma selection
337    fn grid_search_gamma(&self, x: &Array2<Float>) -> Result<Float> {
338        let gamma_candidates = self.generate_gamma_candidates()?;
339        let mut best_gamma = gamma_candidates[0];
340        let mut best_score = match self.objective_function {
341            ObjectiveFunction::LogLikelihood => Float::NEG_INFINITY,
342            _ => Float::INFINITY,
343        };
344
345        for &gamma in &gamma_candidates {
346            let score = self.evaluate_objective(x, gamma)?;
347            let is_better = match self.objective_function {
348                ObjectiveFunction::LogLikelihood => score > best_score,
349                _ => score < best_score,
350            };
351
352            if is_better {
353                best_score = score;
354                best_gamma = gamma;
355            }
356        }
357
358        Ok(best_gamma)
359    }
360
361    /// Generate candidates for gamma search
362    fn generate_gamma_candidates(&self) -> Result<Vec<Float>> {
363        let (gamma_min, gamma_max) = self.gamma_range;
364        let log_min = gamma_min.ln();
365        let log_max = gamma_max.ln();
366
367        let mut candidates = Vec::with_capacity(self.n_gamma_candidates);
368        for i in 0..self.n_gamma_candidates {
369            let t = i as Float / (self.n_gamma_candidates - 1) as Float;
370            let log_gamma = log_min + t * (log_max - log_min);
371            candidates.push(log_gamma.exp());
372        }
373
374        Ok(candidates)
375    }
376
377    /// Cross-validation score for a given gamma
378    fn cross_validation_score(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
379        let (n_samples, _) = x.dim();
380        let fold_size = n_samples / self.cv_folds;
381        let mut total_error = 0.0;
382
383        for fold in 0..self.cv_folds {
384            let start_idx = fold * fold_size;
385            let end_idx = if fold == self.cv_folds - 1 {
386                n_samples
387            } else {
388                (fold + 1) * fold_size
389            };
390
391            // Create train/validation split
392            let val_indices: Vec<usize> = (start_idx..end_idx).collect();
393            let train_indices: Vec<usize> = (0..start_idx).chain(end_idx..n_samples).collect();
394
395            if train_indices.is_empty() || val_indices.is_empty() {
396                continue;
397            }
398
399            // Evaluate kernel approximation quality
400            let error = self.kernel_approximation_error(x, gamma, &train_indices, &val_indices)?;
401            total_error += error;
402        }
403
404        Ok(total_error / self.cv_folds as Float)
405    }
406
407    /// Log-likelihood for Gaussian process with RBF kernel
408    fn log_likelihood(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
409        let (n_samples, _) = x.dim();
410
411        // Build kernel matrix (simplified version for efficiency)
412        let mut k_matrix = Array2::zeros((n_samples, n_samples));
413        for i in 0..n_samples {
414            for j in i..n_samples {
415                let diff = &x.row(i) - &x.row(j);
416                let dist_sq = diff.mapv(|v| v * v).sum();
417                let k_val = (-gamma * dist_sq).exp();
418                k_matrix[[i, j]] = k_val;
419                if i != j {
420                    k_matrix[[j, i]] = k_val;
421                }
422            }
423        }
424
425        // Add noise term for numerical stability
426        for i in 0..n_samples {
427            k_matrix[[i, i]] += 1e-6;
428        }
429
430        // Simplified log-likelihood (without full matrix decomposition for efficiency)
431        let trace = k_matrix.diag().sum();
432        let det_approx = trace; // Rough approximation
433
434        Ok(-0.5 * det_approx.ln() - 0.5 * n_samples as Float)
435    }
436
437    /// Leave-one-out cross-validation score
438    fn leave_one_out_score(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
439        let (n_samples, _) = x.dim();
440        let mut total_error = 0.0;
441
442        for i in 0..n_samples {
443            let train_indices: Vec<usize> = (0..n_samples).filter(|&j| j != i).collect();
444            let val_indices = vec![i];
445
446            let error = self.kernel_approximation_error(x, gamma, &train_indices, &val_indices)?;
447            total_error += error;
448        }
449
450        Ok(total_error / n_samples as Float)
451    }
452
453    /// Evaluate objective function
454    fn evaluate_objective(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
455        match self.objective_function {
456            ObjectiveFunction::KernelAlignment => self.kernel_alignment(x, gamma),
457            ObjectiveFunction::LogLikelihood => self.log_likelihood(x, gamma),
458            ObjectiveFunction::CrossValidationError => self.cross_validation_score(x, gamma),
459            ObjectiveFunction::KernelTrace => self.kernel_trace(x, gamma),
460            ObjectiveFunction::EffectiveDimensionality => self.effective_dimensionality(x, gamma),
461        }
462    }
463
464    /// Kernel alignment objective
465    fn kernel_alignment(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
466        let (n_samples, _) = x.dim();
467
468        // Simplified kernel alignment computation
469        let mut alignment = 0.0;
470        let mut count = 0;
471
472        for i in 0..n_samples.min(100) {
473            // Limit for efficiency
474            for j in (i + 1)..n_samples.min(100) {
475                let diff = &x.row(i) - &x.row(j);
476                let dist_sq = diff.mapv(|v| v * v).sum();
477                let k_val = (-gamma * dist_sq).exp();
478                alignment += k_val * k_val; // Self-alignment
479                count += 1;
480            }
481        }
482
483        Ok(if count > 0 {
484            -alignment / count as Float
485        } else {
486            0.0
487        })
488    }
489
490    /// Kernel trace objective
491    fn kernel_trace(&self, x: &Array2<Float>, _gamma: Float) -> Result<Float> {
492        let (n_samples, _) = x.dim();
493        let trace = n_samples as Float; // All diagonal elements are 1.0 for RBF kernel
494        Ok(-trace) // Negative because we typically minimize
495    }
496
497    /// Effective dimensionality objective
498    fn effective_dimensionality(&self, x: &Array2<Float>, gamma: Float) -> Result<Float> {
499        // Simplified effective dimensionality based on kernel scale
500        let characteristic_length = (1.0 / gamma).sqrt();
501        let (_, n_features) = x.dim();
502        let eff_dim = (characteristic_length * n_features as Float).min(n_features as Float);
503        Ok(-eff_dim) // Negative for minimization
504    }
505
506    /// Kernel approximation error for validation
507    fn kernel_approximation_error(
508        &self,
509        x: &Array2<Float>,
510        gamma: Float,
511        train_indices: &[usize],
512        val_indices: &[usize],
513    ) -> Result<Float> {
514        if train_indices.is_empty() || val_indices.is_empty() {
515            return Ok(0.0);
516        }
517
518        // Simplified approximation quality metric
519        let mut error = 0.0;
520        let mut count = 0;
521
522        for &i in val_indices {
523            for &j in train_indices {
524                let diff = &x.row(i) - &x.row(j);
525                let dist_sq = diff.mapv(|v| v * v).sum();
526                let true_kernel = (-gamma * dist_sq).exp();
527
528                // Simulate RFF approximation error (simplified)
529                let approx_error = (1.0 - true_kernel) * (1.0 - true_kernel);
530                error += approx_error;
531                count += 1;
532            }
533        }
534
535        Ok(if count > 0 {
536            error / count as Float
537        } else {
538            0.0
539        })
540    }
541}
542
543impl Fit<Array2<Float>, ()> for AdaptiveBandwidthRBFSampler<Untrained> {
544    type Fitted = AdaptiveBandwidthRBFSampler<Trained>;
545
546    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
547        let (n_samples, n_features) = x.dim();
548
549        if n_samples == 0 || n_features == 0 {
550            return Err(SklearsError::InvalidInput(
551                "Input array is empty".to_string(),
552            ));
553        }
554
555        // Select optimal gamma
556        let selected_gamma = self.select_gamma(x)?;
557
558        let mut rng = match self.random_state {
559            Some(seed) => RealStdRng::seed_from_u64(seed),
560            None => RealStdRng::from_seed(thread_rng().gen()),
561        };
562
563        // Generate random weights ~ N(0, 2*gamma*I)
564        let std_dev = (2.0 * selected_gamma).sqrt();
565        let mut random_weights = Array2::zeros((self.n_components, n_features));
566        for i in 0..self.n_components {
567            for j in 0..n_features {
568                // Use Box-Muller transformation for normal distribution
569                let u1 = rng.gen::<Float>();
570                let u2 = rng.gen::<Float>();
571                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
572                random_weights[[i, j]] = z * std_dev;
573            }
574        }
575
576        // Generate random offsets ~ Uniform[0, 2π]
577        let mut random_offset = Array1::zeros(self.n_components);
578        for i in 0..self.n_components {
579            random_offset[i] = rng.gen::<Float>() * 2.0 * std::f64::consts::PI;
580        }
581
582        Ok(AdaptiveBandwidthRBFSampler {
583            n_components: self.n_components,
584            strategy: self.strategy,
585            objective_function: self.objective_function,
586            gamma_range: self.gamma_range,
587            n_gamma_candidates: self.n_gamma_candidates,
588            cv_folds: self.cv_folds,
589            random_state: self.random_state,
590            tolerance: self.tolerance,
591            max_iterations: self.max_iterations,
592            selected_gamma_: Some(selected_gamma),
593            random_weights_: Some(random_weights),
594            random_offset_: Some(random_offset),
595            optimization_history_: None, // Could be populated during optimization
596            _state: PhantomData,
597        })
598    }
599}
600
601impl AdaptiveBandwidthRBFSampler<Trained> {
602    /// Get the selected gamma value
603    pub fn selected_gamma(&self) -> Result<Float> {
604        self.selected_gamma_.ok_or_else(|| SklearsError::NotFitted {
605            operation: "selected_gamma".to_string(),
606        })
607    }
608
609    /// Get the optimization history (if available)
610    pub fn optimization_history(&self) -> Option<&Vec<(Float, Float)>> {
611        self.optimization_history_.as_ref()
612    }
613}
614
615impl Transform<Array2<Float>> for AdaptiveBandwidthRBFSampler<Trained> {
616    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
617        let random_weights =
618            self.random_weights_
619                .as_ref()
620                .ok_or_else(|| SklearsError::NotFitted {
621                    operation: "transform".to_string(),
622                })?;
623
624        let random_offset =
625            self.random_offset_
626                .as_ref()
627                .ok_or_else(|| SklearsError::NotFitted {
628                    operation: "transform".to_string(),
629                })?;
630
631        let (_n_samples, n_features) = x.dim();
632
633        if n_features != random_weights.ncols() {
634            return Err(SklearsError::InvalidInput(format!(
635                "Input has {} features, expected {}",
636                n_features,
637                random_weights.ncols()
638            )));
639        }
640
641        // Compute X @ W.T + b
642        let projection = x.dot(&random_weights.t()) + random_offset;
643
644        // Apply cosine transformation and normalize
645        let normalization = (2.0 / random_weights.nrows() as Float).sqrt();
646        Ok(projection.mapv(|x| x.cos() * normalization))
647    }
648}
649
650#[allow(non_snake_case)]
651#[cfg(test)]
652mod tests {
653    use super::*;
654    use approx::assert_abs_diff_eq;
655    use scirs2_core::ndarray::array;
656
657    #[test]
658    fn test_adaptive_bandwidth_rbf_sampler_basic() {
659        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
660
661        let sampler = AdaptiveBandwidthRBFSampler::new(50)
662            .strategy(BandwidthSelectionStrategy::MedianHeuristic)
663            .random_state(42);
664
665        let fitted = sampler.fit(&x, &()).unwrap();
666        let features = fitted.transform(&x).unwrap();
667
668        assert_eq!(features.shape(), &[3, 50]);
669
670        // Check that gamma was selected
671        let gamma = fitted.selected_gamma().unwrap();
672        assert!(gamma > 0.0);
673
674        // Check that features are bounded (cosine function)
675        for &val in features.iter() {
676            assert!(val >= -2.0 && val <= 2.0);
677        }
678    }
679
680    #[test]
681    fn test_different_bandwidth_strategies() {
682        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
683
684        let strategies = [
685            BandwidthSelectionStrategy::MedianHeuristic,
686            BandwidthSelectionStrategy::ScottRule,
687            BandwidthSelectionStrategy::SilvermanRule,
688            BandwidthSelectionStrategy::GridSearch,
689        ];
690
691        for strategy in &strategies {
692            let sampler = AdaptiveBandwidthRBFSampler::new(20)
693                .strategy(*strategy)
694                .random_state(42);
695
696            let fitted = sampler.fit(&x, &()).unwrap();
697            let features = fitted.transform(&x).unwrap();
698            let gamma = fitted.selected_gamma().unwrap();
699
700            assert_eq!(features.shape(), &[4, 20]);
701            assert!(gamma > 0.0);
702        }
703    }
704
705    #[test]
706    fn test_cross_validation_strategy() {
707        let x = array![
708            [1.0, 1.0],
709            [1.1, 1.1],
710            [2.0, 2.0],
711            [2.1, 2.1],
712            [5.0, 5.0],
713            [5.1, 5.1]
714        ];
715
716        let sampler = AdaptiveBandwidthRBFSampler::new(30)
717            .strategy(BandwidthSelectionStrategy::CrossValidation)
718            .cv_folds(3)
719            .n_gamma_candidates(5)
720            .random_state(42);
721
722        let fitted = sampler.fit(&x, &()).unwrap();
723        let features = fitted.transform(&x).unwrap();
724        let gamma = fitted.selected_gamma().unwrap();
725
726        assert_eq!(features.shape(), &[6, 30]);
727        assert!(gamma > 0.0);
728    }
729
730    #[test]
731    fn test_different_objective_functions() {
732        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
733
734        let objectives = [
735            ObjectiveFunction::KernelAlignment,
736            ObjectiveFunction::LogLikelihood,
737            ObjectiveFunction::KernelTrace,
738            ObjectiveFunction::EffectiveDimensionality,
739        ];
740
741        for objective in &objectives {
742            let sampler = AdaptiveBandwidthRBFSampler::new(25)
743                .strategy(BandwidthSelectionStrategy::GridSearch)
744                .objective_function(*objective)
745                .n_gamma_candidates(5)
746                .random_state(42);
747
748            let fitted = sampler.fit(&x, &()).unwrap();
749            let gamma = fitted.selected_gamma().unwrap();
750
751            assert!(gamma > 0.0);
752        }
753    }
754
755    #[test]
756    fn test_median_heuristic() {
757        let x = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
758
759        let sampler = AdaptiveBandwidthRBFSampler::new(10);
760        let gamma = sampler.median_heuristic_gamma(&x).unwrap();
761
762        // With unit distances, median distance² ≈ 1, so gamma ≈ 0.5
763        assert!(gamma > 0.1 && gamma < 2.0);
764    }
765
766    #[test]
767    fn test_scott_rule() {
768        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
769
770        let sampler = AdaptiveBandwidthRBFSampler::new(10);
771        let gamma = sampler.scott_rule_gamma(&x).unwrap();
772
773        assert!(gamma > 0.0);
774    }
775
776    #[test]
777    fn test_silverman_rule() {
778        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
779
780        let sampler = AdaptiveBandwidthRBFSampler::new(10);
781        let gamma = sampler.silverman_rule_gamma(&x).unwrap();
782
783        assert!(gamma > 0.0);
784    }
785
786    #[test]
787    fn test_reproducibility() {
788        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
789
790        let sampler1 = AdaptiveBandwidthRBFSampler::new(40)
791            .strategy(BandwidthSelectionStrategy::MedianHeuristic)
792            .random_state(123);
793
794        let sampler2 = AdaptiveBandwidthRBFSampler::new(40)
795            .strategy(BandwidthSelectionStrategy::MedianHeuristic)
796            .random_state(123);
797
798        let fitted1 = sampler1.fit(&x, &()).unwrap();
799        let fitted2 = sampler2.fit(&x, &()).unwrap();
800
801        let features1 = fitted1.transform(&x).unwrap();
802        let features2 = fitted2.transform(&x).unwrap();
803
804        let gamma1 = fitted1.selected_gamma().unwrap();
805        let gamma2 = fitted2.selected_gamma().unwrap();
806
807        assert_abs_diff_eq!(gamma1, gamma2, epsilon = 1e-10);
808
809        for (f1, f2) in features1.iter().zip(features2.iter()) {
810            assert_abs_diff_eq!(f1, f2, epsilon = 1e-10);
811        }
812    }
813
814    #[test]
815    fn test_gamma_range() {
816        let x = array![[1.0, 2.0], [3.0, 4.0]];
817
818        let sampler = AdaptiveBandwidthRBFSampler::new(15)
819            .strategy(BandwidthSelectionStrategy::GridSearch)
820            .gamma_range(0.5, 2.0)
821            .n_gamma_candidates(5)
822            .random_state(42);
823
824        let fitted = sampler.fit(&x, &()).unwrap();
825        let gamma = fitted.selected_gamma().unwrap();
826
827        // Selected gamma should be within the specified range
828        assert!(gamma >= 0.5 && gamma <= 2.0);
829    }
830
831    #[test]
832    fn test_error_handling() {
833        // Empty input
834        let empty = Array2::<Float>::zeros((0, 0));
835        let sampler = AdaptiveBandwidthRBFSampler::new(10);
836        assert!(sampler.clone().fit(&empty, &()).is_err());
837
838        // Dimension mismatch in transform
839        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
840        let x_test = array![[1.0, 2.0, 3.0]]; // Wrong number of features
841
842        let fitted = sampler.fit(&x_train, &()).unwrap();
843        assert!(fitted.transform(&x_test).is_err());
844    }
845
846    #[test]
847    fn test_single_sample() {
848        let x = array![[1.0, 2.0]];
849
850        let sampler = AdaptiveBandwidthRBFSampler::new(10)
851            .strategy(BandwidthSelectionStrategy::MedianHeuristic);
852
853        let fitted = sampler.fit(&x, &()).unwrap();
854        let gamma = fitted.selected_gamma().unwrap();
855
856        // Should use default gamma for single sample
857        assert!(gamma > 0.0);
858    }
859
860    #[test]
861    fn test_large_dataset_efficiency() {
862        // Test that median heuristic works efficiently on larger datasets
863        let mut data = Vec::new();
864        for i in 0..500 {
865            data.push([i as Float, (i * 2) as Float]);
866        }
867        let x = Array2::from(data);
868
869        let sampler = AdaptiveBandwidthRBFSampler::new(20)
870            .strategy(BandwidthSelectionStrategy::MedianHeuristic);
871
872        let fitted = sampler.fit(&x, &()).unwrap();
873        let gamma = fitted.selected_gamma().unwrap();
874
875        assert!(gamma > 0.0);
876    }
877}