1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::{thread_rng, Rng, SeedableRng};
11use sklears_core::error::{Result, SklearsError};
12use std::marker::PhantomData;
13
14pub trait ApproximationState {}
16
17#[derive(Debug, Clone, Copy)]
19pub struct Untrained;
21impl ApproximationState for Untrained {}
22
23#[derive(Debug, Clone, Copy)]
25pub struct Trained;
27impl ApproximationState for Trained {}
28
29pub trait KernelType {
31 const NAME: &'static str;
33
34 const SUPPORTS_PARAMETER_LEARNING: bool;
36
37 const DEFAULT_BANDWIDTH: f64;
39}
40
41#[derive(Debug, Clone, Copy)]
43pub struct RBFKernel;
45impl KernelType for RBFKernel {
46 const NAME: &'static str = "RBF";
47 const SUPPORTS_PARAMETER_LEARNING: bool = true;
48 const DEFAULT_BANDWIDTH: f64 = 1.0;
49}
50
51#[derive(Debug, Clone, Copy)]
53pub struct LaplacianKernel;
55impl KernelType for LaplacianKernel {
56 const NAME: &'static str = "Laplacian";
57 const SUPPORTS_PARAMETER_LEARNING: bool = true;
58 const DEFAULT_BANDWIDTH: f64 = 1.0;
59}
60
61#[derive(Debug, Clone, Copy)]
63pub struct PolynomialKernel;
65impl KernelType for PolynomialKernel {
66 const NAME: &'static str = "Polynomial";
67 const SUPPORTS_PARAMETER_LEARNING: bool = true;
68 const DEFAULT_BANDWIDTH: f64 = 1.0;
69}
70
71#[derive(Debug, Clone, Copy)]
73pub struct ArcCosineKernel;
75impl KernelType for ArcCosineKernel {
76 const NAME: &'static str = "ArcCosine";
77 const SUPPORTS_PARAMETER_LEARNING: bool = false;
78 const DEFAULT_BANDWIDTH: f64 = 1.0;
79}
80
81pub trait ApproximationMethod {
83 const NAME: &'static str;
85
86 const SUPPORTS_INCREMENTAL: bool;
88
89 const HAS_ERROR_BOUNDS: bool;
91
92 const COMPLEXITY: ComplexityClass;
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum ComplexityClass {
100 Quadratic,
102 QuasiLinear,
104 Linear,
106 DimensionDependent,
108}
109
110#[derive(Debug, Clone, Copy)]
112pub struct RandomFourierFeatures;
114impl ApproximationMethod for RandomFourierFeatures {
115 const NAME: &'static str = "RandomFourierFeatures";
116 const SUPPORTS_INCREMENTAL: bool = true;
117 const HAS_ERROR_BOUNDS: bool = true;
118 const COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
119}
120
121#[derive(Debug, Clone, Copy)]
123pub struct NystromMethod;
125impl ApproximationMethod for NystromMethod {
126 const NAME: &'static str = "Nystrom";
127 const SUPPORTS_INCREMENTAL: bool = false;
128 const HAS_ERROR_BOUNDS: bool = true;
129 const COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
130}
131
132#[derive(Debug, Clone, Copy)]
134pub struct FastfoodMethod;
136impl ApproximationMethod for FastfoodMethod {
137 const NAME: &'static str = "Fastfood";
138 const SUPPORTS_INCREMENTAL: bool = false;
139 const HAS_ERROR_BOUNDS: bool = true;
140 const COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
141}
142
143#[derive(Debug, Clone)]
145pub struct TypeSafeKernelApproximation<State, Kernel, Method, const N_COMPONENTS: usize>
147where
148 State: ApproximationState,
149 Kernel: KernelType,
150 Method: ApproximationMethod,
151{
152 _phantom: PhantomData<(State, Kernel, Method)>,
154
155 parameters: ApproximationParameters,
157
158 random_state: Option<u64>,
160}
161
162#[derive(Debug, Clone)]
164pub struct ApproximationParameters {
166 pub bandwidth: f64,
168
169 pub degree: Option<usize>,
171
172 pub coef0: Option<f64>,
174
175 pub custom: std::collections::HashMap<String, f64>,
177}
178
179impl Default for ApproximationParameters {
180 fn default() -> Self {
181 Self {
182 bandwidth: 1.0,
183 degree: None,
184 coef0: None,
185 custom: std::collections::HashMap::new(),
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
192pub struct FittedTypeSafeKernelApproximation<Kernel, Method, const N_COMPONENTS: usize>
194where
195 Kernel: KernelType,
196 Method: ApproximationMethod,
197{
198 _phantom: PhantomData<(Kernel, Method)>,
200
201 transformation_params: TransformationParameters<N_COMPONENTS>,
203
204 fitted_parameters: ApproximationParameters,
206
207 quality_metrics: QualityMetrics,
209}
210
211#[derive(Debug, Clone)]
213pub enum TransformationParameters<const N: usize> {
215 RandomFeatures {
217 weights: Array2<f64>,
218
219 biases: Option<Array1<f64>>,
220 },
221 Nystrom {
223 inducing_points: Array2<f64>,
224
225 eigenvalues: Array1<f64>,
226
227 eigenvectors: Array2<f64>,
228 },
229 Fastfood {
231 structured_matrices: Vec<Array2<f64>>,
232 scaling: Array1<f64>,
233 },
234}
235
236#[derive(Debug, Clone)]
238#[derive(Default)]
240pub struct QualityMetrics {
241 pub approximation_error: Option<f64>,
243
244 pub effective_rank: Option<f64>,
246
247 pub condition_number: Option<f64>,
249
250 pub kernel_alignment: Option<f64>,
252}
253
254impl<Kernel, Method, const N_COMPONENTS: usize> Default
256 for TypeSafeKernelApproximation<Untrained, Kernel, Method, N_COMPONENTS>
257where
258 Kernel: KernelType,
259 Method: ApproximationMethod,
260{
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266impl<Kernel, Method, const N_COMPONENTS: usize>
267 TypeSafeKernelApproximation<Untrained, Kernel, Method, N_COMPONENTS>
268where
269 Kernel: KernelType,
270 Method: ApproximationMethod,
271{
272 pub fn new() -> Self {
274 Self {
275 _phantom: PhantomData,
276 parameters: ApproximationParameters {
277 bandwidth: Kernel::DEFAULT_BANDWIDTH,
278 ..Default::default()
279 },
280 random_state: None,
281 }
282 }
283
284 pub fn bandwidth(mut self, bandwidth: f64) -> Self
286 where
287 Kernel: KernelTypeWithBandwidth,
288 {
289 self.parameters.bandwidth = bandwidth;
290 self
291 }
292
293 pub fn degree(mut self, degree: usize) -> Self
295 where
296 Kernel: PolynomialKernelType,
297 {
298 self.parameters.degree = Some(degree);
299 self
300 }
301
302 pub fn random_state(mut self, seed: u64) -> Self {
304 self.random_state = Some(seed);
305 self
306 }
307
308 pub fn fit(
310 self,
311 data: &Array2<f64>,
312 ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
313 where
314 Kernel: FittableKernel<Method>,
315 Method: FittableMethod<Kernel>,
316 {
317 self.fit_impl(data)
318 }
319
320 fn fit_impl(
322 self,
323 data: &Array2<f64>,
324 ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
325 where
326 Kernel: FittableKernel<Method>,
327 Method: FittableMethod<Kernel>,
328 {
329 let transformation_params = match Method::NAME {
330 "RandomFourierFeatures" => self.fit_random_fourier_features(data)?,
331 "Nystrom" => self.fit_nystrom(data)?,
332 "Fastfood" => self.fit_fastfood(data)?,
333 _ => {
334 return Err(SklearsError::InvalidOperation(format!(
335 "Unsupported method: {}",
336 Method::NAME
337 )));
338 }
339 };
340
341 Ok(FittedTypeSafeKernelApproximation {
342 _phantom: PhantomData,
343 transformation_params,
344 fitted_parameters: self.parameters,
345 quality_metrics: QualityMetrics::default(),
346 })
347 }
348
349 fn fit_random_fourier_features(
350 &self,
351 data: &Array2<f64>,
352 ) -> Result<TransformationParameters<N_COMPONENTS>> {
353 let mut rng = match self.random_state {
354 Some(seed) => RealStdRng::seed_from_u64(seed),
355 None => RealStdRng::from_seed(thread_rng().gen()),
356 };
357
358 let (_, n_features) = data.dim();
359
360 let weights = match Kernel::NAME {
361 "RBF" => {
362 let normal = RandNormal::new(0.0, self.parameters.bandwidth).unwrap();
363 Array2::from_shape_fn((N_COMPONENTS, n_features), |_| rng.sample(normal))
364 }
365 "Laplacian" => {
366 let uniform = RandUniform::new(0.0, std::f64::consts::PI).unwrap();
368 Array2::from_shape_fn((N_COMPONENTS, n_features), |_| {
369 let u = rng.sample(uniform);
370 (u - std::f64::consts::PI / 2.0).tan() / self.parameters.bandwidth
371 })
372 }
373 _ => {
374 return Err(SklearsError::InvalidOperation(format!(
375 "RFF not supported for kernel: {}",
376 Kernel::NAME
377 )));
378 }
379 };
380
381 let uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
382 let biases = Some(Array1::from_shape_fn(N_COMPONENTS, |_| rng.sample(uniform)));
383
384 Ok(TransformationParameters::RandomFeatures { weights, biases })
385 }
386
387 fn fit_nystrom(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
388 let (n_samples, _) = data.dim();
389 let n_inducing = N_COMPONENTS.min(n_samples);
390
391 let mut rng = match self.random_state {
393 Some(seed) => RealStdRng::seed_from_u64(seed),
394 None => RealStdRng::from_seed(thread_rng().gen()),
395 };
396
397 let mut indices: Vec<usize> = (0..n_samples).collect();
398 indices.shuffle(&mut rng);
399
400 let inducing_indices = &indices[..n_inducing];
401 let inducing_points = data.select(scirs2_core::ndarray::Axis(0), inducing_indices);
402
403 let kernel_matrix = self.compute_kernel_matrix(&inducing_points, &inducing_points)?;
405
406 let eigenvalues = Array1::from_shape_fn(n_inducing, |i| kernel_matrix[[i, i]]);
408 let eigenvectors = Array2::eye(n_inducing);
409
410 Ok(TransformationParameters::Nystrom {
411 inducing_points,
412 eigenvalues,
413 eigenvectors,
414 })
415 }
416
417 fn fit_fastfood(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
418 let mut rng = match self.random_state {
419 Some(seed) => RealStdRng::seed_from_u64(seed),
420 None => RealStdRng::from_seed(thread_rng().gen()),
421 };
422
423 let (_, n_features) = data.dim();
424
425 let mut structured_matrices = Vec::new();
427
428 let binary_matrix = Array2::from_shape_fn((n_features, n_features), |(i, j)| {
430 if i == j {
431 if rng.gen::<bool>() {
432 1.0
433 } else {
434 -1.0
435 }
436 } else {
437 0.0
438 }
439 });
440 structured_matrices.push(binary_matrix);
441
442 let scaling = Array1::from_shape_fn(N_COMPONENTS, |_| {
444 use scirs2_core::random::RandNormal;
445 let normal = RandNormal::new(0.0, 1.0).unwrap();
446 rng.sample(normal)
447 });
448
449 Ok(TransformationParameters::Fastfood {
450 structured_matrices,
451 scaling,
452 })
453 }
454
455 fn compute_kernel_matrix(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Result<Array2<f64>> {
456 let (n1, _) = x1.dim();
457 let (n2, _) = x2.dim();
458 let mut kernel = Array2::zeros((n1, n2));
459
460 for i in 0..n1 {
461 for j in 0..n2 {
462 let similarity = match Kernel::NAME {
463 "RBF" => {
464 let diff = &x1.row(i) - &x2.row(j);
465 let dist_sq = diff.mapv(|x| x * x).sum();
466 (-self.parameters.bandwidth * dist_sq).exp()
467 }
468 "Laplacian" => {
469 let diff = &x1.row(i) - &x2.row(j);
470 let dist = diff.mapv(|x| x.abs()).sum();
471 (-self.parameters.bandwidth * dist).exp()
472 }
473 "Polynomial" => {
474 let dot_product = x1.row(i).dot(&x2.row(j));
475 let degree = self.parameters.degree.unwrap_or(2) as i32;
476 let coef0 = self.parameters.coef0.unwrap_or(1.0);
477 (self.parameters.bandwidth * dot_product + coef0).powi(degree)
478 }
479 _ => {
480 return Err(SklearsError::InvalidOperation(format!(
481 "Unsupported kernel: {}",
482 Kernel::NAME
483 )));
484 }
485 };
486 kernel[[i, j]] = similarity;
487 }
488 }
489
490 Ok(kernel)
491 }
492}
493
494impl<Kernel, Method, const N_COMPONENTS: usize>
496 FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>
497where
498 Kernel: KernelType,
499 Method: ApproximationMethod,
500{
501 pub fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
503 match &self.transformation_params {
504 TransformationParameters::RandomFeatures { weights, biases } => {
505 self.transform_random_features(data, weights, biases.as_ref())
506 }
507 TransformationParameters::Nystrom {
508 inducing_points,
509 eigenvalues,
510 eigenvectors,
511 } => self.transform_nystrom(data, inducing_points, eigenvalues, eigenvectors),
512 TransformationParameters::Fastfood {
513 structured_matrices,
514 scaling,
515 } => self.transform_fastfood(data, structured_matrices, scaling),
516 }
517 }
518
519 fn transform_random_features(
520 &self,
521 data: &Array2<f64>,
522 weights: &Array2<f64>,
523 biases: Option<&Array1<f64>>,
524 ) -> Result<Array2<f64>> {
525 let (n_samples, _) = data.dim();
526 let mut features = Array2::zeros((n_samples, N_COMPONENTS * 2));
527
528 for (i, sample) in data.axis_iter(scirs2_core::ndarray::Axis(0)).enumerate() {
529 for j in 0..N_COMPONENTS {
530 let projection = sample.dot(&weights.row(j));
531 let phase = if let Some(b) = biases {
532 projection + b[j]
533 } else {
534 projection
535 };
536
537 features[[i, 2 * j]] = phase.cos();
538 features[[i, 2 * j + 1]] = phase.sin();
539 }
540 }
541
542 Ok(features)
543 }
544
545 fn transform_nystrom(
546 &self,
547 data: &Array2<f64>,
548 inducing_points: &Array2<f64>,
549 eigenvalues: &Array1<f64>,
550 eigenvectors: &Array2<f64>,
551 ) -> Result<Array2<f64>> {
552 let kernel_matrix = self.compute_kernel_matrix_fitted(data, inducing_points)?;
554
555 let mut features = Array2::zeros((data.nrows(), eigenvalues.len()));
557
558 for i in 0..data.nrows() {
559 for j in 0..eigenvalues.len() {
560 if eigenvalues[j] > 1e-8 {
561 let mut feature_value = 0.0;
562 for k in 0..inducing_points.nrows() {
563 feature_value += kernel_matrix[[i, k]] * eigenvectors[[k, j]];
564 }
565 features[[i, j]] = feature_value / eigenvalues[j].sqrt();
566 }
567 }
568 }
569
570 Ok(features)
571 }
572
573 fn transform_fastfood(
574 &self,
575 data: &Array2<f64>,
576 structured_matrices: &[Array2<f64>],
577 scaling: &Array1<f64>,
578 ) -> Result<Array2<f64>> {
579 let (n_samples, n_features) = data.dim();
580 let mut features = data.clone();
581
582 for matrix in structured_matrices {
584 features = features.dot(matrix);
585 }
586
587 let mut result = Array2::zeros((n_samples, N_COMPONENTS));
589 for i in 0..n_samples {
590 for j in 0..N_COMPONENTS.min(n_features) {
591 result[[i, j]] = (features[[i, j]] * scaling[j]).cos();
592 }
593 }
594
595 Ok(result)
596 }
597
598 fn compute_kernel_matrix_fitted(
599 &self,
600 x1: &Array2<f64>,
601 x2: &Array2<f64>,
602 ) -> Result<Array2<f64>> {
603 let (n1, _) = x1.dim();
604 let (n2, _) = x2.dim();
605 let mut kernel = Array2::zeros((n1, n2));
606
607 for i in 0..n1 {
608 for j in 0..n2 {
609 let similarity = match Kernel::NAME {
610 "RBF" => {
611 let diff = &x1.row(i) - &x2.row(j);
612 let dist_sq = diff.mapv(|x| x * x).sum();
613 (-self.fitted_parameters.bandwidth * dist_sq).exp()
614 }
615 "Laplacian" => {
616 let diff = &x1.row(i) - &x2.row(j);
617 let dist = diff.mapv(|x| x.abs()).sum();
618 (-self.fitted_parameters.bandwidth * dist).exp()
619 }
620 "Polynomial" => {
621 let dot_product = x1.row(i).dot(&x2.row(j));
622 let degree = self.fitted_parameters.degree.unwrap_or(2) as i32;
623 let coef0 = self.fitted_parameters.coef0.unwrap_or(1.0);
624 (self.fitted_parameters.bandwidth * dot_product + coef0).powi(degree)
625 }
626 _ => {
627 return Err(SklearsError::InvalidOperation(format!(
628 "Unsupported kernel: {}",
629 Kernel::NAME
630 )));
631 }
632 };
633 kernel[[i, j]] = similarity;
634 }
635 }
636
637 Ok(kernel)
638 }
639
640 pub fn quality_metrics(&self) -> &QualityMetrics {
642 &self.quality_metrics
643 }
644
645 pub const fn n_components() -> usize {
647 N_COMPONENTS
648 }
649
650 pub fn kernel_name(&self) -> &'static str {
652 Kernel::NAME
653 }
654
655 pub fn method_name(&self) -> &'static str {
657 Method::NAME
658 }
659}
660
661pub trait KernelTypeWithBandwidth: KernelType {}
665impl KernelTypeWithBandwidth for RBFKernel {}
666impl KernelTypeWithBandwidth for LaplacianKernel {}
667
668pub trait PolynomialKernelType: KernelType {}
670impl PolynomialKernelType for PolynomialKernel {}
671
672pub trait FittableKernel<Method: ApproximationMethod>: KernelType {}
674impl FittableKernel<RandomFourierFeatures> for RBFKernel {}
675impl FittableKernel<RandomFourierFeatures> for LaplacianKernel {}
676impl FittableKernel<NystromMethod> for RBFKernel {}
677impl FittableKernel<NystromMethod> for LaplacianKernel {}
678impl FittableKernel<NystromMethod> for PolynomialKernel {}
679impl FittableKernel<FastfoodMethod> for RBFKernel {}
680
681pub trait FittableMethod<Kernel: KernelType>: ApproximationMethod {}
683impl FittableMethod<RBFKernel> for RandomFourierFeatures {}
684impl FittableMethod<LaplacianKernel> for RandomFourierFeatures {}
685impl FittableMethod<RBFKernel> for NystromMethod {}
686impl FittableMethod<LaplacianKernel> for NystromMethod {}
687impl FittableMethod<PolynomialKernel> for NystromMethod {}
688impl FittableMethod<RBFKernel> for FastfoodMethod {}
689
690pub type RBFRandomFourierFeatures<const N: usize> =
692 TypeSafeKernelApproximation<Untrained, RBFKernel, RandomFourierFeatures, N>;
693
694pub type LaplacianRandomFourierFeatures<const N: usize> =
695 TypeSafeKernelApproximation<Untrained, LaplacianKernel, RandomFourierFeatures, N>;
696
697pub type RBFNystrom<const N: usize> =
698 TypeSafeKernelApproximation<Untrained, RBFKernel, NystromMethod, N>;
699
700pub type PolynomialNystrom<const N: usize> =
701 TypeSafeKernelApproximation<Untrained, PolynomialKernel, NystromMethod, N>;
702
703pub type RBFFastfood<const N: usize> =
704 TypeSafeKernelApproximation<Untrained, RBFKernel, FastfoodMethod, N>;
705
706pub type FittedRBFRandomFourierFeatures<const N: usize> =
708 FittedTypeSafeKernelApproximation<RBFKernel, RandomFourierFeatures, N>;
709
710pub type FittedLaplacianRandomFourierFeatures<const N: usize> =
711 FittedTypeSafeKernelApproximation<LaplacianKernel, RandomFourierFeatures, N>;
712
713pub type FittedRBFNystrom<const N: usize> =
714 FittedTypeSafeKernelApproximation<RBFKernel, NystromMethod, N>;
715
716#[allow(non_snake_case)]
717#[cfg(test)]
718mod tests {
719 use super::*;
720 use scirs2_core::ndarray::array;
721
722 #[test]
723 fn test_type_safe_rbf_rff() {
724 let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
725
726 let approximation: RBFRandomFourierFeatures<10> = TypeSafeKernelApproximation::new()
728 .bandwidth(1.5)
729 .random_state(42);
730
731 let fitted = approximation.fit(&data).unwrap();
733 let features = fitted.transform(&data).unwrap();
734
735 assert_eq!(features.shape(), &[4, 20]); assert_eq!(FittedRBFRandomFourierFeatures::<10>::n_components(), 10);
737 assert_eq!(fitted.kernel_name(), "RBF");
738 assert_eq!(fitted.method_name(), "RandomFourierFeatures");
739 }
740
741 #[test]
742 fn test_type_safe_nystrom() {
743 let data = array![
744 [1.0, 2.0, 3.0],
745 [2.0, 3.0, 4.0],
746 [3.0, 4.0, 5.0],
747 [4.0, 5.0, 6.0],
748 ];
749
750 let approximation: RBFNystrom<5> = TypeSafeKernelApproximation::new()
751 .bandwidth(2.0)
752 .random_state(42);
753
754 let fitted = approximation.fit(&data).unwrap();
755 let features = fitted.transform(&data).unwrap();
756
757 assert_eq!(features.shape()[0], 4); assert_eq!(fitted.kernel_name(), "RBF");
759 assert_eq!(fitted.method_name(), "Nystrom");
760 }
761
762 #[test]
763 fn test_polynomial_kernel() {
764 let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
765
766 let approximation: PolynomialNystrom<3> = TypeSafeKernelApproximation::new()
767 .degree(3)
768 .random_state(42);
769
770 let fitted = approximation.fit(&data).unwrap();
771 let features = fitted.transform(&data).unwrap();
772
773 assert_eq!(fitted.kernel_name(), "Polynomial");
774 assert!(features.nrows() == 3);
775 }
776
777 #[test]
778 fn test_compile_time_constants() {
779 assert_eq!(RBFKernel::NAME, "RBF");
781 assert!(RBFKernel::SUPPORTS_PARAMETER_LEARNING);
782 assert_eq!(RandomFourierFeatures::NAME, "RandomFourierFeatures");
783 assert!(RandomFourierFeatures::SUPPORTS_INCREMENTAL);
784 assert!(RandomFourierFeatures::HAS_ERROR_BOUNDS);
785 }
786
787 }
805
806pub trait ParameterValidation<const MIN: usize, const MAX: usize> {
812 const IS_VALID: bool = MIN <= MAX;
814
815 fn parameter_range() -> (usize, usize) {
817 (MIN, MAX)
818 }
819}
820
821#[derive(Debug, Clone, Copy)]
823pub struct ValidatedComponents<const N: usize>;
825
826impl<const N: usize> Default for ValidatedComponents<N> {
827 fn default() -> Self {
828 Self::new()
829 }
830}
831
832impl<const N: usize> ValidatedComponents<N> {
833 pub fn new() -> Self {
835 assert!(N > 0, "Component count must be positive");
837 assert!(N <= 10000, "Component count too large");
838 Self
839 }
840
841 pub const fn count(&self) -> usize {
843 N
844 }
845}
846
847pub trait KernelMethodCompatibility<K: KernelType, M: ApproximationMethod> {
849 const IS_COMPATIBLE: bool;
851
852 const PERFORMANCE_TIER: PerformanceTier;
854
855 const MEMORY_COMPLEXITY: ComplexityClass;
857}
858
859#[derive(Debug, Clone, Copy, PartialEq, Eq)]
861pub enum PerformanceTier {
863 Optimal,
865 Good,
867 Acceptable,
869 Poor,
871}
872
873impl KernelMethodCompatibility<RBFKernel, RandomFourierFeatures> for () {
875 const IS_COMPATIBLE: bool = true;
876 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
877 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
878}
879
880impl KernelMethodCompatibility<LaplacianKernel, RandomFourierFeatures> for () {
881 const IS_COMPATIBLE: bool = true;
882 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
883 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
884}
885
886impl KernelMethodCompatibility<RBFKernel, NystromMethod> for () {
887 const IS_COMPATIBLE: bool = true;
888 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
889 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
890}
891
892impl KernelMethodCompatibility<PolynomialKernel, NystromMethod> for () {
893 const IS_COMPATIBLE: bool = true;
894 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
895 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
896}
897
898impl KernelMethodCompatibility<RBFKernel, FastfoodMethod> for () {
899 const IS_COMPATIBLE: bool = true;
900 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
901 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
902}
903
904impl KernelMethodCompatibility<ArcCosineKernel, RandomFourierFeatures> for () {
906 const IS_COMPATIBLE: bool = true;
907 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Acceptable;
908 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
909}
910
911impl KernelMethodCompatibility<ArcCosineKernel, NystromMethod> for () {
912 const IS_COMPATIBLE: bool = false; const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
914 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
915}
916
917impl KernelMethodCompatibility<ArcCosineKernel, FastfoodMethod> for () {
918 const IS_COMPATIBLE: bool = false; const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
920 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
921}
922
923impl KernelMethodCompatibility<LaplacianKernel, NystromMethod> for () {
925 const IS_COMPATIBLE: bool = true;
926 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
927 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
928}
929
930impl KernelMethodCompatibility<PolynomialKernel, RandomFourierFeatures> for () {
931 const IS_COMPATIBLE: bool = true;
932 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
933 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
934}
935
936#[derive(Debug, Clone)]
938pub struct ValidatedKernelApproximation<K, M, const N: usize>
940where
941 K: KernelType,
942 M: ApproximationMethod,
943 (): KernelMethodCompatibility<K, M>,
944{
945 inner: TypeSafeKernelApproximation<Untrained, K, M, N>,
946 _validation: ValidatedComponents<N>,
947}
948
949impl<K, M, const N: usize> Default for ValidatedKernelApproximation<K, M, N>
950where
951 K: KernelType,
952 M: ApproximationMethod,
953 (): KernelMethodCompatibility<K, M>,
954{
955 fn default() -> Self {
956 Self::new()
957 }
958}
959
960impl<K, M, const N: usize> ValidatedKernelApproximation<K, M, N>
961where
962 K: KernelType,
963 M: ApproximationMethod,
964 (): KernelMethodCompatibility<K, M>,
965{
966 pub fn new() -> Self {
968 assert!(
970 <() as KernelMethodCompatibility<K, M>>::IS_COMPATIBLE,
971 "Incompatible kernel-method combination"
972 );
973
974 Self {
975 inner: TypeSafeKernelApproximation::new(),
976 _validation: ValidatedComponents::new(),
977 }
978 }
979
980 pub const fn performance_info() -> (PerformanceTier, ComplexityClass, ComplexityClass) {
982 (
983 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
984 M::COMPLEXITY,
985 <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY,
986 )
987 }
988
989 pub const fn is_optimal() -> bool {
991 matches!(
992 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
993 PerformanceTier::Optimal
994 )
995 }
996}
997
998#[derive(Debug, Clone, Copy)]
1000pub struct BoundedQualityMetrics<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32> {
1002 kernel_alignment: f64,
1003 approximation_error: f64,
1004 effective_rank: f64,
1005}
1006
1007impl<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32>
1008 BoundedQualityMetrics<MIN_ALIGNMENT, MAX_ERROR>
1009{
1010 pub const fn new(alignment: f64, error: f64, rank: f64) -> Option<Self> {
1012 let min_align = MIN_ALIGNMENT as f64 / 100.0; let max_err = MAX_ERROR as f64 / 100.0;
1014
1015 if alignment >= min_align && error <= max_err && rank > 0.0 {
1016 Some(Self {
1017 kernel_alignment: alignment,
1018 approximation_error: error,
1019 effective_rank: rank,
1020 })
1021 } else {
1022 None
1023 }
1024 }
1025
1026 pub const fn bounds() -> (f64, f64) {
1028 (MIN_ALIGNMENT as f64 / 100.0, MAX_ERROR as f64 / 100.0)
1029 }
1030
1031 pub const fn meets_standards(&self) -> bool {
1033 let (min_align, max_err) = Self::bounds();
1034 self.kernel_alignment >= min_align && self.approximation_error <= max_err
1035 }
1036}
1037
1038pub type HighQualityMetrics = BoundedQualityMetrics<90, 5>;
1040
1041pub type AcceptableQualityMetrics = BoundedQualityMetrics<70, 15>;
1043
1044#[macro_export]
1046macro_rules! validated_kernel_approximation {
1047 (RBF, RandomFourierFeatures, $n:literal) => {
1048 ValidatedKernelApproximation::<RBFKernel, RandomFourierFeatures, $n>::new()
1049 };
1050 (Laplacian, RandomFourierFeatures, $n:literal) => {
1051 ValidatedKernelApproximation::<LaplacianKernel, RandomFourierFeatures, $n>::new()
1052 };
1053 (RBF, Nystrom, $n:literal) => {
1054 ValidatedKernelApproximation::<RBFKernel, NystromMethod, $n>::new()
1055 };
1056 (RBF, Fastfood, $n:literal) => {
1057 ValidatedKernelApproximation::<RBFKernel, FastfoodMethod, $n>::new()
1058 };
1059}
1060
1061#[allow(non_snake_case)]
1062#[cfg(test)]
1063mod advanced_type_safety_tests {
1064 use super::*;
1065
1066 #[test]
1067 fn test_validated_components() {
1068 let components = ValidatedComponents::<100>::new();
1069 assert_eq!(components.count(), 100);
1070 }
1071
1072 #[test]
1073 fn test_kernel_method_compatibility() {
1074 assert!(<() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::IS_COMPATIBLE);
1076 assert!(!<() as KernelMethodCompatibility<ArcCosineKernel, NystromMethod>>::IS_COMPATIBLE);
1077
1078 assert_eq!(
1080 <() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::PERFORMANCE_TIER,
1081 PerformanceTier::Optimal
1082 );
1083 }
1084
1085 #[test]
1086 fn test_bounded_quality_metrics() {
1087 let high_quality = HighQualityMetrics::new(0.95, 0.03, 50.0).unwrap();
1089 assert!(high_quality.meets_standards());
1090
1091 let low_quality = HighQualityMetrics::new(0.60, 0.20, 10.0);
1093 assert!(low_quality.is_none());
1094 }
1095
1096 #[test]
1097 fn test_macro_creation() {
1098 let _rbf_rff = validated_kernel_approximation!(RBF, RandomFourierFeatures, 100);
1100 let _lap_rff = validated_kernel_approximation!(Laplacian, RandomFourierFeatures, 50);
1101 let _rbf_nys = validated_kernel_approximation!(RBF, Nystrom, 30);
1102 let _rbf_ff = validated_kernel_approximation!(RBF, Fastfood, 128);
1103
1104 let (tier, complexity, memory) = ValidatedKernelApproximation::<
1106 RBFKernel,
1107 RandomFourierFeatures,
1108 100,
1109 >::performance_info();
1110 assert_eq!(tier, PerformanceTier::Optimal);
1111 assert_eq!(complexity, ComplexityClass::Linear);
1112 assert_eq!(memory, ComplexityClass::Linear);
1113 }
1114}
1115
1116pub trait ComposableKernel: KernelType {
1123 type CompositionResult<Other: ComposableKernel>: ComposableKernel;
1124
1125 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other>;
1127}
1128
1129#[derive(Debug, Clone, Copy)]
1131pub struct SumKernel<K1: KernelType, K2: KernelType> {
1133 _phantom: PhantomData<(K1, K2)>,
1134}
1135
1136impl<K1: KernelType, K2: KernelType> KernelType for SumKernel<K1, K2> {
1137 const NAME: &'static str = "Sum";
1138 const SUPPORTS_PARAMETER_LEARNING: bool =
1139 K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1140 const DEFAULT_BANDWIDTH: f64 = (K1::DEFAULT_BANDWIDTH + K2::DEFAULT_BANDWIDTH) / 2.0;
1141}
1142
1143#[derive(Debug, Clone, Copy)]
1145pub struct ProductKernel<K1: KernelType, K2: KernelType> {
1147 _phantom: PhantomData<(K1, K2)>,
1148}
1149
1150impl<K1: KernelType, K2: KernelType> KernelType for ProductKernel<K1, K2> {
1151 const NAME: &'static str = "Product";
1152 const SUPPORTS_PARAMETER_LEARNING: bool =
1153 K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1154 const DEFAULT_BANDWIDTH: f64 = K1::DEFAULT_BANDWIDTH * K2::DEFAULT_BANDWIDTH;
1155}
1156
1157impl ComposableKernel for RBFKernel {
1158 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1159
1160 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1161 SumKernel {
1162 _phantom: PhantomData,
1163 }
1164 }
1165}
1166
1167impl ComposableKernel for LaplacianKernel {
1168 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1169
1170 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1171 SumKernel {
1172 _phantom: PhantomData,
1173 }
1174 }
1175}
1176
1177impl ComposableKernel for PolynomialKernel {
1178 type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1179
1180 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1181 ProductKernel {
1182 _phantom: PhantomData,
1183 }
1184 }
1185}
1186
1187pub trait ValidatedFeatureSize<const N: usize> {
1189 const IS_POWER_OF_TWO: bool = (N != 0) && ((N & (N - 1)) == 0);
1190 const IS_REASONABLE_SIZE: bool = N >= 8 && N <= 8192;
1191 const IS_VALID: bool = Self::IS_POWER_OF_TWO && Self::IS_REASONABLE_SIZE;
1192}
1193
1194impl<const N: usize> ValidatedFeatureSize<N> for () {}
1196
1197#[derive(Debug, Clone)]
1199pub struct ValidatedFeatures<const N: usize> {
1201 _phantom: PhantomData<[f64; N]>,
1202}
1203
1204impl<const N: usize> Default for ValidatedFeatures<N>
1205where
1206 (): ValidatedFeatureSize<N>,
1207{
1208 fn default() -> Self {
1209 Self::new()
1210 }
1211}
1212
1213impl<const N: usize> ValidatedFeatures<N>
1214where
1215 (): ValidatedFeatureSize<N>,
1216{
1217 pub const fn new() -> Self {
1219 Self {
1220 _phantom: PhantomData,
1221 }
1222 }
1223
1224 pub const fn count() -> usize {
1226 N
1227 }
1228
1229 pub const fn is_fastfood_optimal() -> bool {
1231 <() as ValidatedFeatureSize<N>>::IS_POWER_OF_TWO
1232 }
1233}
1234
1235pub trait ApproximationQualityBounds<Method: ApproximationMethod> {
1237 const ERROR_BOUND_CONSTANT: f64;
1239
1240 const SAMPLE_COMPLEXITY_EXPONENT: f64;
1242
1243 const DIMENSION_DEPENDENCY: f64;
1245
1246 fn error_bound(n_samples: usize, n_features: usize, n_components: usize) -> f64 {
1248 let base_rate = Self::ERROR_BOUND_CONSTANT;
1249 let sample_factor = (n_samples as f64).powf(-Self::SAMPLE_COMPLEXITY_EXPONENT);
1250 let dim_factor = (n_features as f64).powf(Self::DIMENSION_DEPENDENCY);
1251 let comp_factor = (n_components as f64).powf(-0.5);
1252
1253 base_rate * sample_factor * dim_factor * comp_factor
1254 }
1255}
1256
1257impl ApproximationQualityBounds<RandomFourierFeatures> for () {
1258 const ERROR_BOUND_CONSTANT: f64 = 2.0;
1259 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.25;
1260 const DIMENSION_DEPENDENCY: f64 = 0.1;
1261}
1262
1263impl ApproximationQualityBounds<NystromMethod> for () {
1264 const ERROR_BOUND_CONSTANT: f64 = 1.5;
1265 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.33;
1266 const DIMENSION_DEPENDENCY: f64 = 0.05;
1267}
1268
1269impl ApproximationQualityBounds<FastfoodMethod> for () {
1270 const ERROR_BOUND_CONSTANT: f64 = 3.0;
1271 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.2;
1272 const DIMENSION_DEPENDENCY: f64 = 0.15;
1273}
1274
1275#[derive(Debug, Clone)]
1277pub struct TypeSafeKernelConfig<K: KernelType, M: ApproximationMethod, const N: usize>
1279where
1280 (): ValidatedFeatureSize<N>,
1281 (): KernelMethodCompatibility<K, M>,
1282 (): ApproximationQualityBounds<M>,
1283{
1284 kernel_type: PhantomData<K>,
1285 method_type: PhantomData<M>,
1286 features: ValidatedFeatures<N>,
1287 bandwidth: f64,
1288 quality_threshold: f64,
1289}
1290
1291impl<K: KernelType, M: ApproximationMethod, const N: usize> Default
1292 for TypeSafeKernelConfig<K, M, N>
1293where
1294 (): ValidatedFeatureSize<N>,
1295 (): KernelMethodCompatibility<K, M>,
1296 (): ApproximationQualityBounds<M>,
1297{
1298 fn default() -> Self {
1299 Self::new()
1300 }
1301}
1302
1303impl<K: KernelType, M: ApproximationMethod, const N: usize> TypeSafeKernelConfig<K, M, N>
1304where
1305 (): ValidatedFeatureSize<N>,
1306 (): KernelMethodCompatibility<K, M>,
1307 (): ApproximationQualityBounds<M>,
1308{
1309 pub fn new() -> Self {
1311 Self {
1312 kernel_type: PhantomData,
1313 method_type: PhantomData,
1314 features: ValidatedFeatures::new(),
1315 bandwidth: K::DEFAULT_BANDWIDTH,
1316 quality_threshold: 0.1,
1317 }
1318 }
1319
1320 pub fn bandwidth(mut self, bandwidth: f64) -> Self {
1322 assert!(bandwidth > 0.0, "Bandwidth must be positive");
1323 self.bandwidth = bandwidth;
1324 self
1325 }
1326
1327 pub fn quality_threshold(mut self, threshold: f64) -> Self {
1329 assert!(
1330 threshold > 0.0 && threshold < 1.0,
1331 "Quality threshold must be between 0 and 1"
1332 );
1333 self.quality_threshold = threshold;
1334 self
1335 }
1336
1337 pub const fn performance_tier() -> PerformanceTier {
1339 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER
1340 }
1341
1342 pub const fn memory_complexity() -> ComplexityClass {
1344 <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY
1345 }
1346
1347 pub fn theoretical_error_bound(&self, n_samples: usize, n_features: usize) -> f64 {
1349 <() as ApproximationQualityBounds<M>>::error_bound(n_samples, n_features, N)
1350 }
1351
1352 pub fn meets_quality_requirements(&self, n_samples: usize, n_features: usize) -> bool {
1354 let error_bound = self.theoretical_error_bound(n_samples, n_features);
1355 error_bound <= self.quality_threshold
1356 }
1357}
1358
1359pub type ValidatedRBFRandomFourier<const N: usize> =
1361 TypeSafeKernelConfig<RBFKernel, RandomFourierFeatures, N>;
1362pub type ValidatedLaplacianNystrom<const N: usize> =
1363 TypeSafeKernelConfig<LaplacianKernel, NystromMethod, N>;
1364pub type ValidatedPolynomialRFF<const N: usize> =
1365 TypeSafeKernelConfig<PolynomialKernel, RandomFourierFeatures, N>;
1366
1367impl<K1: KernelType, K2: KernelType> ComposableKernel for SumKernel<K1, K2> {
1368 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1369
1370 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1371 SumKernel {
1372 _phantom: PhantomData,
1373 }
1374 }
1375}
1376
1377impl<K1: KernelType, K2: KernelType> ComposableKernel for ProductKernel<K1, K2> {
1378 type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1379
1380 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1381 ProductKernel {
1382 _phantom: PhantomData,
1383 }
1384 }
1385}
1386
1387#[allow(non_snake_case)]
1388#[cfg(test)]
1389mod advanced_composition_tests {
1390 use super::*;
1391
1392 #[test]
1393 fn test_validated_features() {
1394 let _features_64 = ValidatedFeatures::<64>::new();
1396 let _features_128 = ValidatedFeatures::<128>::new();
1397 let _features_256 = ValidatedFeatures::<256>::new();
1398
1399 assert_eq!(ValidatedFeatures::<64>::count(), 64);
1400 assert!(ValidatedFeatures::<64>::is_fastfood_optimal());
1401 }
1402
1403 #[test]
1404 fn test_type_safe_configs() {
1405 let config = ValidatedRBFRandomFourier::<128>::new()
1406 .bandwidth(1.5)
1407 .quality_threshold(0.05);
1408
1409 assert_eq!(
1410 ValidatedRBFRandomFourier::<128>::performance_tier(),
1411 PerformanceTier::Optimal
1412 );
1413 assert!(config.meets_quality_requirements(1000, 10));
1414 }
1415
1416 #[test]
1417 fn test_kernel_composition() {
1418 let _rbf = RBFKernel;
1419 let _laplacian = LaplacianKernel;
1420 let _polynomial = PolynomialKernel;
1421
1422 let _composed1 = _rbf.compose::<LaplacianKernel>();
1424 let _composed2 = _polynomial.compose::<RBFKernel>();
1425 }
1426
1427 #[test]
1428 fn test_approximation_bounds() {
1429 let rff_bound =
1430 <() as ApproximationQualityBounds<RandomFourierFeatures>>::error_bound(1000, 10, 100);
1431 let nystrom_bound =
1432 <() as ApproximationQualityBounds<NystromMethod>>::error_bound(1000, 10, 100);
1433 let fastfood_bound =
1434 <() as ApproximationQualityBounds<FastfoodMethod>>::error_bound(1000, 10, 100);
1435
1436 assert!(rff_bound > 0.0);
1437 assert!(nystrom_bound > 0.0);
1438 assert!(fastfood_bound > 0.0);
1439
1440 assert!(nystrom_bound < rff_bound);
1442 }
1443}
1444
1445pub struct KernelPresets;
1451
1452impl KernelPresets {
1453 pub fn fast_rbf_128() -> ValidatedRBFRandomFourier<128> {
1455 ValidatedRBFRandomFourier::<128>::new()
1456 .bandwidth(1.0)
1457 .quality_threshold(0.2)
1458 }
1459
1460 pub fn balanced_rbf_256() -> ValidatedRBFRandomFourier<256> {
1462 ValidatedRBFRandomFourier::<256>::new()
1463 .bandwidth(1.0)
1464 .quality_threshold(0.1)
1465 }
1466
1467 pub fn accurate_rbf_512() -> ValidatedRBFRandomFourier<512> {
1469 ValidatedRBFRandomFourier::<512>::new()
1470 .bandwidth(1.0)
1471 .quality_threshold(0.05)
1472 }
1473
1474 pub fn ultrafast_rbf_64() -> ValidatedRBFRandomFourier<64> {
1476 ValidatedRBFRandomFourier::<64>::new()
1477 .bandwidth(1.0)
1478 .quality_threshold(0.3)
1479 }
1480
1481 pub fn precise_nystroem_128() -> ValidatedLaplacianNystrom<128> {
1483 ValidatedLaplacianNystrom::<128>::new()
1484 .bandwidth(1.0)
1485 .quality_threshold(0.01)
1486 }
1487
1488 pub fn memory_efficient_rbf_32() -> ValidatedRBFRandomFourier<32> {
1490 ValidatedRBFRandomFourier::<32>::new()
1491 .bandwidth(1.0)
1492 .quality_threshold(0.4)
1493 }
1494
1495 pub fn polynomial_features_256() -> ValidatedPolynomialRFF<256> {
1497 ValidatedPolynomialRFF::<256>::new()
1498 .bandwidth(1.0)
1499 .quality_threshold(0.15)
1500 }
1501}
1502
1503#[derive(Debug, Clone)]
1509pub struct ProfileGuidedConfig {
1511 pub enable_pgo_feature_selection: bool,
1513
1514 pub enable_pgo_bandwidth_optimization: bool,
1516
1517 pub profile_data_path: Option<String>,
1519
1520 pub target_architecture: TargetArchitecture,
1522
1523 pub optimization_level: OptimizationLevel,
1525}
1526
1527#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1529pub enum TargetArchitecture {
1531 X86_64Generic,
1533 X86_64AVX2,
1535 X86_64AVX512,
1537 ARM64,
1539 ARM64NEON,
1541}
1542
1543#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1545pub enum OptimizationLevel {
1547 None,
1549 Basic,
1551 Aggressive,
1553 Maximum,
1555}
1556
1557impl Default for ProfileGuidedConfig {
1558 fn default() -> Self {
1559 Self {
1560 enable_pgo_feature_selection: false,
1561 enable_pgo_bandwidth_optimization: false,
1562 profile_data_path: None,
1563 target_architecture: TargetArchitecture::X86_64Generic,
1564 optimization_level: OptimizationLevel::Basic,
1565 }
1566 }
1567}
1568
1569impl ProfileGuidedConfig {
1570 pub fn new() -> Self {
1572 Self::default()
1573 }
1574
1575 pub fn enable_feature_selection(mut self) -> Self {
1577 self.enable_pgo_feature_selection = true;
1578 self
1579 }
1580
1581 pub fn enable_bandwidth_optimization(mut self) -> Self {
1583 self.enable_pgo_bandwidth_optimization = true;
1584 self
1585 }
1586
1587 pub fn profile_data_path<P: Into<String>>(mut self, path: P) -> Self {
1589 self.profile_data_path = Some(path.into());
1590 self
1591 }
1592
1593 pub fn target_architecture(mut self, arch: TargetArchitecture) -> Self {
1595 self.target_architecture = arch;
1596 self
1597 }
1598
1599 pub fn optimization_level(mut self, level: OptimizationLevel) -> Self {
1601 self.optimization_level = level;
1602 self
1603 }
1604
1605 pub fn recommended_feature_count(&self, data_size: usize, dimensionality: usize) -> usize {
1607 let base_features = match self.target_architecture {
1608 TargetArchitecture::X86_64Generic => 128,
1609 TargetArchitecture::X86_64AVX2 => 256,
1610 TargetArchitecture::X86_64AVX512 => 512,
1611 TargetArchitecture::ARM64 => 128,
1612 TargetArchitecture::ARM64NEON => 256,
1613 };
1614
1615 let scale_factor = match self.optimization_level {
1616 OptimizationLevel::None => 0.5,
1617 OptimizationLevel::Basic => 1.0,
1618 OptimizationLevel::Aggressive => 1.5,
1619 OptimizationLevel::Maximum => 2.0,
1620 };
1621
1622 let scaled_features = (base_features as f64 * scale_factor) as usize;
1623
1624 let data_adjustment = if data_size > 10000 {
1626 1.2
1627 } else if data_size < 1000 {
1628 0.8
1629 } else {
1630 1.0
1631 };
1632
1633 let dimension_adjustment = if dimensionality > 100 {
1634 1.1
1635 } else if dimensionality < 10 {
1636 0.9
1637 } else {
1638 1.0
1639 };
1640
1641 ((scaled_features as f64 * data_adjustment * dimension_adjustment) as usize)
1642 .max(32)
1643 .min(1024)
1644 }
1645}
1646
1647#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1653pub struct SerializableKernelConfig {
1655 pub kernel_type: String,
1657
1658 pub approximation_method: String,
1660
1661 pub n_components: usize,
1663
1664 pub bandwidth: f64,
1666
1667 pub quality_threshold: f64,
1669
1670 pub random_state: Option<u64>,
1672
1673 pub additional_params: std::collections::HashMap<String, f64>,
1675}
1676
1677#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1679pub struct SerializableFittedParams {
1681 pub config: SerializableKernelConfig,
1683
1684 pub random_features: Option<Vec<Vec<f64>>>,
1686
1687 pub selected_indices: Option<Vec<usize>>,
1689
1690 pub eigenvalues: Option<Vec<f64>>,
1692
1693 pub eigenvectors: Option<Vec<Vec<f64>>>,
1695
1696 pub quality_metrics: std::collections::HashMap<String, f64>,
1698
1699 pub fitted_timestamp: Option<u64>,
1701}
1702
1703pub trait SerializableKernelApproximation {
1705 fn export_config(&self) -> Result<SerializableKernelConfig>;
1707
1708 fn import_config(config: &SerializableKernelConfig) -> Result<Self>
1710 where
1711 Self: Sized;
1712
1713 fn export_fitted_params(&self) -> Result<SerializableFittedParams>;
1715
1716 fn import_fitted_params(&mut self, params: &SerializableFittedParams) -> Result<()>;
1718
1719 fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
1721 let config = self.export_config()?;
1722 let fitted_params = self.export_fitted_params()?;
1723
1724 let model_data = serde_json::json!({
1725 "config": config,
1726 "fitted_params": fitted_params,
1727 "version": "1.0",
1728 "sklears_version": env!("CARGO_PKG_VERSION"),
1729 });
1730
1731 std::fs::write(path, serde_json::to_string_pretty(&model_data).unwrap())
1732 .map_err(SklearsError::from)?;
1733
1734 Ok(())
1735 }
1736
1737 fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self>
1739 where
1740 Self: Sized,
1741 {
1742 let content = std::fs::read_to_string(path).map_err(SklearsError::from)?;
1743
1744 let model_data: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
1745 SklearsError::SerializationError(format!("Failed to parse JSON: {}", e))
1746 })?;
1747
1748 let config: SerializableKernelConfig = serde_json::from_value(model_data["config"].clone())
1749 .map_err(|e| {
1750 SklearsError::SerializationError(format!("Failed to deserialize config: {}", e))
1751 })?;
1752
1753 let fitted_params: SerializableFittedParams =
1754 serde_json::from_value(model_data["fitted_params"].clone()).map_err(|e| {
1755 SklearsError::SerializationError(format!(
1756 "Failed to deserialize fitted params: {}",
1757 e
1758 ))
1759 })?;
1760
1761 let mut model = Self::import_config(&config)?;
1762 model.import_fitted_params(&fitted_params)?;
1763
1764 Ok(model)
1765 }
1766}
1767
1768#[allow(non_snake_case)]
1769#[cfg(test)]
1770mod preset_tests {
1771 use super::*;
1772
1773 #[test]
1774 fn test_kernel_presets() {
1775 let _fast_config = KernelPresets::fast_rbf_128();
1776 assert_eq!(
1777 ValidatedRBFRandomFourier::<128>::performance_tier(),
1778 PerformanceTier::Optimal
1779 );
1780
1781 let _balanced_config = KernelPresets::balanced_rbf_256();
1782 assert_eq!(
1783 ValidatedRBFRandomFourier::<256>::performance_tier(),
1784 PerformanceTier::Optimal
1785 );
1786
1787 let _accurate_config = KernelPresets::accurate_rbf_512();
1788 assert_eq!(
1789 ValidatedRBFRandomFourier::<512>::performance_tier(),
1790 PerformanceTier::Optimal
1791 );
1792 }
1793
1794 #[test]
1795 fn test_profile_guided_config() {
1796 let pgo_config = ProfileGuidedConfig::new()
1797 .enable_feature_selection()
1798 .enable_bandwidth_optimization()
1799 .target_architecture(TargetArchitecture::X86_64AVX2)
1800 .optimization_level(OptimizationLevel::Aggressive);
1801
1802 assert!(pgo_config.enable_pgo_feature_selection);
1803 assert!(pgo_config.enable_pgo_bandwidth_optimization);
1804 assert_eq!(
1805 pgo_config.target_architecture,
1806 TargetArchitecture::X86_64AVX2
1807 );
1808 assert_eq!(pgo_config.optimization_level, OptimizationLevel::Aggressive);
1809
1810 let features = pgo_config.recommended_feature_count(5000, 50);
1812 assert!(features >= 32 && features <= 1024);
1813 }
1814
1815 #[test]
1816 fn test_serializable_config() {
1817 let config = SerializableKernelConfig {
1818 kernel_type: "RBF".to_string(),
1819 approximation_method: "RandomFourierFeatures".to_string(),
1820 n_components: 256,
1821 bandwidth: 1.5,
1822 quality_threshold: 0.1,
1823 random_state: Some(42),
1824 additional_params: std::collections::HashMap::new(),
1825 };
1826
1827 let serialized = serde_json::to_string(&config).unwrap();
1829 assert!(serialized.contains("RBF"));
1830 assert!(serialized.contains("RandomFourierFeatures"));
1831
1832 let deserialized: SerializableKernelConfig = serde_json::from_str(&serialized).unwrap();
1834 assert_eq!(deserialized.kernel_type, "RBF");
1835 assert_eq!(deserialized.n_components, 256);
1836 assert_eq!(deserialized.bandwidth, 1.5);
1837 }
1838}