1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::Distribution;
11use scirs2_core::random::{thread_rng, Rng, SeedableRng};
12use sklears_core::error::{Result, SklearsError};
13use std::marker::PhantomData;
14
15pub trait ApproximationState {}
17
18#[derive(Debug, Clone, Copy)]
20pub struct Untrained;
22impl ApproximationState for Untrained {}
23
24#[derive(Debug, Clone, Copy)]
26pub struct Trained;
28impl ApproximationState for Trained {}
29
30pub trait KernelType {
32 const NAME: &'static str;
34
35 const SUPPORTS_PARAMETER_LEARNING: bool;
37
38 const DEFAULT_BANDWIDTH: f64;
40}
41
42#[derive(Debug, Clone, Copy)]
44pub struct RBFKernel;
46impl KernelType for RBFKernel {
47 const NAME: &'static str = "RBF";
48 const SUPPORTS_PARAMETER_LEARNING: bool = true;
49 const DEFAULT_BANDWIDTH: f64 = 1.0;
50}
51
52#[derive(Debug, Clone, Copy)]
54pub struct LaplacianKernel;
56impl KernelType for LaplacianKernel {
57 const NAME: &'static str = "Laplacian";
58 const SUPPORTS_PARAMETER_LEARNING: bool = true;
59 const DEFAULT_BANDWIDTH: f64 = 1.0;
60}
61
62#[derive(Debug, Clone, Copy)]
64pub struct PolynomialKernel;
66impl KernelType for PolynomialKernel {
67 const NAME: &'static str = "Polynomial";
68 const SUPPORTS_PARAMETER_LEARNING: bool = true;
69 const DEFAULT_BANDWIDTH: f64 = 1.0;
70}
71
72#[derive(Debug, Clone, Copy)]
74pub struct ArcCosineKernel;
76impl KernelType for ArcCosineKernel {
77 const NAME: &'static str = "ArcCosine";
78 const SUPPORTS_PARAMETER_LEARNING: bool = false;
79 const DEFAULT_BANDWIDTH: f64 = 1.0;
80}
81
82pub trait ApproximationMethod {
84 const NAME: &'static str;
86
87 const SUPPORTS_INCREMENTAL: bool;
89
90 const HAS_ERROR_BOUNDS: bool;
92
93 const COMPLEXITY: ComplexityClass;
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum ComplexityClass {
101 Quadratic,
103 QuasiLinear,
105 Linear,
107 DimensionDependent,
109}
110
111#[derive(Debug, Clone, Copy)]
113pub struct RandomFourierFeatures;
115impl ApproximationMethod for RandomFourierFeatures {
116 const NAME: &'static str = "RandomFourierFeatures";
117 const SUPPORTS_INCREMENTAL: bool = true;
118 const HAS_ERROR_BOUNDS: bool = true;
119 const COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
120}
121
122#[derive(Debug, Clone, Copy)]
124pub struct NystromMethod;
126impl ApproximationMethod for NystromMethod {
127 const NAME: &'static str = "Nystrom";
128 const SUPPORTS_INCREMENTAL: bool = false;
129 const HAS_ERROR_BOUNDS: bool = true;
130 const COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
131}
132
133#[derive(Debug, Clone, Copy)]
135pub struct FastfoodMethod;
137impl ApproximationMethod for FastfoodMethod {
138 const NAME: &'static str = "Fastfood";
139 const SUPPORTS_INCREMENTAL: bool = false;
140 const HAS_ERROR_BOUNDS: bool = true;
141 const COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
142}
143
144#[derive(Debug, Clone)]
146pub struct TypeSafeKernelApproximation<State, Kernel, Method, const N_COMPONENTS: usize>
148where
149 State: ApproximationState,
150 Kernel: KernelType,
151 Method: ApproximationMethod,
152{
153 _phantom: PhantomData<(State, Kernel, Method)>,
155
156 parameters: ApproximationParameters,
158
159 random_state: Option<u64>,
161}
162
163#[derive(Debug, Clone)]
165pub struct ApproximationParameters {
167 pub bandwidth: f64,
169
170 pub degree: Option<usize>,
172
173 pub coef0: Option<f64>,
175
176 pub custom: std::collections::HashMap<String, f64>,
178}
179
180impl Default for ApproximationParameters {
181 fn default() -> Self {
182 Self {
183 bandwidth: 1.0,
184 degree: None,
185 coef0: None,
186 custom: std::collections::HashMap::new(),
187 }
188 }
189}
190
191#[derive(Debug, Clone)]
193pub struct FittedTypeSafeKernelApproximation<Kernel, Method, const N_COMPONENTS: usize>
195where
196 Kernel: KernelType,
197 Method: ApproximationMethod,
198{
199 _phantom: PhantomData<(Kernel, Method)>,
201
202 transformation_params: TransformationParameters<N_COMPONENTS>,
204
205 fitted_parameters: ApproximationParameters,
207
208 quality_metrics: QualityMetrics,
210}
211
212#[derive(Debug, Clone)]
214pub enum TransformationParameters<const N: usize> {
216 RandomFeatures {
218 weights: Array2<f64>,
219
220 biases: Option<Array1<f64>>,
221 },
222 Nystrom {
224 inducing_points: Array2<f64>,
225
226 eigenvalues: Array1<f64>,
227
228 eigenvectors: Array2<f64>,
229 },
230 Fastfood {
232 structured_matrices: Vec<Array2<f64>>,
233 scaling: Array1<f64>,
234 },
235}
236
237#[derive(Debug, Clone)]
239pub struct QualityMetrics {
241 pub approximation_error: Option<f64>,
243
244 pub effective_rank: Option<f64>,
246
247 pub condition_number: Option<f64>,
249
250 pub kernel_alignment: Option<f64>,
252}
253
254impl Default for QualityMetrics {
255 fn default() -> Self {
256 Self {
257 approximation_error: None,
258 effective_rank: None,
259 condition_number: None,
260 kernel_alignment: None,
261 }
262 }
263}
264
265impl<Kernel, Method, const N_COMPONENTS: usize>
267 TypeSafeKernelApproximation<Untrained, Kernel, Method, N_COMPONENTS>
268where
269 Kernel: KernelType,
270 Method: ApproximationMethod,
271{
272 pub fn new() -> Self {
274 Self {
275 _phantom: PhantomData,
276 parameters: ApproximationParameters {
277 bandwidth: Kernel::DEFAULT_BANDWIDTH,
278 ..Default::default()
279 },
280 random_state: None,
281 }
282 }
283
284 pub fn bandwidth(mut self, bandwidth: f64) -> Self
286 where
287 Kernel: KernelTypeWithBandwidth,
288 {
289 self.parameters.bandwidth = bandwidth;
290 self
291 }
292
293 pub fn degree(mut self, degree: usize) -> Self
295 where
296 Kernel: PolynomialKernelType,
297 {
298 self.parameters.degree = Some(degree);
299 self
300 }
301
302 pub fn random_state(mut self, seed: u64) -> Self {
304 self.random_state = Some(seed);
305 self
306 }
307
308 pub fn fit(
310 self,
311 data: &Array2<f64>,
312 ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
313 where
314 Kernel: FittableKernel<Method>,
315 Method: FittableMethod<Kernel>,
316 {
317 self.fit_impl(data)
318 }
319
320 fn fit_impl(
322 self,
323 data: &Array2<f64>,
324 ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
325 where
326 Kernel: FittableKernel<Method>,
327 Method: FittableMethod<Kernel>,
328 {
329 let transformation_params = match Method::NAME {
330 "RandomFourierFeatures" => self.fit_random_fourier_features(data)?,
331 "Nystrom" => self.fit_nystrom(data)?,
332 "Fastfood" => self.fit_fastfood(data)?,
333 _ => {
334 return Err(SklearsError::InvalidOperation(format!(
335 "Unsupported method: {}",
336 Method::NAME
337 )));
338 }
339 };
340
341 Ok(FittedTypeSafeKernelApproximation {
342 _phantom: PhantomData,
343 transformation_params,
344 fitted_parameters: self.parameters,
345 quality_metrics: QualityMetrics::default(),
346 })
347 }
348
349 fn fit_random_fourier_features(
350 &self,
351 data: &Array2<f64>,
352 ) -> Result<TransformationParameters<N_COMPONENTS>> {
353 let mut rng = match self.random_state {
354 Some(seed) => RealStdRng::seed_from_u64(seed),
355 None => RealStdRng::from_seed(thread_rng().gen()),
356 };
357
358 let (_, n_features) = data.dim();
359
360 let weights = match Kernel::NAME {
361 "RBF" => {
362 let normal = RandNormal::new(0.0, self.parameters.bandwidth).unwrap();
363 Array2::from_shape_fn((N_COMPONENTS, n_features), |_| rng.sample(normal))
364 }
365 "Laplacian" => {
366 let uniform = RandUniform::new(0.0, std::f64::consts::PI).unwrap();
368 Array2::from_shape_fn((N_COMPONENTS, n_features), |_| {
369 let u = rng.sample(uniform);
370 (u - std::f64::consts::PI / 2.0).tan() / self.parameters.bandwidth
371 })
372 }
373 _ => {
374 return Err(SklearsError::InvalidOperation(format!(
375 "RFF not supported for kernel: {}",
376 Kernel::NAME
377 )));
378 }
379 };
380
381 let uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
382 let biases = Some(Array1::from_shape_fn(N_COMPONENTS, |_| rng.sample(uniform)));
383
384 Ok(TransformationParameters::RandomFeatures { weights, biases })
385 }
386
387 fn fit_nystrom(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
388 let (n_samples, _) = data.dim();
389 let n_inducing = N_COMPONENTS.min(n_samples);
390
391 let mut rng = match self.random_state {
393 Some(seed) => RealStdRng::seed_from_u64(seed),
394 None => RealStdRng::from_seed(thread_rng().gen()),
395 };
396
397 let mut indices: Vec<usize> = (0..n_samples).collect();
398 indices.shuffle(&mut rng);
399
400 let inducing_indices = &indices[..n_inducing];
401 let inducing_points = data.select(scirs2_core::ndarray::Axis(0), inducing_indices);
402
403 let kernel_matrix = self.compute_kernel_matrix(&inducing_points, &inducing_points)?;
405
406 let eigenvalues = Array1::from_shape_fn(n_inducing, |i| kernel_matrix[[i, i]]);
408 let eigenvectors = Array2::eye(n_inducing);
409
410 Ok(TransformationParameters::Nystrom {
411 inducing_points,
412 eigenvalues,
413 eigenvectors,
414 })
415 }
416
417 fn fit_fastfood(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
418 let mut rng = match self.random_state {
419 Some(seed) => RealStdRng::seed_from_u64(seed),
420 None => RealStdRng::from_seed(thread_rng().gen()),
421 };
422
423 let (_, n_features) = data.dim();
424
425 let mut structured_matrices = Vec::new();
427
428 let binary_matrix = Array2::from_shape_fn((n_features, n_features), |(i, j)| {
430 if i == j {
431 if rng.gen::<bool>() {
432 1.0
433 } else {
434 -1.0
435 }
436 } else {
437 0.0
438 }
439 });
440 structured_matrices.push(binary_matrix);
441
442 let scaling = Array1::from_shape_fn(N_COMPONENTS, |_| {
444 use scirs2_core::random::{RandNormal, Rng};
445 let normal = RandNormal::new(0.0, 1.0).unwrap();
446 rng.sample(normal)
447 });
448
449 Ok(TransformationParameters::Fastfood {
450 structured_matrices,
451 scaling,
452 })
453 }
454
455 fn compute_kernel_matrix(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Result<Array2<f64>> {
456 let (n1, _) = x1.dim();
457 let (n2, _) = x2.dim();
458 let mut kernel = Array2::zeros((n1, n2));
459
460 for i in 0..n1 {
461 for j in 0..n2 {
462 let similarity = match Kernel::NAME {
463 "RBF" => {
464 let diff = &x1.row(i) - &x2.row(j);
465 let dist_sq = diff.mapv(|x| x * x).sum();
466 (-self.parameters.bandwidth * dist_sq).exp()
467 }
468 "Laplacian" => {
469 let diff = &x1.row(i) - &x2.row(j);
470 let dist = diff.mapv(|x| x.abs()).sum();
471 (-self.parameters.bandwidth * dist).exp()
472 }
473 "Polynomial" => {
474 let dot_product = x1.row(i).dot(&x2.row(j));
475 let degree = self.parameters.degree.unwrap_or(2) as i32;
476 let coef0 = self.parameters.coef0.unwrap_or(1.0);
477 (self.parameters.bandwidth * dot_product + coef0).powi(degree)
478 }
479 _ => {
480 return Err(SklearsError::InvalidOperation(format!(
481 "Unsupported kernel: {}",
482 Kernel::NAME
483 )));
484 }
485 };
486 kernel[[i, j]] = similarity;
487 }
488 }
489
490 Ok(kernel)
491 }
492}
493
494impl<Kernel, Method, const N_COMPONENTS: usize>
496 FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>
497where
498 Kernel: KernelType,
499 Method: ApproximationMethod,
500{
501 pub fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
503 match &self.transformation_params {
504 TransformationParameters::RandomFeatures { weights, biases } => {
505 self.transform_random_features(data, weights, biases.as_ref())
506 }
507 TransformationParameters::Nystrom {
508 inducing_points,
509 eigenvalues,
510 eigenvectors,
511 } => self.transform_nystrom(data, inducing_points, eigenvalues, eigenvectors),
512 TransformationParameters::Fastfood {
513 structured_matrices,
514 scaling,
515 } => self.transform_fastfood(data, structured_matrices, scaling),
516 }
517 }
518
519 fn transform_random_features(
520 &self,
521 data: &Array2<f64>,
522 weights: &Array2<f64>,
523 biases: Option<&Array1<f64>>,
524 ) -> Result<Array2<f64>> {
525 let (n_samples, _) = data.dim();
526 let mut features = Array2::zeros((n_samples, N_COMPONENTS * 2));
527
528 for (i, sample) in data.axis_iter(scirs2_core::ndarray::Axis(0)).enumerate() {
529 for j in 0..N_COMPONENTS {
530 let projection = sample.dot(&weights.row(j));
531 let phase = if let Some(b) = biases {
532 projection + b[j]
533 } else {
534 projection
535 };
536
537 features[[i, 2 * j]] = phase.cos();
538 features[[i, 2 * j + 1]] = phase.sin();
539 }
540 }
541
542 Ok(features)
543 }
544
545 fn transform_nystrom(
546 &self,
547 data: &Array2<f64>,
548 inducing_points: &Array2<f64>,
549 eigenvalues: &Array1<f64>,
550 eigenvectors: &Array2<f64>,
551 ) -> Result<Array2<f64>> {
552 let kernel_matrix = self.compute_kernel_matrix_fitted(data, inducing_points)?;
554
555 let mut features = Array2::zeros((data.nrows(), eigenvalues.len()));
557
558 for i in 0..data.nrows() {
559 for j in 0..eigenvalues.len() {
560 if eigenvalues[j] > 1e-8 {
561 let mut feature_value = 0.0;
562 for k in 0..inducing_points.nrows() {
563 feature_value += kernel_matrix[[i, k]] * eigenvectors[[k, j]];
564 }
565 features[[i, j]] = feature_value / eigenvalues[j].sqrt();
566 }
567 }
568 }
569
570 Ok(features)
571 }
572
573 fn transform_fastfood(
574 &self,
575 data: &Array2<f64>,
576 structured_matrices: &[Array2<f64>],
577 scaling: &Array1<f64>,
578 ) -> Result<Array2<f64>> {
579 let (n_samples, n_features) = data.dim();
580 let mut features = data.clone();
581
582 for matrix in structured_matrices {
584 features = features.dot(matrix);
585 }
586
587 let mut result = Array2::zeros((n_samples, N_COMPONENTS));
589 for i in 0..n_samples {
590 for j in 0..N_COMPONENTS.min(n_features) {
591 result[[i, j]] = (features[[i, j]] * scaling[j]).cos();
592 }
593 }
594
595 Ok(result)
596 }
597
598 fn compute_kernel_matrix_fitted(
599 &self,
600 x1: &Array2<f64>,
601 x2: &Array2<f64>,
602 ) -> Result<Array2<f64>> {
603 let (n1, _) = x1.dim();
604 let (n2, _) = x2.dim();
605 let mut kernel = Array2::zeros((n1, n2));
606
607 for i in 0..n1 {
608 for j in 0..n2 {
609 let similarity = match Kernel::NAME {
610 "RBF" => {
611 let diff = &x1.row(i) - &x2.row(j);
612 let dist_sq = diff.mapv(|x| x * x).sum();
613 (-self.fitted_parameters.bandwidth * dist_sq).exp()
614 }
615 "Laplacian" => {
616 let diff = &x1.row(i) - &x2.row(j);
617 let dist = diff.mapv(|x| x.abs()).sum();
618 (-self.fitted_parameters.bandwidth * dist).exp()
619 }
620 "Polynomial" => {
621 let dot_product = x1.row(i).dot(&x2.row(j));
622 let degree = self.fitted_parameters.degree.unwrap_or(2) as i32;
623 let coef0 = self.fitted_parameters.coef0.unwrap_or(1.0);
624 (self.fitted_parameters.bandwidth * dot_product + coef0).powi(degree)
625 }
626 _ => {
627 return Err(SklearsError::InvalidOperation(format!(
628 "Unsupported kernel: {}",
629 Kernel::NAME
630 )));
631 }
632 };
633 kernel[[i, j]] = similarity;
634 }
635 }
636
637 Ok(kernel)
638 }
639
640 pub fn quality_metrics(&self) -> &QualityMetrics {
642 &self.quality_metrics
643 }
644
645 pub const fn n_components() -> usize {
647 N_COMPONENTS
648 }
649
650 pub fn kernel_name(&self) -> &'static str {
652 Kernel::NAME
653 }
654
655 pub fn method_name(&self) -> &'static str {
657 Method::NAME
658 }
659}
660
661pub trait KernelTypeWithBandwidth: KernelType {}
665impl KernelTypeWithBandwidth for RBFKernel {}
666impl KernelTypeWithBandwidth for LaplacianKernel {}
667
668pub trait PolynomialKernelType: KernelType {}
670impl PolynomialKernelType for PolynomialKernel {}
671
672pub trait FittableKernel<Method: ApproximationMethod>: KernelType {}
674impl FittableKernel<RandomFourierFeatures> for RBFKernel {}
675impl FittableKernel<RandomFourierFeatures> for LaplacianKernel {}
676impl FittableKernel<NystromMethod> for RBFKernel {}
677impl FittableKernel<NystromMethod> for LaplacianKernel {}
678impl FittableKernel<NystromMethod> for PolynomialKernel {}
679impl FittableKernel<FastfoodMethod> for RBFKernel {}
680
681pub trait FittableMethod<Kernel: KernelType>: ApproximationMethod {}
683impl FittableMethod<RBFKernel> for RandomFourierFeatures {}
684impl FittableMethod<LaplacianKernel> for RandomFourierFeatures {}
685impl FittableMethod<RBFKernel> for NystromMethod {}
686impl FittableMethod<LaplacianKernel> for NystromMethod {}
687impl FittableMethod<PolynomialKernel> for NystromMethod {}
688impl FittableMethod<RBFKernel> for FastfoodMethod {}
689
690pub type RBFRandomFourierFeatures<const N: usize> =
692 TypeSafeKernelApproximation<Untrained, RBFKernel, RandomFourierFeatures, N>;
693
694pub type LaplacianRandomFourierFeatures<const N: usize> =
695 TypeSafeKernelApproximation<Untrained, LaplacianKernel, RandomFourierFeatures, N>;
696
697pub type RBFNystrom<const N: usize> =
698 TypeSafeKernelApproximation<Untrained, RBFKernel, NystromMethod, N>;
699
700pub type PolynomialNystrom<const N: usize> =
701 TypeSafeKernelApproximation<Untrained, PolynomialKernel, NystromMethod, N>;
702
703pub type RBFFastfood<const N: usize> =
704 TypeSafeKernelApproximation<Untrained, RBFKernel, FastfoodMethod, N>;
705
706pub type FittedRBFRandomFourierFeatures<const N: usize> =
708 FittedTypeSafeKernelApproximation<RBFKernel, RandomFourierFeatures, N>;
709
710pub type FittedLaplacianRandomFourierFeatures<const N: usize> =
711 FittedTypeSafeKernelApproximation<LaplacianKernel, RandomFourierFeatures, N>;
712
713pub type FittedRBFNystrom<const N: usize> =
714 FittedTypeSafeKernelApproximation<RBFKernel, NystromMethod, N>;
715
716#[allow(non_snake_case)]
717#[cfg(test)]
718mod tests {
719 use super::*;
720 use scirs2_core::ndarray::array;
721
722 #[test]
723 fn test_type_safe_rbf_rff() {
724 let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
725
726 let approximation: RBFRandomFourierFeatures<10> = TypeSafeKernelApproximation::new()
728 .bandwidth(1.5)
729 .random_state(42);
730
731 let fitted = approximation.fit(&data).unwrap();
733 let features = fitted.transform(&data).unwrap();
734
735 assert_eq!(features.shape(), &[4, 20]); assert_eq!(FittedRBFRandomFourierFeatures::<10>::n_components(), 10);
737 assert_eq!(fitted.kernel_name(), "RBF");
738 assert_eq!(fitted.method_name(), "RandomFourierFeatures");
739 }
740
741 #[test]
742 fn test_type_safe_nystrom() {
743 let data = array![
744 [1.0, 2.0, 3.0],
745 [2.0, 3.0, 4.0],
746 [3.0, 4.0, 5.0],
747 [4.0, 5.0, 6.0],
748 ];
749
750 let approximation: RBFNystrom<5> = TypeSafeKernelApproximation::new()
751 .bandwidth(2.0)
752 .random_state(42);
753
754 let fitted = approximation.fit(&data).unwrap();
755 let features = fitted.transform(&data).unwrap();
756
757 assert_eq!(features.shape()[0], 4); assert_eq!(fitted.kernel_name(), "RBF");
759 assert_eq!(fitted.method_name(), "Nystrom");
760 }
761
762 #[test]
763 fn test_polynomial_kernel() {
764 let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
765
766 let approximation: PolynomialNystrom<3> = TypeSafeKernelApproximation::new()
767 .degree(3)
768 .random_state(42);
769
770 let fitted = approximation.fit(&data).unwrap();
771 let features = fitted.transform(&data).unwrap();
772
773 assert_eq!(fitted.kernel_name(), "Polynomial");
774 assert!(features.nrows() == 3);
775 }
776
777 #[test]
778 fn test_compile_time_constants() {
779 assert_eq!(RBFKernel::NAME, "RBF");
781 assert!(RBFKernel::SUPPORTS_PARAMETER_LEARNING);
782 assert_eq!(RandomFourierFeatures::NAME, "RandomFourierFeatures");
783 assert!(RandomFourierFeatures::SUPPORTS_INCREMENTAL);
784 assert!(RandomFourierFeatures::HAS_ERROR_BOUNDS);
785 }
786
787 }
805
806pub trait ParameterValidation<const MIN: usize, const MAX: usize> {
812 const IS_VALID: bool = MIN <= MAX;
814
815 fn parameter_range() -> (usize, usize) {
817 (MIN, MAX)
818 }
819}
820
821#[derive(Debug, Clone, Copy)]
823pub struct ValidatedComponents<const N: usize>;
825
826impl<const N: usize> ValidatedComponents<N> {
827 pub fn new() -> Self {
829 assert!(N > 0, "Component count must be positive");
831 assert!(N <= 10000, "Component count too large");
832 Self
833 }
834
835 pub const fn count(&self) -> usize {
837 N
838 }
839}
840
841pub trait KernelMethodCompatibility<K: KernelType, M: ApproximationMethod> {
843 const IS_COMPATIBLE: bool;
845
846 const PERFORMANCE_TIER: PerformanceTier;
848
849 const MEMORY_COMPLEXITY: ComplexityClass;
851}
852
853#[derive(Debug, Clone, Copy, PartialEq, Eq)]
855pub enum PerformanceTier {
857 Optimal,
859 Good,
861 Acceptable,
863 Poor,
865}
866
867impl KernelMethodCompatibility<RBFKernel, RandomFourierFeatures> for () {
869 const IS_COMPATIBLE: bool = true;
870 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
871 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
872}
873
874impl KernelMethodCompatibility<LaplacianKernel, RandomFourierFeatures> for () {
875 const IS_COMPATIBLE: bool = true;
876 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
877 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
878}
879
880impl KernelMethodCompatibility<RBFKernel, NystromMethod> for () {
881 const IS_COMPATIBLE: bool = true;
882 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
883 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
884}
885
886impl KernelMethodCompatibility<PolynomialKernel, NystromMethod> for () {
887 const IS_COMPATIBLE: bool = true;
888 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
889 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
890}
891
892impl KernelMethodCompatibility<RBFKernel, FastfoodMethod> for () {
893 const IS_COMPATIBLE: bool = true;
894 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
895 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
896}
897
898impl KernelMethodCompatibility<ArcCosineKernel, RandomFourierFeatures> for () {
900 const IS_COMPATIBLE: bool = true;
901 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Acceptable;
902 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
903}
904
905impl KernelMethodCompatibility<ArcCosineKernel, NystromMethod> for () {
906 const IS_COMPATIBLE: bool = false; const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
908 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
909}
910
911impl KernelMethodCompatibility<ArcCosineKernel, FastfoodMethod> for () {
912 const IS_COMPATIBLE: bool = false; const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
914 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
915}
916
917impl KernelMethodCompatibility<LaplacianKernel, NystromMethod> for () {
919 const IS_COMPATIBLE: bool = true;
920 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
921 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
922}
923
924impl KernelMethodCompatibility<PolynomialKernel, RandomFourierFeatures> for () {
925 const IS_COMPATIBLE: bool = true;
926 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
927 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
928}
929
930#[derive(Debug, Clone)]
932pub struct ValidatedKernelApproximation<K, M, const N: usize>
934where
935 K: KernelType,
936 M: ApproximationMethod,
937 (): KernelMethodCompatibility<K, M>,
938{
939 inner: TypeSafeKernelApproximation<Untrained, K, M, N>,
940 _validation: ValidatedComponents<N>,
941}
942
943impl<K, M, const N: usize> ValidatedKernelApproximation<K, M, N>
944where
945 K: KernelType,
946 M: ApproximationMethod,
947 (): KernelMethodCompatibility<K, M>,
948{
949 pub fn new() -> Self {
951 assert!(
953 <() as KernelMethodCompatibility<K, M>>::IS_COMPATIBLE,
954 "Incompatible kernel-method combination"
955 );
956
957 Self {
958 inner: TypeSafeKernelApproximation::new(),
959 _validation: ValidatedComponents::new(),
960 }
961 }
962
963 pub const fn performance_info() -> (PerformanceTier, ComplexityClass, ComplexityClass) {
965 (
966 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
967 M::COMPLEXITY,
968 <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY,
969 )
970 }
971
972 pub const fn is_optimal() -> bool {
974 matches!(
975 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
976 PerformanceTier::Optimal
977 )
978 }
979}
980
981#[derive(Debug, Clone, Copy)]
983pub struct BoundedQualityMetrics<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32> {
985 kernel_alignment: f64,
986 approximation_error: f64,
987 effective_rank: f64,
988}
989
990impl<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32>
991 BoundedQualityMetrics<MIN_ALIGNMENT, MAX_ERROR>
992{
993 pub const fn new(alignment: f64, error: f64, rank: f64) -> Option<Self> {
995 let min_align = MIN_ALIGNMENT as f64 / 100.0; let max_err = MAX_ERROR as f64 / 100.0;
997
998 if alignment >= min_align && error <= max_err && rank > 0.0 {
999 Some(Self {
1000 kernel_alignment: alignment,
1001 approximation_error: error,
1002 effective_rank: rank,
1003 })
1004 } else {
1005 None
1006 }
1007 }
1008
1009 pub const fn bounds() -> (f64, f64) {
1011 (MIN_ALIGNMENT as f64 / 100.0, MAX_ERROR as f64 / 100.0)
1012 }
1013
1014 pub const fn meets_standards(&self) -> bool {
1016 let (min_align, max_err) = Self::bounds();
1017 self.kernel_alignment >= min_align && self.approximation_error <= max_err
1018 }
1019}
1020
1021pub type HighQualityMetrics = BoundedQualityMetrics<90, 5>;
1023
1024pub type AcceptableQualityMetrics = BoundedQualityMetrics<70, 15>;
1026
1027#[macro_export]
1029macro_rules! validated_kernel_approximation {
1030 (RBF, RandomFourierFeatures, $n:literal) => {
1031 ValidatedKernelApproximation::<RBFKernel, RandomFourierFeatures, $n>::new()
1032 };
1033 (Laplacian, RandomFourierFeatures, $n:literal) => {
1034 ValidatedKernelApproximation::<LaplacianKernel, RandomFourierFeatures, $n>::new()
1035 };
1036 (RBF, Nystrom, $n:literal) => {
1037 ValidatedKernelApproximation::<RBFKernel, NystromMethod, $n>::new()
1038 };
1039 (RBF, Fastfood, $n:literal) => {
1040 ValidatedKernelApproximation::<RBFKernel, FastfoodMethod, $n>::new()
1041 };
1042}
1043
1044#[allow(non_snake_case)]
1045#[cfg(test)]
1046mod advanced_type_safety_tests {
1047 use super::*;
1048
1049 #[test]
1050 fn test_validated_components() {
1051 let components = ValidatedComponents::<100>::new();
1052 assert_eq!(components.count(), 100);
1053 }
1054
1055 #[test]
1056 fn test_kernel_method_compatibility() {
1057 assert!(<() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::IS_COMPATIBLE);
1059 assert!(!<() as KernelMethodCompatibility<ArcCosineKernel, NystromMethod>>::IS_COMPATIBLE);
1060
1061 assert_eq!(
1063 <() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::PERFORMANCE_TIER,
1064 PerformanceTier::Optimal
1065 );
1066 }
1067
1068 #[test]
1069 fn test_bounded_quality_metrics() {
1070 let high_quality = HighQualityMetrics::new(0.95, 0.03, 50.0).unwrap();
1072 assert!(high_quality.meets_standards());
1073
1074 let low_quality = HighQualityMetrics::new(0.60, 0.20, 10.0);
1076 assert!(low_quality.is_none());
1077 }
1078
1079 #[test]
1080 fn test_macro_creation() {
1081 let _rbf_rff = validated_kernel_approximation!(RBF, RandomFourierFeatures, 100);
1083 let _lap_rff = validated_kernel_approximation!(Laplacian, RandomFourierFeatures, 50);
1084 let _rbf_nys = validated_kernel_approximation!(RBF, Nystrom, 30);
1085 let _rbf_ff = validated_kernel_approximation!(RBF, Fastfood, 128);
1086
1087 let (tier, complexity, memory) = ValidatedKernelApproximation::<
1089 RBFKernel,
1090 RandomFourierFeatures,
1091 100,
1092 >::performance_info();
1093 assert_eq!(tier, PerformanceTier::Optimal);
1094 assert_eq!(complexity, ComplexityClass::Linear);
1095 assert_eq!(memory, ComplexityClass::Linear);
1096 }
1097}
1098
1099pub trait ComposableKernel: KernelType {
1106 type CompositionResult<Other: ComposableKernel>: ComposableKernel;
1107
1108 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other>;
1110}
1111
1112#[derive(Debug, Clone, Copy)]
1114pub struct SumKernel<K1: KernelType, K2: KernelType> {
1116 _phantom: PhantomData<(K1, K2)>,
1117}
1118
1119impl<K1: KernelType, K2: KernelType> KernelType for SumKernel<K1, K2> {
1120 const NAME: &'static str = "Sum";
1121 const SUPPORTS_PARAMETER_LEARNING: bool =
1122 K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1123 const DEFAULT_BANDWIDTH: f64 = (K1::DEFAULT_BANDWIDTH + K2::DEFAULT_BANDWIDTH) / 2.0;
1124}
1125
1126#[derive(Debug, Clone, Copy)]
1128pub struct ProductKernel<K1: KernelType, K2: KernelType> {
1130 _phantom: PhantomData<(K1, K2)>,
1131}
1132
1133impl<K1: KernelType, K2: KernelType> KernelType for ProductKernel<K1, K2> {
1134 const NAME: &'static str = "Product";
1135 const SUPPORTS_PARAMETER_LEARNING: bool =
1136 K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1137 const DEFAULT_BANDWIDTH: f64 = K1::DEFAULT_BANDWIDTH * K2::DEFAULT_BANDWIDTH;
1138}
1139
1140impl ComposableKernel for RBFKernel {
1141 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1142
1143 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1144 SumKernel {
1145 _phantom: PhantomData,
1146 }
1147 }
1148}
1149
1150impl ComposableKernel for LaplacianKernel {
1151 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1152
1153 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1154 SumKernel {
1155 _phantom: PhantomData,
1156 }
1157 }
1158}
1159
1160impl ComposableKernel for PolynomialKernel {
1161 type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1162
1163 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1164 ProductKernel {
1165 _phantom: PhantomData,
1166 }
1167 }
1168}
1169
1170pub trait ValidatedFeatureSize<const N: usize> {
1172 const IS_POWER_OF_TWO: bool = (N != 0) && ((N & (N - 1)) == 0);
1173 const IS_REASONABLE_SIZE: bool = N >= 8 && N <= 8192;
1174 const IS_VALID: bool = Self::IS_POWER_OF_TWO && Self::IS_REASONABLE_SIZE;
1175}
1176
1177impl<const N: usize> ValidatedFeatureSize<N> for () {}
1179
1180#[derive(Debug, Clone)]
1182pub struct ValidatedFeatures<const N: usize> {
1184 _phantom: PhantomData<[f64; N]>,
1185}
1186
1187impl<const N: usize> ValidatedFeatures<N>
1188where
1189 (): ValidatedFeatureSize<N>,
1190{
1191 pub const fn new() -> Self {
1193 Self {
1194 _phantom: PhantomData,
1195 }
1196 }
1197
1198 pub const fn count() -> usize {
1200 N
1201 }
1202
1203 pub const fn is_fastfood_optimal() -> bool {
1205 <() as ValidatedFeatureSize<N>>::IS_POWER_OF_TWO
1206 }
1207}
1208
1209pub trait ApproximationQualityBounds<Method: ApproximationMethod> {
1211 const ERROR_BOUND_CONSTANT: f64;
1213
1214 const SAMPLE_COMPLEXITY_EXPONENT: f64;
1216
1217 const DIMENSION_DEPENDENCY: f64;
1219
1220 fn error_bound(n_samples: usize, n_features: usize, n_components: usize) -> f64 {
1222 let base_rate = Self::ERROR_BOUND_CONSTANT;
1223 let sample_factor = (n_samples as f64).powf(-Self::SAMPLE_COMPLEXITY_EXPONENT);
1224 let dim_factor = (n_features as f64).powf(Self::DIMENSION_DEPENDENCY);
1225 let comp_factor = (n_components as f64).powf(-0.5);
1226
1227 base_rate * sample_factor * dim_factor * comp_factor
1228 }
1229}
1230
1231impl ApproximationQualityBounds<RandomFourierFeatures> for () {
1232 const ERROR_BOUND_CONSTANT: f64 = 2.0;
1233 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.25;
1234 const DIMENSION_DEPENDENCY: f64 = 0.1;
1235}
1236
1237impl ApproximationQualityBounds<NystromMethod> for () {
1238 const ERROR_BOUND_CONSTANT: f64 = 1.5;
1239 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.33;
1240 const DIMENSION_DEPENDENCY: f64 = 0.05;
1241}
1242
1243impl ApproximationQualityBounds<FastfoodMethod> for () {
1244 const ERROR_BOUND_CONSTANT: f64 = 3.0;
1245 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.2;
1246 const DIMENSION_DEPENDENCY: f64 = 0.15;
1247}
1248
1249#[derive(Debug, Clone)]
1251pub struct TypeSafeKernelConfig<K: KernelType, M: ApproximationMethod, const N: usize>
1253where
1254 (): ValidatedFeatureSize<N>,
1255 (): KernelMethodCompatibility<K, M>,
1256 (): ApproximationQualityBounds<M>,
1257{
1258 kernel_type: PhantomData<K>,
1259 method_type: PhantomData<M>,
1260 features: ValidatedFeatures<N>,
1261 bandwidth: f64,
1262 quality_threshold: f64,
1263}
1264
1265impl<K: KernelType, M: ApproximationMethod, const N: usize> TypeSafeKernelConfig<K, M, N>
1266where
1267 (): ValidatedFeatureSize<N>,
1268 (): KernelMethodCompatibility<K, M>,
1269 (): ApproximationQualityBounds<M>,
1270{
1271 pub fn new() -> Self {
1273 Self {
1274 kernel_type: PhantomData,
1275 method_type: PhantomData,
1276 features: ValidatedFeatures::new(),
1277 bandwidth: K::DEFAULT_BANDWIDTH,
1278 quality_threshold: 0.1,
1279 }
1280 }
1281
1282 pub fn bandwidth(mut self, bandwidth: f64) -> Self {
1284 assert!(bandwidth > 0.0, "Bandwidth must be positive");
1285 self.bandwidth = bandwidth;
1286 self
1287 }
1288
1289 pub fn quality_threshold(mut self, threshold: f64) -> Self {
1291 assert!(
1292 threshold > 0.0 && threshold < 1.0,
1293 "Quality threshold must be between 0 and 1"
1294 );
1295 self.quality_threshold = threshold;
1296 self
1297 }
1298
1299 pub const fn performance_tier() -> PerformanceTier {
1301 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER
1302 }
1303
1304 pub const fn memory_complexity() -> ComplexityClass {
1306 <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY
1307 }
1308
1309 pub fn theoretical_error_bound(&self, n_samples: usize, n_features: usize) -> f64 {
1311 <() as ApproximationQualityBounds<M>>::error_bound(n_samples, n_features, N)
1312 }
1313
1314 pub fn meets_quality_requirements(&self, n_samples: usize, n_features: usize) -> bool {
1316 let error_bound = self.theoretical_error_bound(n_samples, n_features);
1317 error_bound <= self.quality_threshold
1318 }
1319}
1320
1321pub type ValidatedRBFRandomFourier<const N: usize> =
1323 TypeSafeKernelConfig<RBFKernel, RandomFourierFeatures, N>;
1324pub type ValidatedLaplacianNystrom<const N: usize> =
1325 TypeSafeKernelConfig<LaplacianKernel, NystromMethod, N>;
1326pub type ValidatedPolynomialRFF<const N: usize> =
1327 TypeSafeKernelConfig<PolynomialKernel, RandomFourierFeatures, N>;
1328
1329impl<K1: KernelType, K2: KernelType> ComposableKernel for SumKernel<K1, K2> {
1330 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1331
1332 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1333 SumKernel {
1334 _phantom: PhantomData,
1335 }
1336 }
1337}
1338
1339impl<K1: KernelType, K2: KernelType> ComposableKernel for ProductKernel<K1, K2> {
1340 type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1341
1342 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1343 ProductKernel {
1344 _phantom: PhantomData,
1345 }
1346 }
1347}
1348
1349#[allow(non_snake_case)]
1350#[cfg(test)]
1351mod advanced_composition_tests {
1352 use super::*;
1353
1354 #[test]
1355 fn test_validated_features() {
1356 let _features_64 = ValidatedFeatures::<64>::new();
1358 let _features_128 = ValidatedFeatures::<128>::new();
1359 let _features_256 = ValidatedFeatures::<256>::new();
1360
1361 assert_eq!(ValidatedFeatures::<64>::count(), 64);
1362 assert!(ValidatedFeatures::<64>::is_fastfood_optimal());
1363 }
1364
1365 #[test]
1366 fn test_type_safe_configs() {
1367 let config = ValidatedRBFRandomFourier::<128>::new()
1368 .bandwidth(1.5)
1369 .quality_threshold(0.05);
1370
1371 assert_eq!(
1372 ValidatedRBFRandomFourier::<128>::performance_tier(),
1373 PerformanceTier::Optimal
1374 );
1375 assert!(config.meets_quality_requirements(1000, 10));
1376 }
1377
1378 #[test]
1379 fn test_kernel_composition() {
1380 let _rbf = RBFKernel;
1381 let _laplacian = LaplacianKernel;
1382 let _polynomial = PolynomialKernel;
1383
1384 let _composed1 = _rbf.compose::<LaplacianKernel>();
1386 let _composed2 = _polynomial.compose::<RBFKernel>();
1387 }
1388
1389 #[test]
1390 fn test_approximation_bounds() {
1391 let rff_bound =
1392 <() as ApproximationQualityBounds<RandomFourierFeatures>>::error_bound(1000, 10, 100);
1393 let nystrom_bound =
1394 <() as ApproximationQualityBounds<NystromMethod>>::error_bound(1000, 10, 100);
1395 let fastfood_bound =
1396 <() as ApproximationQualityBounds<FastfoodMethod>>::error_bound(1000, 10, 100);
1397
1398 assert!(rff_bound > 0.0);
1399 assert!(nystrom_bound > 0.0);
1400 assert!(fastfood_bound > 0.0);
1401
1402 assert!(nystrom_bound < rff_bound);
1404 }
1405}
1406
1407pub struct KernelPresets;
1413
1414impl KernelPresets {
1415 pub fn fast_rbf_128() -> ValidatedRBFRandomFourier<128> {
1417 ValidatedRBFRandomFourier::<128>::new()
1418 .bandwidth(1.0)
1419 .quality_threshold(0.2)
1420 }
1421
1422 pub fn balanced_rbf_256() -> ValidatedRBFRandomFourier<256> {
1424 ValidatedRBFRandomFourier::<256>::new()
1425 .bandwidth(1.0)
1426 .quality_threshold(0.1)
1427 }
1428
1429 pub fn accurate_rbf_512() -> ValidatedRBFRandomFourier<512> {
1431 ValidatedRBFRandomFourier::<512>::new()
1432 .bandwidth(1.0)
1433 .quality_threshold(0.05)
1434 }
1435
1436 pub fn ultrafast_rbf_64() -> ValidatedRBFRandomFourier<64> {
1438 ValidatedRBFRandomFourier::<64>::new()
1439 .bandwidth(1.0)
1440 .quality_threshold(0.3)
1441 }
1442
1443 pub fn precise_nystroem_128() -> ValidatedLaplacianNystrom<128> {
1445 ValidatedLaplacianNystrom::<128>::new()
1446 .bandwidth(1.0)
1447 .quality_threshold(0.01)
1448 }
1449
1450 pub fn memory_efficient_rbf_32() -> ValidatedRBFRandomFourier<32> {
1452 ValidatedRBFRandomFourier::<32>::new()
1453 .bandwidth(1.0)
1454 .quality_threshold(0.4)
1455 }
1456
1457 pub fn polynomial_features_256() -> ValidatedPolynomialRFF<256> {
1459 ValidatedPolynomialRFF::<256>::new()
1460 .bandwidth(1.0)
1461 .quality_threshold(0.15)
1462 }
1463}
1464
1465#[derive(Debug, Clone)]
1471pub struct ProfileGuidedConfig {
1473 pub enable_pgo_feature_selection: bool,
1475
1476 pub enable_pgo_bandwidth_optimization: bool,
1478
1479 pub profile_data_path: Option<String>,
1481
1482 pub target_architecture: TargetArchitecture,
1484
1485 pub optimization_level: OptimizationLevel,
1487}
1488
1489#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1491pub enum TargetArchitecture {
1493 X86_64Generic,
1495 X86_64AVX2,
1497 X86_64AVX512,
1499 ARM64,
1501 ARM64NEON,
1503}
1504
1505#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1507pub enum OptimizationLevel {
1509 None,
1511 Basic,
1513 Aggressive,
1515 Maximum,
1517}
1518
1519impl Default for ProfileGuidedConfig {
1520 fn default() -> Self {
1521 Self {
1522 enable_pgo_feature_selection: false,
1523 enable_pgo_bandwidth_optimization: false,
1524 profile_data_path: None,
1525 target_architecture: TargetArchitecture::X86_64Generic,
1526 optimization_level: OptimizationLevel::Basic,
1527 }
1528 }
1529}
1530
1531impl ProfileGuidedConfig {
1532 pub fn new() -> Self {
1534 Self::default()
1535 }
1536
1537 pub fn enable_feature_selection(mut self) -> Self {
1539 self.enable_pgo_feature_selection = true;
1540 self
1541 }
1542
1543 pub fn enable_bandwidth_optimization(mut self) -> Self {
1545 self.enable_pgo_bandwidth_optimization = true;
1546 self
1547 }
1548
1549 pub fn profile_data_path<P: Into<String>>(mut self, path: P) -> Self {
1551 self.profile_data_path = Some(path.into());
1552 self
1553 }
1554
1555 pub fn target_architecture(mut self, arch: TargetArchitecture) -> Self {
1557 self.target_architecture = arch;
1558 self
1559 }
1560
1561 pub fn optimization_level(mut self, level: OptimizationLevel) -> Self {
1563 self.optimization_level = level;
1564 self
1565 }
1566
1567 pub fn recommended_feature_count(&self, data_size: usize, dimensionality: usize) -> usize {
1569 let base_features = match self.target_architecture {
1570 TargetArchitecture::X86_64Generic => 128,
1571 TargetArchitecture::X86_64AVX2 => 256,
1572 TargetArchitecture::X86_64AVX512 => 512,
1573 TargetArchitecture::ARM64 => 128,
1574 TargetArchitecture::ARM64NEON => 256,
1575 };
1576
1577 let scale_factor = match self.optimization_level {
1578 OptimizationLevel::None => 0.5,
1579 OptimizationLevel::Basic => 1.0,
1580 OptimizationLevel::Aggressive => 1.5,
1581 OptimizationLevel::Maximum => 2.0,
1582 };
1583
1584 let scaled_features = (base_features as f64 * scale_factor) as usize;
1585
1586 let data_adjustment = if data_size > 10000 {
1588 1.2
1589 } else if data_size < 1000 {
1590 0.8
1591 } else {
1592 1.0
1593 };
1594
1595 let dimension_adjustment = if dimensionality > 100 {
1596 1.1
1597 } else if dimensionality < 10 {
1598 0.9
1599 } else {
1600 1.0
1601 };
1602
1603 ((scaled_features as f64 * data_adjustment * dimension_adjustment) as usize)
1604 .max(32)
1605 .min(1024)
1606 }
1607}
1608
1609#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1615pub struct SerializableKernelConfig {
1617 pub kernel_type: String,
1619
1620 pub approximation_method: String,
1622
1623 pub n_components: usize,
1625
1626 pub bandwidth: f64,
1628
1629 pub quality_threshold: f64,
1631
1632 pub random_state: Option<u64>,
1634
1635 pub additional_params: std::collections::HashMap<String, f64>,
1637}
1638
1639#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1641pub struct SerializableFittedParams {
1643 pub config: SerializableKernelConfig,
1645
1646 pub random_features: Option<Vec<Vec<f64>>>,
1648
1649 pub selected_indices: Option<Vec<usize>>,
1651
1652 pub eigenvalues: Option<Vec<f64>>,
1654
1655 pub eigenvectors: Option<Vec<Vec<f64>>>,
1657
1658 pub quality_metrics: std::collections::HashMap<String, f64>,
1660
1661 pub fitted_timestamp: Option<u64>,
1663}
1664
1665pub trait SerializableKernelApproximation {
1667 fn export_config(&self) -> Result<SerializableKernelConfig>;
1669
1670 fn import_config(config: &SerializableKernelConfig) -> Result<Self>
1672 where
1673 Self: Sized;
1674
1675 fn export_fitted_params(&self) -> Result<SerializableFittedParams>;
1677
1678 fn import_fitted_params(&mut self, params: &SerializableFittedParams) -> Result<()>;
1680
1681 fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
1683 let config = self.export_config()?;
1684 let fitted_params = self.export_fitted_params()?;
1685
1686 let model_data = serde_json::json!({
1687 "config": config,
1688 "fitted_params": fitted_params,
1689 "version": "1.0",
1690 "sklears_version": env!("CARGO_PKG_VERSION"),
1691 });
1692
1693 std::fs::write(path, serde_json::to_string_pretty(&model_data).unwrap())
1694 .map_err(SklearsError::from)?;
1695
1696 Ok(())
1697 }
1698
1699 fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self>
1701 where
1702 Self: Sized,
1703 {
1704 let content = std::fs::read_to_string(path).map_err(SklearsError::from)?;
1705
1706 let model_data: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
1707 SklearsError::SerializationError(format!("Failed to parse JSON: {}", e))
1708 })?;
1709
1710 let config: SerializableKernelConfig = serde_json::from_value(model_data["config"].clone())
1711 .map_err(|e| {
1712 SklearsError::SerializationError(format!("Failed to deserialize config: {}", e))
1713 })?;
1714
1715 let fitted_params: SerializableFittedParams =
1716 serde_json::from_value(model_data["fitted_params"].clone()).map_err(|e| {
1717 SklearsError::SerializationError(format!(
1718 "Failed to deserialize fitted params: {}",
1719 e
1720 ))
1721 })?;
1722
1723 let mut model = Self::import_config(&config)?;
1724 model.import_fitted_params(&fitted_params)?;
1725
1726 Ok(model)
1727 }
1728}
1729
1730#[allow(non_snake_case)]
1731#[cfg(test)]
1732mod preset_tests {
1733 use super::*;
1734
1735 #[test]
1736 fn test_kernel_presets() {
1737 let fast_config = KernelPresets::fast_rbf_128();
1738 assert_eq!(
1739 ValidatedRBFRandomFourier::<128>::performance_tier(),
1740 PerformanceTier::Optimal
1741 );
1742
1743 let balanced_config = KernelPresets::balanced_rbf_256();
1744 assert_eq!(
1745 ValidatedRBFRandomFourier::<256>::performance_tier(),
1746 PerformanceTier::Optimal
1747 );
1748
1749 let accurate_config = KernelPresets::accurate_rbf_512();
1750 assert_eq!(
1751 ValidatedRBFRandomFourier::<512>::performance_tier(),
1752 PerformanceTier::Optimal
1753 );
1754 }
1755
1756 #[test]
1757 fn test_profile_guided_config() {
1758 let pgo_config = ProfileGuidedConfig::new()
1759 .enable_feature_selection()
1760 .enable_bandwidth_optimization()
1761 .target_architecture(TargetArchitecture::X86_64AVX2)
1762 .optimization_level(OptimizationLevel::Aggressive);
1763
1764 assert!(pgo_config.enable_pgo_feature_selection);
1765 assert!(pgo_config.enable_pgo_bandwidth_optimization);
1766 assert_eq!(
1767 pgo_config.target_architecture,
1768 TargetArchitecture::X86_64AVX2
1769 );
1770 assert_eq!(pgo_config.optimization_level, OptimizationLevel::Aggressive);
1771
1772 let features = pgo_config.recommended_feature_count(5000, 50);
1774 assert!(features >= 32 && features <= 1024);
1775 }
1776
1777 #[test]
1778 fn test_serializable_config() {
1779 let config = SerializableKernelConfig {
1780 kernel_type: "RBF".to_string(),
1781 approximation_method: "RandomFourierFeatures".to_string(),
1782 n_components: 256,
1783 bandwidth: 1.5,
1784 quality_threshold: 0.1,
1785 random_state: Some(42),
1786 additional_params: std::collections::HashMap::new(),
1787 };
1788
1789 let serialized = serde_json::to_string(&config).unwrap();
1791 assert!(serialized.contains("RBF"));
1792 assert!(serialized.contains("RandomFourierFeatures"));
1793
1794 let deserialized: SerializableKernelConfig = serde_json::from_str(&serialized).unwrap();
1796 assert_eq!(deserialized.kernel_type, "RBF");
1797 assert_eq!(deserialized.n_components, 256);
1798 assert_eq!(deserialized.bandwidth, 1.5);
1799 }
1800}