1use scirs2_core::ndarray::{Array1, Array2};
15use sklears_core::error::SklearsError;
16use std::fmt::Debug;
17
18pub trait KernelMethod: Send + Sync + Debug {
20 fn name(&self) -> &str;
22
23 fn n_output_features(&self) -> Option<usize>;
25
26 fn complexity(&self) -> Complexity;
28
29 fn error_bound(&self) -> Option<ErrorBound>;
31
32 fn supports_kernel(&self, kernel_type: KernelType) -> bool;
34
35 fn supported_kernels(&self) -> Vec<KernelType>;
37}
38
39pub trait SamplingStrategy: Send + Sync + Debug {
41 fn sample(&self, data: &Array2<f64>, n_samples: usize) -> Result<Vec<usize>, SklearsError>;
43
44 fn name(&self) -> &str;
46
47 fn requires_fitting(&self) -> bool {
49 false
50 }
51
52 fn fit(&mut self, _data: &Array2<f64>) -> Result<(), SklearsError> {
54 Ok(())
55 }
56
57 fn weights(&self) -> Option<Array1<f64>> {
59 None
60 }
61}
62
63pub trait FeatureMap: Send + Sync + Debug {
65 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError>;
67
68 fn output_dim(&self) -> usize;
70
71 fn name(&self) -> &str;
73
74 fn is_invertible(&self) -> bool {
76 false
77 }
78
79 fn inverse_transform(&self, _features: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
81 Err(SklearsError::InvalidInput(
82 "Inverse transform not supported".to_string(),
83 ))
84 }
85}
86
87pub trait ApproximationQuality: Send + Sync + Debug {
89 fn compute(
91 &self,
92 exact_kernel: &Array2<f64>,
93 approx_kernel: &Array2<f64>,
94 ) -> Result<f64, SklearsError>;
95
96 fn name(&self) -> &str;
98
99 fn higher_is_better(&self) -> bool;
101
102 fn acceptable_threshold(&self) -> Option<f64> {
104 None
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum Complexity {
111 Linear,
113 QuasiLinear,
115 LinearBoth,
117 QuadraticFeatures,
119 QuadraticSamples,
121 Cubic,
123 Custom(String),
125}
126
127impl Complexity {
128 pub fn description(&self) -> &str {
130 match self {
131 Complexity::Linear => "O(d) - Linear in features",
132 Complexity::QuasiLinear => "O(d log d) - Quasi-linear",
133 Complexity::LinearBoth => "O(n*d) - Linear in samples and features",
134 Complexity::QuadraticFeatures => "O(n*d^2) - Quadratic in features",
135 Complexity::QuadraticSamples => "O(n^2*d) - Quadratic in samples",
136 Complexity::Cubic => "O(n^3) - Cubic complexity",
137 Complexity::Custom(s) => s,
138 }
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct ErrorBound {
145 pub bound_type: BoundType,
147 pub error: f64,
149 pub confidence: Option<f64>,
151 pub description: String,
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum BoundType {
158 Probabilistic,
160 Deterministic,
162 Expected,
164 Empirical,
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
170pub enum KernelType {
171 RBF,
173 Laplacian,
175 Polynomial,
177 Linear,
179 ArcCosine,
181 ChiSquared,
183 String,
185 Graph,
187 Custom,
189}
190
191impl KernelType {
192 pub fn name(&self) -> &str {
194 match self {
195 KernelType::RBF => "RBF",
196 KernelType::Laplacian => "Laplacian",
197 KernelType::Polynomial => "Polynomial",
198 KernelType::Linear => "Linear",
199 KernelType::ArcCosine => "ArcCosine",
200 KernelType::ChiSquared => "ChiSquared",
201 KernelType::String => "String",
202 KernelType::Graph => "Graph",
203 KernelType::Custom => "Custom",
204 }
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct UniformSampling {
211 pub random_state: Option<u64>,
213}
214
215impl UniformSampling {
216 pub fn new(random_state: Option<u64>) -> Self {
218 Self { random_state }
219 }
220}
221
222impl SamplingStrategy for UniformSampling {
223 fn sample(&self, data: &Array2<f64>, n_samples: usize) -> Result<Vec<usize>, SklearsError> {
224 use scirs2_core::random::seeded_rng;
225
226 let (n_rows, _) = data.dim();
227 if n_samples > n_rows {
228 return Err(SklearsError::InvalidInput(format!(
229 "Cannot sample {} points from {} samples",
230 n_samples, n_rows
231 )));
232 }
233
234 let mut rng = seeded_rng(self.random_state.unwrap_or(42));
235
236 let mut indices: Vec<usize> = (0..n_samples).collect();
238 for i in n_samples..n_rows {
239 let j = rng.gen_range(0..=i);
240 if j < n_samples {
241 indices[j] = i;
242 }
243 }
244
245 Ok(indices)
246 }
247
248 fn name(&self) -> &str {
249 "UniformSampling"
250 }
251}
252
253#[derive(Debug, Clone)]
255pub struct KMeansSampling {
256 pub n_iterations: usize,
258 pub random_state: Option<u64>,
260 centers: Option<Array2<f64>>,
262}
263
264impl KMeansSampling {
265 pub fn new(n_iterations: usize, random_state: Option<u64>) -> Self {
267 Self {
268 n_iterations,
269 random_state,
270 centers: None,
271 }
272 }
273}
274
275impl SamplingStrategy for KMeansSampling {
276 fn sample(&self, data: &Array2<f64>, n_samples: usize) -> Result<Vec<usize>, SklearsError> {
277 use scirs2_core::random::seeded_rng;
278
279 let (n_rows, n_features) = data.dim();
280 if n_samples > n_rows {
281 return Err(SklearsError::InvalidInput(format!(
282 "Cannot sample {} points from {} samples",
283 n_samples, n_rows
284 )));
285 }
286
287 let mut rng = seeded_rng(self.random_state.unwrap_or(42));
288
289 let mut centers = Array2::zeros((n_samples, n_features));
291 let mut initial_indices: Vec<usize> = (0..n_rows).collect();
292 for i in 0..n_samples {
293 let idx = rng.gen_range(0..initial_indices.len());
294 let sample_idx = initial_indices.swap_remove(idx);
295 centers.row_mut(i).assign(&data.row(sample_idx));
296 }
297
298 let mut assignments = vec![0; n_rows];
300 for _ in 0..self.n_iterations {
301 for i in 0..n_rows {
303 let point = data.row(i);
304 let mut min_dist = f64::INFINITY;
305 let mut best_cluster = 0;
306
307 for j in 0..n_samples {
308 let center = centers.row(j);
309 let dist: f64 = point
310 .iter()
311 .zip(center.iter())
312 .map(|(a, b)| (a - b).powi(2))
313 .sum();
314
315 if dist < min_dist {
316 min_dist = dist;
317 best_cluster = j;
318 }
319 }
320 assignments[i] = best_cluster;
321 }
322
323 let mut counts = vec![0; n_samples];
325 centers.fill(0.0);
326
327 for i in 0..n_rows {
328 let cluster = assignments[i];
329 let point = data.row(i);
330 for (j, &val) in point.iter().enumerate() {
331 centers[[cluster, j]] += val;
332 }
333 counts[cluster] += 1;
334 }
335
336 for j in 0..n_samples {
337 if counts[j] > 0 {
338 for k in 0..n_features {
339 centers[[j, k]] /= counts[j] as f64;
340 }
341 }
342 }
343 }
344
345 let mut selected_indices = Vec::with_capacity(n_samples);
347 for center_idx in 0..n_samples {
348 let center = centers.row(center_idx);
349 let mut min_dist = f64::INFINITY;
350 let mut best_idx = 0;
351
352 for i in 0..n_rows {
353 let point = data.row(i);
354 let dist: f64 = point
355 .iter()
356 .zip(center.iter())
357 .map(|(a, b)| (a - b).powi(2))
358 .sum();
359
360 if dist < min_dist {
361 min_dist = dist;
362 best_idx = i;
363 }
364 }
365 selected_indices.push(best_idx);
366 }
367
368 Ok(selected_indices)
369 }
370
371 fn name(&self) -> &str {
372 "KMeansSampling"
373 }
374
375 fn requires_fitting(&self) -> bool {
376 true
377 }
378
379 fn fit(&mut self, data: &Array2<f64>) -> Result<(), SklearsError> {
380 let _ = data;
382 Ok(())
383 }
384}
385
386#[derive(Debug, Clone)]
388pub struct KernelAlignmentMetric;
389
390impl ApproximationQuality for KernelAlignmentMetric {
391 fn compute(
392 &self,
393 exact_kernel: &Array2<f64>,
394 approx_kernel: &Array2<f64>,
395 ) -> Result<f64, SklearsError> {
396 let (n1, m1) = exact_kernel.dim();
397 let (n2, m2) = approx_kernel.dim();
398
399 if n1 != n2 || m1 != m2 {
400 return Err(SklearsError::InvalidInput(
401 "Kernel matrices must have the same shape".to_string(),
402 ));
403 }
404
405 let mut inner_product = 0.0;
407 let mut exact_norm = 0.0;
408 let mut approx_norm = 0.0;
409
410 for i in 0..n1 {
411 for j in 0..m1 {
412 let exact_val = exact_kernel[[i, j]];
413 let approx_val = approx_kernel[[i, j]];
414 inner_product += exact_val * approx_val;
415 exact_norm += exact_val * exact_val;
416 approx_norm += approx_val * approx_val;
417 }
418 }
419
420 if exact_norm < 1e-10 || approx_norm < 1e-10 {
421 return Ok(0.0);
422 }
423
424 Ok(inner_product / (exact_norm.sqrt() * approx_norm.sqrt()))
425 }
426
427 fn name(&self) -> &str {
428 "KernelAlignment"
429 }
430
431 fn higher_is_better(&self) -> bool {
432 true
433 }
434
435 fn acceptable_threshold(&self) -> Option<f64> {
436 Some(0.9) }
438}
439
440#[derive(Debug)]
442pub struct CompositeKernelMethod {
443 methods: Vec<Box<dyn KernelMethod>>,
445 strategy: CombinationStrategy,
447}
448
449#[derive(Debug, Clone, Copy)]
451pub enum CombinationStrategy {
452 Concatenate,
454 Average,
456 WeightedSum,
458 Product,
460}
461
462impl CompositeKernelMethod {
463 pub fn new(strategy: CombinationStrategy) -> Self {
465 Self {
466 methods: Vec::new(),
467 strategy,
468 }
469 }
470
471 pub fn add_method(&mut self, method: Box<dyn KernelMethod>) {
473 self.methods.push(method);
474 }
475
476 pub fn strategy(&self) -> CombinationStrategy {
478 self.strategy
479 }
480
481 pub fn len(&self) -> usize {
483 self.methods.len()
484 }
485
486 pub fn is_empty(&self) -> bool {
488 self.methods.is_empty()
489 }
490}
491
492impl KernelMethod for CompositeKernelMethod {
493 fn name(&self) -> &str {
494 "CompositeKernel"
495 }
496
497 fn n_output_features(&self) -> Option<usize> {
498 match self.strategy {
499 CombinationStrategy::Concatenate => {
500 let mut total = 0;
501 for method in &self.methods {
502 if let Some(n) = method.n_output_features() {
503 total += n;
504 } else {
505 return None;
506 }
507 }
508 Some(total)
509 }
510 _ => {
511 self.methods.first().and_then(|m| m.n_output_features())
513 }
514 }
515 }
516
517 fn complexity(&self) -> Complexity {
518 let mut worst = Complexity::Linear;
520 for method in &self.methods {
521 let c = method.complexity();
522 worst = match (worst, c.clone()) {
523 (Complexity::Cubic, _) | (_, Complexity::Cubic) => Complexity::Cubic,
524 (Complexity::QuadraticSamples, _) | (_, Complexity::QuadraticSamples) => {
525 Complexity::QuadraticSamples
526 }
527 (Complexity::QuadraticFeatures, _) | (_, Complexity::QuadraticFeatures) => {
528 Complexity::QuadraticFeatures
529 }
530 _ => c,
531 };
532 }
533 worst
534 }
535
536 fn error_bound(&self) -> Option<ErrorBound> {
537 None
540 }
541
542 fn supports_kernel(&self, kernel_type: KernelType) -> bool {
543 self.methods.iter().any(|m| m.supports_kernel(kernel_type))
545 }
546
547 fn supported_kernels(&self) -> Vec<KernelType> {
548 let mut kernels = Vec::new();
549 for method in &self.methods {
550 for kernel in method.supported_kernels() {
551 if !kernels.contains(&kernel) {
552 kernels.push(kernel);
553 }
554 }
555 }
556 kernels
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use scirs2_core::ndarray::array;
564
565 #[test]
566 fn test_complexity_description() {
567 let c = Complexity::Linear;
568 assert!(c.description().contains("Linear"));
569
570 let c = Complexity::QuasiLinear;
571 assert!(c.description().contains("Quasi-linear"));
572 }
573
574 #[test]
575 fn test_kernel_type_name() {
576 assert_eq!(KernelType::RBF.name(), "RBF");
577 assert_eq!(KernelType::Polynomial.name(), "Polynomial");
578 }
579
580 #[test]
581 fn test_uniform_sampling() {
582 let strategy = UniformSampling::new(Some(42));
583 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
584
585 let indices = strategy.sample(&data, 2).unwrap();
586 assert_eq!(indices.len(), 2);
587 assert!(indices[0] < 4);
588 assert!(indices[1] < 4);
589 }
590
591 #[test]
592 fn test_kmeans_sampling() {
593 let strategy = KMeansSampling::new(5, Some(42));
594 let data = array![
595 [1.0, 1.0],
596 [1.1, 1.1],
597 [5.0, 5.0],
598 [5.1, 5.1],
599 [9.0, 9.0],
600 [9.1, 9.1]
601 ];
602
603 let indices = strategy.sample(&data, 3).unwrap();
604 assert_eq!(indices.len(), 3);
605 }
606
607 #[test]
608 fn test_kernel_alignment_metric() {
609 let metric = KernelAlignmentMetric;
610 let exact = array![[1.0, 0.5], [0.5, 1.0]];
611 let approx = array![[1.0, 0.6], [0.6, 1.0]];
612
613 let alignment = metric.compute(&exact, &approx).unwrap();
614 assert!(alignment > 0.9 && alignment <= 1.0);
615 assert!(metric.higher_is_better());
616 }
617
618 #[test]
619 fn test_composite_kernel_method() {
620 let composite = CompositeKernelMethod::new(CombinationStrategy::Concatenate);
621 assert!(composite.is_empty());
622 assert_eq!(composite.len(), 0);
623 }
624
625 #[test]
626 fn test_bound_type() {
627 let bound = ErrorBound {
628 bound_type: BoundType::Probabilistic,
629 error: 0.1,
630 confidence: Some(0.95),
631 description: "Test bound".to_string(),
632 };
633
634 assert_eq!(bound.bound_type, BoundType::Probabilistic);
635 assert_eq!(bound.error, 0.1);
636 }
637}