1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::RngExt;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::error::{Result, SklearsError};
13use std::marker::PhantomData;
14
15pub trait ApproximationState {}
17
18#[derive(Debug, Clone, Copy)]
20pub struct Untrained;
22impl ApproximationState for Untrained {}
23
24#[derive(Debug, Clone, Copy)]
26pub struct Trained;
28impl ApproximationState for Trained {}
29
30pub trait KernelType {
32 const NAME: &'static str;
34
35 const SUPPORTS_PARAMETER_LEARNING: bool;
37
38 const DEFAULT_BANDWIDTH: f64;
40}
41
42#[derive(Debug, Clone, Copy)]
44pub struct RBFKernel;
46impl KernelType for RBFKernel {
47 const NAME: &'static str = "RBF";
48 const SUPPORTS_PARAMETER_LEARNING: bool = true;
49 const DEFAULT_BANDWIDTH: f64 = 1.0;
50}
51
52#[derive(Debug, Clone, Copy)]
54pub struct LaplacianKernel;
56impl KernelType for LaplacianKernel {
57 const NAME: &'static str = "Laplacian";
58 const SUPPORTS_PARAMETER_LEARNING: bool = true;
59 const DEFAULT_BANDWIDTH: f64 = 1.0;
60}
61
62#[derive(Debug, Clone, Copy)]
64pub struct PolynomialKernel;
66impl KernelType for PolynomialKernel {
67 const NAME: &'static str = "Polynomial";
68 const SUPPORTS_PARAMETER_LEARNING: bool = true;
69 const DEFAULT_BANDWIDTH: f64 = 1.0;
70}
71
72#[derive(Debug, Clone, Copy)]
74pub struct ArcCosineKernel;
76impl KernelType for ArcCosineKernel {
77 const NAME: &'static str = "ArcCosine";
78 const SUPPORTS_PARAMETER_LEARNING: bool = false;
79 const DEFAULT_BANDWIDTH: f64 = 1.0;
80}
81
82pub trait ApproximationMethod {
84 const NAME: &'static str;
86
87 const SUPPORTS_INCREMENTAL: bool;
89
90 const HAS_ERROR_BOUNDS: bool;
92
93 const COMPLEXITY: ComplexityClass;
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum ComplexityClass {
101 Quadratic,
103 QuasiLinear,
105 Linear,
107 DimensionDependent,
109}
110
111#[derive(Debug, Clone, Copy)]
113pub struct RandomFourierFeatures;
115impl ApproximationMethod for RandomFourierFeatures {
116 const NAME: &'static str = "RandomFourierFeatures";
117 const SUPPORTS_INCREMENTAL: bool = true;
118 const HAS_ERROR_BOUNDS: bool = true;
119 const COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
120}
121
122#[derive(Debug, Clone, Copy)]
124pub struct NystromMethod;
126impl ApproximationMethod for NystromMethod {
127 const NAME: &'static str = "Nystrom";
128 const SUPPORTS_INCREMENTAL: bool = false;
129 const HAS_ERROR_BOUNDS: bool = true;
130 const COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
131}
132
133#[derive(Debug, Clone, Copy)]
135pub struct FastfoodMethod;
137impl ApproximationMethod for FastfoodMethod {
138 const NAME: &'static str = "Fastfood";
139 const SUPPORTS_INCREMENTAL: bool = false;
140 const HAS_ERROR_BOUNDS: bool = true;
141 const COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
142}
143
144#[derive(Debug, Clone)]
146pub struct TypeSafeKernelApproximation<State, Kernel, Method, const N_COMPONENTS: usize>
148where
149 State: ApproximationState,
150 Kernel: KernelType,
151 Method: ApproximationMethod,
152{
153 _phantom: PhantomData<(State, Kernel, Method)>,
155
156 parameters: ApproximationParameters,
158
159 random_state: Option<u64>,
161}
162
163#[derive(Debug, Clone)]
165pub struct ApproximationParameters {
167 pub bandwidth: f64,
169
170 pub degree: Option<usize>,
172
173 pub coef0: Option<f64>,
175
176 pub custom: std::collections::HashMap<String, f64>,
178}
179
180impl Default for ApproximationParameters {
181 fn default() -> Self {
182 Self {
183 bandwidth: 1.0,
184 degree: None,
185 coef0: None,
186 custom: std::collections::HashMap::new(),
187 }
188 }
189}
190
191#[derive(Debug, Clone)]
193pub struct FittedTypeSafeKernelApproximation<Kernel, Method, const N_COMPONENTS: usize>
195where
196 Kernel: KernelType,
197 Method: ApproximationMethod,
198{
199 _phantom: PhantomData<(Kernel, Method)>,
201
202 transformation_params: TransformationParameters<N_COMPONENTS>,
204
205 fitted_parameters: ApproximationParameters,
207
208 quality_metrics: QualityMetrics,
210}
211
212#[derive(Debug, Clone)]
214pub enum TransformationParameters<const N: usize> {
216 RandomFeatures {
218 weights: Array2<f64>,
219
220 biases: Option<Array1<f64>>,
221 },
222 Nystrom {
224 inducing_points: Array2<f64>,
225
226 eigenvalues: Array1<f64>,
227
228 eigenvectors: Array2<f64>,
229 },
230 Fastfood {
232 structured_matrices: Vec<Array2<f64>>,
233 scaling: Array1<f64>,
234 },
235}
236
237#[derive(Debug, Clone)]
239#[derive(Default)]
241pub struct QualityMetrics {
242 pub approximation_error: Option<f64>,
244
245 pub effective_rank: Option<f64>,
247
248 pub condition_number: Option<f64>,
250
251 pub kernel_alignment: Option<f64>,
253}
254
255impl<Kernel, Method, const N_COMPONENTS: usize> Default
257 for TypeSafeKernelApproximation<Untrained, Kernel, Method, N_COMPONENTS>
258where
259 Kernel: KernelType,
260 Method: ApproximationMethod,
261{
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267impl<Kernel, Method, const N_COMPONENTS: usize>
268 TypeSafeKernelApproximation<Untrained, Kernel, Method, N_COMPONENTS>
269where
270 Kernel: KernelType,
271 Method: ApproximationMethod,
272{
273 pub fn new() -> Self {
275 Self {
276 _phantom: PhantomData,
277 parameters: ApproximationParameters {
278 bandwidth: Kernel::DEFAULT_BANDWIDTH,
279 ..Default::default()
280 },
281 random_state: None,
282 }
283 }
284
285 pub fn bandwidth(mut self, bandwidth: f64) -> Self
287 where
288 Kernel: KernelTypeWithBandwidth,
289 {
290 self.parameters.bandwidth = bandwidth;
291 self
292 }
293
294 pub fn degree(mut self, degree: usize) -> Self
296 where
297 Kernel: PolynomialKernelType,
298 {
299 self.parameters.degree = Some(degree);
300 self
301 }
302
303 pub fn random_state(mut self, seed: u64) -> Self {
305 self.random_state = Some(seed);
306 self
307 }
308
309 pub fn fit(
311 self,
312 data: &Array2<f64>,
313 ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
314 where
315 Kernel: FittableKernel<Method>,
316 Method: FittableMethod<Kernel>,
317 {
318 self.fit_impl(data)
319 }
320
321 fn fit_impl(
323 self,
324 data: &Array2<f64>,
325 ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
326 where
327 Kernel: FittableKernel<Method>,
328 Method: FittableMethod<Kernel>,
329 {
330 let transformation_params = match Method::NAME {
331 "RandomFourierFeatures" => self.fit_random_fourier_features(data)?,
332 "Nystrom" => self.fit_nystrom(data)?,
333 "Fastfood" => self.fit_fastfood(data)?,
334 _ => {
335 return Err(SklearsError::InvalidOperation(format!(
336 "Unsupported method: {}",
337 Method::NAME
338 )));
339 }
340 };
341
342 Ok(FittedTypeSafeKernelApproximation {
343 _phantom: PhantomData,
344 transformation_params,
345 fitted_parameters: self.parameters,
346 quality_metrics: QualityMetrics::default(),
347 })
348 }
349
350 fn fit_random_fourier_features(
351 &self,
352 data: &Array2<f64>,
353 ) -> Result<TransformationParameters<N_COMPONENTS>> {
354 let mut rng = match self.random_state {
355 Some(seed) => RealStdRng::seed_from_u64(seed),
356 None => RealStdRng::from_seed(thread_rng().random()),
357 };
358
359 let (_, n_features) = data.dim();
360
361 let weights = match Kernel::NAME {
362 "RBF" => {
363 let normal = RandNormal::new(0.0, self.parameters.bandwidth)
364 .expect("operation should succeed");
365 Array2::from_shape_fn((N_COMPONENTS, n_features), |_| rng.sample(normal))
366 }
367 "Laplacian" => {
368 let uniform =
370 RandUniform::new(0.0, std::f64::consts::PI).expect("operation should succeed");
371 Array2::from_shape_fn((N_COMPONENTS, n_features), |_| {
372 let u = rng.sample(uniform);
373 (u - std::f64::consts::PI / 2.0).tan() / self.parameters.bandwidth
374 })
375 }
376 _ => {
377 return Err(SklearsError::InvalidOperation(format!(
378 "RFF not supported for kernel: {}",
379 Kernel::NAME
380 )));
381 }
382 };
383
384 let uniform =
385 RandUniform::new(0.0, 2.0 * std::f64::consts::PI).expect("operation should succeed");
386 let biases = Some(Array1::from_shape_fn(N_COMPONENTS, |_| rng.sample(uniform)));
387
388 Ok(TransformationParameters::RandomFeatures { weights, biases })
389 }
390
391 fn fit_nystrom(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
392 let (n_samples, _) = data.dim();
393 let n_inducing = N_COMPONENTS.min(n_samples);
394
395 let mut rng = match self.random_state {
397 Some(seed) => RealStdRng::seed_from_u64(seed),
398 None => RealStdRng::from_seed(thread_rng().random()),
399 };
400
401 let mut indices: Vec<usize> = (0..n_samples).collect();
402 indices.shuffle(&mut rng);
403
404 let inducing_indices = &indices[..n_inducing];
405 let inducing_points = data.select(scirs2_core::ndarray::Axis(0), inducing_indices);
406
407 let kernel_matrix = self.compute_kernel_matrix(&inducing_points, &inducing_points)?;
409
410 let eigenvalues = Array1::from_shape_fn(n_inducing, |i| kernel_matrix[[i, i]]);
412 let eigenvectors = Array2::eye(n_inducing);
413
414 Ok(TransformationParameters::Nystrom {
415 inducing_points,
416 eigenvalues,
417 eigenvectors,
418 })
419 }
420
421 fn fit_fastfood(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
422 let mut rng = match self.random_state {
423 Some(seed) => RealStdRng::seed_from_u64(seed),
424 None => RealStdRng::from_seed(thread_rng().random()),
425 };
426
427 let (_, n_features) = data.dim();
428
429 let mut structured_matrices = Vec::new();
431
432 let binary_matrix = Array2::from_shape_fn((n_features, n_features), |(i, j)| {
434 if i == j {
435 if rng.random::<bool>() {
436 1.0
437 } else {
438 -1.0
439 }
440 } else {
441 0.0
442 }
443 });
444 structured_matrices.push(binary_matrix);
445
446 let scaling = Array1::from_shape_fn(N_COMPONENTS, |_| {
448 use scirs2_core::random::RandNormal;
449 let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
450 rng.sample(normal)
451 });
452
453 Ok(TransformationParameters::Fastfood {
454 structured_matrices,
455 scaling,
456 })
457 }
458
459 fn compute_kernel_matrix(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Result<Array2<f64>> {
460 let (n1, _) = x1.dim();
461 let (n2, _) = x2.dim();
462 let mut kernel = Array2::zeros((n1, n2));
463
464 for i in 0..n1 {
465 for j in 0..n2 {
466 let similarity = match Kernel::NAME {
467 "RBF" => {
468 let diff = &x1.row(i) - &x2.row(j);
469 let dist_sq = diff.mapv(|x| x * x).sum();
470 (-self.parameters.bandwidth * dist_sq).exp()
471 }
472 "Laplacian" => {
473 let diff = &x1.row(i) - &x2.row(j);
474 let dist = diff.mapv(|x| x.abs()).sum();
475 (-self.parameters.bandwidth * dist).exp()
476 }
477 "Polynomial" => {
478 let dot_product = x1.row(i).dot(&x2.row(j));
479 let degree = self.parameters.degree.unwrap_or(2) as i32;
480 let coef0 = self.parameters.coef0.unwrap_or(1.0);
481 (self.parameters.bandwidth * dot_product + coef0).powi(degree)
482 }
483 _ => {
484 return Err(SklearsError::InvalidOperation(format!(
485 "Unsupported kernel: {}",
486 Kernel::NAME
487 )));
488 }
489 };
490 kernel[[i, j]] = similarity;
491 }
492 }
493
494 Ok(kernel)
495 }
496}
497
498impl<Kernel, Method, const N_COMPONENTS: usize>
500 FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>
501where
502 Kernel: KernelType,
503 Method: ApproximationMethod,
504{
505 pub fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
507 match &self.transformation_params {
508 TransformationParameters::RandomFeatures { weights, biases } => {
509 self.transform_random_features(data, weights, biases.as_ref())
510 }
511 TransformationParameters::Nystrom {
512 inducing_points,
513 eigenvalues,
514 eigenvectors,
515 } => self.transform_nystrom(data, inducing_points, eigenvalues, eigenvectors),
516 TransformationParameters::Fastfood {
517 structured_matrices,
518 scaling,
519 } => self.transform_fastfood(data, structured_matrices, scaling),
520 }
521 }
522
523 fn transform_random_features(
524 &self,
525 data: &Array2<f64>,
526 weights: &Array2<f64>,
527 biases: Option<&Array1<f64>>,
528 ) -> Result<Array2<f64>> {
529 let (n_samples, _) = data.dim();
530 let mut features = Array2::zeros((n_samples, N_COMPONENTS * 2));
531
532 for (i, sample) in data.axis_iter(scirs2_core::ndarray::Axis(0)).enumerate() {
533 for j in 0..N_COMPONENTS {
534 let projection = sample.dot(&weights.row(j));
535 let phase = if let Some(b) = biases {
536 projection + b[j]
537 } else {
538 projection
539 };
540
541 features[[i, 2 * j]] = phase.cos();
542 features[[i, 2 * j + 1]] = phase.sin();
543 }
544 }
545
546 Ok(features)
547 }
548
549 fn transform_nystrom(
550 &self,
551 data: &Array2<f64>,
552 inducing_points: &Array2<f64>,
553 eigenvalues: &Array1<f64>,
554 eigenvectors: &Array2<f64>,
555 ) -> Result<Array2<f64>> {
556 let kernel_matrix = self.compute_kernel_matrix_fitted(data, inducing_points)?;
558
559 let mut features = Array2::zeros((data.nrows(), eigenvalues.len()));
561
562 for i in 0..data.nrows() {
563 for j in 0..eigenvalues.len() {
564 if eigenvalues[j] > 1e-8 {
565 let mut feature_value = 0.0;
566 for k in 0..inducing_points.nrows() {
567 feature_value += kernel_matrix[[i, k]] * eigenvectors[[k, j]];
568 }
569 features[[i, j]] = feature_value / eigenvalues[j].sqrt();
570 }
571 }
572 }
573
574 Ok(features)
575 }
576
577 fn transform_fastfood(
578 &self,
579 data: &Array2<f64>,
580 structured_matrices: &[Array2<f64>],
581 scaling: &Array1<f64>,
582 ) -> Result<Array2<f64>> {
583 let (n_samples, n_features) = data.dim();
584 let mut features = data.clone();
585
586 for matrix in structured_matrices {
588 features = features.dot(matrix);
589 }
590
591 let mut result = Array2::zeros((n_samples, N_COMPONENTS));
593 for i in 0..n_samples {
594 for j in 0..N_COMPONENTS.min(n_features) {
595 result[[i, j]] = (features[[i, j]] * scaling[j]).cos();
596 }
597 }
598
599 Ok(result)
600 }
601
602 fn compute_kernel_matrix_fitted(
603 &self,
604 x1: &Array2<f64>,
605 x2: &Array2<f64>,
606 ) -> Result<Array2<f64>> {
607 let (n1, _) = x1.dim();
608 let (n2, _) = x2.dim();
609 let mut kernel = Array2::zeros((n1, n2));
610
611 for i in 0..n1 {
612 for j in 0..n2 {
613 let similarity = match Kernel::NAME {
614 "RBF" => {
615 let diff = &x1.row(i) - &x2.row(j);
616 let dist_sq = diff.mapv(|x| x * x).sum();
617 (-self.fitted_parameters.bandwidth * dist_sq).exp()
618 }
619 "Laplacian" => {
620 let diff = &x1.row(i) - &x2.row(j);
621 let dist = diff.mapv(|x| x.abs()).sum();
622 (-self.fitted_parameters.bandwidth * dist).exp()
623 }
624 "Polynomial" => {
625 let dot_product = x1.row(i).dot(&x2.row(j));
626 let degree = self.fitted_parameters.degree.unwrap_or(2) as i32;
627 let coef0 = self.fitted_parameters.coef0.unwrap_or(1.0);
628 (self.fitted_parameters.bandwidth * dot_product + coef0).powi(degree)
629 }
630 _ => {
631 return Err(SklearsError::InvalidOperation(format!(
632 "Unsupported kernel: {}",
633 Kernel::NAME
634 )));
635 }
636 };
637 kernel[[i, j]] = similarity;
638 }
639 }
640
641 Ok(kernel)
642 }
643
644 pub fn quality_metrics(&self) -> &QualityMetrics {
646 &self.quality_metrics
647 }
648
649 pub const fn n_components() -> usize {
651 N_COMPONENTS
652 }
653
654 pub fn kernel_name(&self) -> &'static str {
656 Kernel::NAME
657 }
658
659 pub fn method_name(&self) -> &'static str {
661 Method::NAME
662 }
663}
664
665pub trait KernelTypeWithBandwidth: KernelType {}
669impl KernelTypeWithBandwidth for RBFKernel {}
670impl KernelTypeWithBandwidth for LaplacianKernel {}
671
672pub trait PolynomialKernelType: KernelType {}
674impl PolynomialKernelType for PolynomialKernel {}
675
676pub trait FittableKernel<Method: ApproximationMethod>: KernelType {}
678impl FittableKernel<RandomFourierFeatures> for RBFKernel {}
679impl FittableKernel<RandomFourierFeatures> for LaplacianKernel {}
680impl FittableKernel<NystromMethod> for RBFKernel {}
681impl FittableKernel<NystromMethod> for LaplacianKernel {}
682impl FittableKernel<NystromMethod> for PolynomialKernel {}
683impl FittableKernel<FastfoodMethod> for RBFKernel {}
684
685pub trait FittableMethod<Kernel: KernelType>: ApproximationMethod {}
687impl FittableMethod<RBFKernel> for RandomFourierFeatures {}
688impl FittableMethod<LaplacianKernel> for RandomFourierFeatures {}
689impl FittableMethod<RBFKernel> for NystromMethod {}
690impl FittableMethod<LaplacianKernel> for NystromMethod {}
691impl FittableMethod<PolynomialKernel> for NystromMethod {}
692impl FittableMethod<RBFKernel> for FastfoodMethod {}
693
694pub type RBFRandomFourierFeatures<const N: usize> =
696 TypeSafeKernelApproximation<Untrained, RBFKernel, RandomFourierFeatures, N>;
697
698pub type LaplacianRandomFourierFeatures<const N: usize> =
699 TypeSafeKernelApproximation<Untrained, LaplacianKernel, RandomFourierFeatures, N>;
700
701pub type RBFNystrom<const N: usize> =
702 TypeSafeKernelApproximation<Untrained, RBFKernel, NystromMethod, N>;
703
704pub type PolynomialNystrom<const N: usize> =
705 TypeSafeKernelApproximation<Untrained, PolynomialKernel, NystromMethod, N>;
706
707pub type RBFFastfood<const N: usize> =
708 TypeSafeKernelApproximation<Untrained, RBFKernel, FastfoodMethod, N>;
709
710pub type FittedRBFRandomFourierFeatures<const N: usize> =
712 FittedTypeSafeKernelApproximation<RBFKernel, RandomFourierFeatures, N>;
713
714pub type FittedLaplacianRandomFourierFeatures<const N: usize> =
715 FittedTypeSafeKernelApproximation<LaplacianKernel, RandomFourierFeatures, N>;
716
717pub type FittedRBFNystrom<const N: usize> =
718 FittedTypeSafeKernelApproximation<RBFKernel, NystromMethod, N>;
719
720#[allow(non_snake_case)]
721#[cfg(test)]
722mod tests {
723 use super::*;
724 use scirs2_core::ndarray::array;
725
726 #[test]
727 fn test_type_safe_rbf_rff() {
728 let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
729
730 let approximation: RBFRandomFourierFeatures<10> = TypeSafeKernelApproximation::new()
732 .bandwidth(1.5)
733 .random_state(42);
734
735 let fitted = approximation.fit(&data).expect("operation should succeed");
737 let features = fitted.transform(&data).expect("operation should succeed");
738
739 assert_eq!(features.shape(), &[4, 20]); assert_eq!(FittedRBFRandomFourierFeatures::<10>::n_components(), 10);
741 assert_eq!(fitted.kernel_name(), "RBF");
742 assert_eq!(fitted.method_name(), "RandomFourierFeatures");
743 }
744
745 #[test]
746 fn test_type_safe_nystrom() {
747 let data = array![
748 [1.0, 2.0, 3.0],
749 [2.0, 3.0, 4.0],
750 [3.0, 4.0, 5.0],
751 [4.0, 5.0, 6.0],
752 ];
753
754 let approximation: RBFNystrom<5> = TypeSafeKernelApproximation::new()
755 .bandwidth(2.0)
756 .random_state(42);
757
758 let fitted = approximation.fit(&data).expect("operation should succeed");
759 let features = fitted.transform(&data).expect("operation should succeed");
760
761 assert_eq!(features.shape()[0], 4); assert_eq!(fitted.kernel_name(), "RBF");
763 assert_eq!(fitted.method_name(), "Nystrom");
764 }
765
766 #[test]
767 fn test_polynomial_kernel() {
768 let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
769
770 let approximation: PolynomialNystrom<3> = TypeSafeKernelApproximation::new()
771 .degree(3)
772 .random_state(42);
773
774 let fitted = approximation.fit(&data).expect("operation should succeed");
775 let features = fitted.transform(&data).expect("operation should succeed");
776
777 assert_eq!(fitted.kernel_name(), "Polynomial");
778 assert!(features.nrows() == 3);
779 }
780
781 #[test]
782 fn test_compile_time_constants() {
783 assert_eq!(RBFKernel::NAME, "RBF");
785 assert!(RBFKernel::SUPPORTS_PARAMETER_LEARNING);
786 assert_eq!(RandomFourierFeatures::NAME, "RandomFourierFeatures");
787 assert!(RandomFourierFeatures::SUPPORTS_INCREMENTAL);
788 assert!(RandomFourierFeatures::HAS_ERROR_BOUNDS);
789 }
790
791 }
809
810pub trait ParameterValidation<const MIN: usize, const MAX: usize> {
816 const IS_VALID: bool = MIN <= MAX;
818
819 fn parameter_range() -> (usize, usize) {
821 (MIN, MAX)
822 }
823}
824
825#[derive(Debug, Clone, Copy)]
827pub struct ValidatedComponents<const N: usize>;
829
830impl<const N: usize> Default for ValidatedComponents<N> {
831 fn default() -> Self {
832 Self::new()
833 }
834}
835
836impl<const N: usize> ValidatedComponents<N> {
837 pub fn new() -> Self {
839 assert!(N > 0, "Component count must be positive");
841 assert!(N <= 10000, "Component count too large");
842 Self
843 }
844
845 pub const fn count(&self) -> usize {
847 N
848 }
849}
850
851pub trait KernelMethodCompatibility<K: KernelType, M: ApproximationMethod> {
853 const IS_COMPATIBLE: bool;
855
856 const PERFORMANCE_TIER: PerformanceTier;
858
859 const MEMORY_COMPLEXITY: ComplexityClass;
861}
862
863#[derive(Debug, Clone, Copy, PartialEq, Eq)]
865pub enum PerformanceTier {
867 Optimal,
869 Good,
871 Acceptable,
873 Poor,
875}
876
877impl KernelMethodCompatibility<RBFKernel, RandomFourierFeatures> for () {
879 const IS_COMPATIBLE: bool = true;
880 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
881 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
882}
883
884impl KernelMethodCompatibility<LaplacianKernel, RandomFourierFeatures> for () {
885 const IS_COMPATIBLE: bool = true;
886 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
887 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
888}
889
890impl KernelMethodCompatibility<RBFKernel, NystromMethod> for () {
891 const IS_COMPATIBLE: bool = true;
892 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
893 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
894}
895
896impl KernelMethodCompatibility<PolynomialKernel, NystromMethod> for () {
897 const IS_COMPATIBLE: bool = true;
898 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
899 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
900}
901
902impl KernelMethodCompatibility<RBFKernel, FastfoodMethod> for () {
903 const IS_COMPATIBLE: bool = true;
904 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
905 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
906}
907
908impl KernelMethodCompatibility<ArcCosineKernel, RandomFourierFeatures> for () {
910 const IS_COMPATIBLE: bool = true;
911 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Acceptable;
912 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
913}
914
915impl KernelMethodCompatibility<ArcCosineKernel, NystromMethod> for () {
916 const IS_COMPATIBLE: bool = false; const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
918 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
919}
920
921impl KernelMethodCompatibility<ArcCosineKernel, FastfoodMethod> for () {
922 const IS_COMPATIBLE: bool = false; const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
924 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
925}
926
927impl KernelMethodCompatibility<LaplacianKernel, NystromMethod> for () {
929 const IS_COMPATIBLE: bool = true;
930 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
931 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
932}
933
934impl KernelMethodCompatibility<PolynomialKernel, RandomFourierFeatures> for () {
935 const IS_COMPATIBLE: bool = true;
936 const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
937 const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
938}
939
940#[derive(Debug, Clone)]
942pub struct ValidatedKernelApproximation<K, M, const N: usize>
944where
945 K: KernelType,
946 M: ApproximationMethod,
947 (): KernelMethodCompatibility<K, M>,
948{
949 inner: TypeSafeKernelApproximation<Untrained, K, M, N>,
950 _validation: ValidatedComponents<N>,
951}
952
953impl<K, M, const N: usize> Default for ValidatedKernelApproximation<K, M, N>
954where
955 K: KernelType,
956 M: ApproximationMethod,
957 (): KernelMethodCompatibility<K, M>,
958{
959 fn default() -> Self {
960 Self::new()
961 }
962}
963
964impl<K, M, const N: usize> ValidatedKernelApproximation<K, M, N>
965where
966 K: KernelType,
967 M: ApproximationMethod,
968 (): KernelMethodCompatibility<K, M>,
969{
970 pub fn new() -> Self {
972 assert!(
974 <() as KernelMethodCompatibility<K, M>>::IS_COMPATIBLE,
975 "Incompatible kernel-method combination"
976 );
977
978 Self {
979 inner: TypeSafeKernelApproximation::new(),
980 _validation: ValidatedComponents::new(),
981 }
982 }
983
984 pub const fn performance_info() -> (PerformanceTier, ComplexityClass, ComplexityClass) {
986 (
987 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
988 M::COMPLEXITY,
989 <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY,
990 )
991 }
992
993 pub const fn is_optimal() -> bool {
995 matches!(
996 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
997 PerformanceTier::Optimal
998 )
999 }
1000}
1001
1002#[derive(Debug, Clone, Copy)]
1004pub struct BoundedQualityMetrics<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32> {
1006 kernel_alignment: f64,
1007 approximation_error: f64,
1008 effective_rank: f64,
1009}
1010
1011impl<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32>
1012 BoundedQualityMetrics<MIN_ALIGNMENT, MAX_ERROR>
1013{
1014 pub const fn new(alignment: f64, error: f64, rank: f64) -> Option<Self> {
1016 let min_align = MIN_ALIGNMENT as f64 / 100.0; let max_err = MAX_ERROR as f64 / 100.0;
1018
1019 if alignment >= min_align && error <= max_err && rank > 0.0 {
1020 Some(Self {
1021 kernel_alignment: alignment,
1022 approximation_error: error,
1023 effective_rank: rank,
1024 })
1025 } else {
1026 None
1027 }
1028 }
1029
1030 pub const fn bounds() -> (f64, f64) {
1032 (MIN_ALIGNMENT as f64 / 100.0, MAX_ERROR as f64 / 100.0)
1033 }
1034
1035 pub const fn meets_standards(&self) -> bool {
1037 let (min_align, max_err) = Self::bounds();
1038 self.kernel_alignment >= min_align && self.approximation_error <= max_err
1039 }
1040}
1041
1042pub type HighQualityMetrics = BoundedQualityMetrics<90, 5>;
1044
1045pub type AcceptableQualityMetrics = BoundedQualityMetrics<70, 15>;
1047
1048#[macro_export]
1050macro_rules! validated_kernel_approximation {
1051 (RBF, RandomFourierFeatures, $n:literal) => {
1052 ValidatedKernelApproximation::<RBFKernel, RandomFourierFeatures, $n>::new()
1053 };
1054 (Laplacian, RandomFourierFeatures, $n:literal) => {
1055 ValidatedKernelApproximation::<LaplacianKernel, RandomFourierFeatures, $n>::new()
1056 };
1057 (RBF, Nystrom, $n:literal) => {
1058 ValidatedKernelApproximation::<RBFKernel, NystromMethod, $n>::new()
1059 };
1060 (RBF, Fastfood, $n:literal) => {
1061 ValidatedKernelApproximation::<RBFKernel, FastfoodMethod, $n>::new()
1062 };
1063}
1064
1065#[allow(non_snake_case)]
1066#[cfg(test)]
1067mod advanced_type_safety_tests {
1068 use super::*;
1069
1070 #[test]
1071 fn test_validated_components() {
1072 let components = ValidatedComponents::<100>::new();
1073 assert_eq!(components.count(), 100);
1074 }
1075
1076 #[test]
1077 fn test_kernel_method_compatibility() {
1078 assert!(<() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::IS_COMPATIBLE);
1080 assert!(!<() as KernelMethodCompatibility<ArcCosineKernel, NystromMethod>>::IS_COMPATIBLE);
1081
1082 assert_eq!(
1084 <() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::PERFORMANCE_TIER,
1085 PerformanceTier::Optimal
1086 );
1087 }
1088
1089 #[test]
1090 fn test_bounded_quality_metrics() {
1091 let high_quality =
1093 HighQualityMetrics::new(0.95, 0.03, 50.0).expect("operation should succeed");
1094 assert!(high_quality.meets_standards());
1095
1096 let low_quality = HighQualityMetrics::new(0.60, 0.20, 10.0);
1098 assert!(low_quality.is_none());
1099 }
1100
1101 #[test]
1102 fn test_macro_creation() {
1103 let _rbf_rff = validated_kernel_approximation!(RBF, RandomFourierFeatures, 100);
1105 let _lap_rff = validated_kernel_approximation!(Laplacian, RandomFourierFeatures, 50);
1106 let _rbf_nys = validated_kernel_approximation!(RBF, Nystrom, 30);
1107 let _rbf_ff = validated_kernel_approximation!(RBF, Fastfood, 128);
1108
1109 let (tier, complexity, memory) = ValidatedKernelApproximation::<
1111 RBFKernel,
1112 RandomFourierFeatures,
1113 100,
1114 >::performance_info();
1115 assert_eq!(tier, PerformanceTier::Optimal);
1116 assert_eq!(complexity, ComplexityClass::Linear);
1117 assert_eq!(memory, ComplexityClass::Linear);
1118 }
1119}
1120
1121pub trait ComposableKernel: KernelType {
1128 type CompositionResult<Other: ComposableKernel>: ComposableKernel;
1129
1130 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other>;
1132}
1133
1134#[derive(Debug, Clone, Copy)]
1136pub struct SumKernel<K1: KernelType, K2: KernelType> {
1138 _phantom: PhantomData<(K1, K2)>,
1139}
1140
1141impl<K1: KernelType, K2: KernelType> KernelType for SumKernel<K1, K2> {
1142 const NAME: &'static str = "Sum";
1143 const SUPPORTS_PARAMETER_LEARNING: bool =
1144 K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1145 const DEFAULT_BANDWIDTH: f64 = (K1::DEFAULT_BANDWIDTH + K2::DEFAULT_BANDWIDTH) / 2.0;
1146}
1147
1148#[derive(Debug, Clone, Copy)]
1150pub struct ProductKernel<K1: KernelType, K2: KernelType> {
1152 _phantom: PhantomData<(K1, K2)>,
1153}
1154
1155impl<K1: KernelType, K2: KernelType> KernelType for ProductKernel<K1, K2> {
1156 const NAME: &'static str = "Product";
1157 const SUPPORTS_PARAMETER_LEARNING: bool =
1158 K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1159 const DEFAULT_BANDWIDTH: f64 = K1::DEFAULT_BANDWIDTH * K2::DEFAULT_BANDWIDTH;
1160}
1161
1162impl ComposableKernel for RBFKernel {
1163 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1164
1165 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1166 SumKernel {
1167 _phantom: PhantomData,
1168 }
1169 }
1170}
1171
1172impl ComposableKernel for LaplacianKernel {
1173 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1174
1175 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1176 SumKernel {
1177 _phantom: PhantomData,
1178 }
1179 }
1180}
1181
1182impl ComposableKernel for PolynomialKernel {
1183 type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1184
1185 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1186 ProductKernel {
1187 _phantom: PhantomData,
1188 }
1189 }
1190}
1191
1192pub trait ValidatedFeatureSize<const N: usize> {
1194 const IS_POWER_OF_TWO: bool = (N != 0) && ((N & (N - 1)) == 0);
1195 const IS_REASONABLE_SIZE: bool = N >= 8 && N <= 8192;
1196 const IS_VALID: bool = Self::IS_POWER_OF_TWO && Self::IS_REASONABLE_SIZE;
1197}
1198
1199impl<const N: usize> ValidatedFeatureSize<N> for () {}
1201
1202#[derive(Debug, Clone)]
1204pub struct ValidatedFeatures<const N: usize> {
1206 _phantom: PhantomData<[f64; N]>,
1207}
1208
1209impl<const N: usize> Default for ValidatedFeatures<N>
1210where
1211 (): ValidatedFeatureSize<N>,
1212{
1213 fn default() -> Self {
1214 Self::new()
1215 }
1216}
1217
1218impl<const N: usize> ValidatedFeatures<N>
1219where
1220 (): ValidatedFeatureSize<N>,
1221{
1222 pub const fn new() -> Self {
1224 Self {
1225 _phantom: PhantomData,
1226 }
1227 }
1228
1229 pub const fn count() -> usize {
1231 N
1232 }
1233
1234 pub const fn is_fastfood_optimal() -> bool {
1236 <() as ValidatedFeatureSize<N>>::IS_POWER_OF_TWO
1237 }
1238}
1239
1240pub trait ApproximationQualityBounds<Method: ApproximationMethod> {
1242 const ERROR_BOUND_CONSTANT: f64;
1244
1245 const SAMPLE_COMPLEXITY_EXPONENT: f64;
1247
1248 const DIMENSION_DEPENDENCY: f64;
1250
1251 fn error_bound(n_samples: usize, n_features: usize, n_components: usize) -> f64 {
1253 let base_rate = Self::ERROR_BOUND_CONSTANT;
1254 let sample_factor = (n_samples as f64).powf(-Self::SAMPLE_COMPLEXITY_EXPONENT);
1255 let dim_factor = (n_features as f64).powf(Self::DIMENSION_DEPENDENCY);
1256 let comp_factor = (n_components as f64).powf(-0.5);
1257
1258 base_rate * sample_factor * dim_factor * comp_factor
1259 }
1260}
1261
1262impl ApproximationQualityBounds<RandomFourierFeatures> for () {
1263 const ERROR_BOUND_CONSTANT: f64 = 2.0;
1264 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.25;
1265 const DIMENSION_DEPENDENCY: f64 = 0.1;
1266}
1267
1268impl ApproximationQualityBounds<NystromMethod> for () {
1269 const ERROR_BOUND_CONSTANT: f64 = 1.5;
1270 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.33;
1271 const DIMENSION_DEPENDENCY: f64 = 0.05;
1272}
1273
1274impl ApproximationQualityBounds<FastfoodMethod> for () {
1275 const ERROR_BOUND_CONSTANT: f64 = 3.0;
1276 const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.2;
1277 const DIMENSION_DEPENDENCY: f64 = 0.15;
1278}
1279
1280#[derive(Debug, Clone)]
1282pub struct TypeSafeKernelConfig<K: KernelType, M: ApproximationMethod, const N: usize>
1284where
1285 (): ValidatedFeatureSize<N>,
1286 (): KernelMethodCompatibility<K, M>,
1287 (): ApproximationQualityBounds<M>,
1288{
1289 kernel_type: PhantomData<K>,
1290 method_type: PhantomData<M>,
1291 features: ValidatedFeatures<N>,
1292 bandwidth: f64,
1293 quality_threshold: f64,
1294}
1295
1296impl<K: KernelType, M: ApproximationMethod, const N: usize> Default
1297 for TypeSafeKernelConfig<K, M, N>
1298where
1299 (): ValidatedFeatureSize<N>,
1300 (): KernelMethodCompatibility<K, M>,
1301 (): ApproximationQualityBounds<M>,
1302{
1303 fn default() -> Self {
1304 Self::new()
1305 }
1306}
1307
1308impl<K: KernelType, M: ApproximationMethod, const N: usize> TypeSafeKernelConfig<K, M, N>
1309where
1310 (): ValidatedFeatureSize<N>,
1311 (): KernelMethodCompatibility<K, M>,
1312 (): ApproximationQualityBounds<M>,
1313{
1314 pub fn new() -> Self {
1316 Self {
1317 kernel_type: PhantomData,
1318 method_type: PhantomData,
1319 features: ValidatedFeatures::new(),
1320 bandwidth: K::DEFAULT_BANDWIDTH,
1321 quality_threshold: 0.1,
1322 }
1323 }
1324
1325 pub fn bandwidth(mut self, bandwidth: f64) -> Self {
1327 assert!(bandwidth > 0.0, "Bandwidth must be positive");
1328 self.bandwidth = bandwidth;
1329 self
1330 }
1331
1332 pub fn quality_threshold(mut self, threshold: f64) -> Self {
1334 assert!(
1335 threshold > 0.0 && threshold < 1.0,
1336 "Quality threshold must be between 0 and 1"
1337 );
1338 self.quality_threshold = threshold;
1339 self
1340 }
1341
1342 pub const fn performance_tier() -> PerformanceTier {
1344 <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER
1345 }
1346
1347 pub const fn memory_complexity() -> ComplexityClass {
1349 <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY
1350 }
1351
1352 pub fn theoretical_error_bound(&self, n_samples: usize, n_features: usize) -> f64 {
1354 <() as ApproximationQualityBounds<M>>::error_bound(n_samples, n_features, N)
1355 }
1356
1357 pub fn meets_quality_requirements(&self, n_samples: usize, n_features: usize) -> bool {
1359 let error_bound = self.theoretical_error_bound(n_samples, n_features);
1360 error_bound <= self.quality_threshold
1361 }
1362}
1363
1364pub type ValidatedRBFRandomFourier<const N: usize> =
1366 TypeSafeKernelConfig<RBFKernel, RandomFourierFeatures, N>;
1367pub type ValidatedLaplacianNystrom<const N: usize> =
1368 TypeSafeKernelConfig<LaplacianKernel, NystromMethod, N>;
1369pub type ValidatedPolynomialRFF<const N: usize> =
1370 TypeSafeKernelConfig<PolynomialKernel, RandomFourierFeatures, N>;
1371
1372impl<K1: KernelType, K2: KernelType> ComposableKernel for SumKernel<K1, K2> {
1373 type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1374
1375 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1376 SumKernel {
1377 _phantom: PhantomData,
1378 }
1379 }
1380}
1381
1382impl<K1: KernelType, K2: KernelType> ComposableKernel for ProductKernel<K1, K2> {
1383 type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1384
1385 fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1386 ProductKernel {
1387 _phantom: PhantomData,
1388 }
1389 }
1390}
1391
1392#[allow(non_snake_case)]
1393#[cfg(test)]
1394mod advanced_composition_tests {
1395 use super::*;
1396
1397 #[test]
1398 fn test_validated_features() {
1399 let _features_64 = ValidatedFeatures::<64>::new();
1401 let _features_128 = ValidatedFeatures::<128>::new();
1402 let _features_256 = ValidatedFeatures::<256>::new();
1403
1404 assert_eq!(ValidatedFeatures::<64>::count(), 64);
1405 assert!(ValidatedFeatures::<64>::is_fastfood_optimal());
1406 }
1407
1408 #[test]
1409 fn test_type_safe_configs() {
1410 let config = ValidatedRBFRandomFourier::<128>::new()
1411 .bandwidth(1.5)
1412 .quality_threshold(0.05);
1413
1414 assert_eq!(
1415 ValidatedRBFRandomFourier::<128>::performance_tier(),
1416 PerformanceTier::Optimal
1417 );
1418 assert!(config.meets_quality_requirements(1000, 10));
1419 }
1420
1421 #[test]
1422 fn test_kernel_composition() {
1423 let _rbf = RBFKernel;
1424 let _laplacian = LaplacianKernel;
1425 let _polynomial = PolynomialKernel;
1426
1427 let _composed1 = _rbf.compose::<LaplacianKernel>();
1429 let _composed2 = _polynomial.compose::<RBFKernel>();
1430 }
1431
1432 #[test]
1433 fn test_approximation_bounds() {
1434 let rff_bound =
1435 <() as ApproximationQualityBounds<RandomFourierFeatures>>::error_bound(1000, 10, 100);
1436 let nystrom_bound =
1437 <() as ApproximationQualityBounds<NystromMethod>>::error_bound(1000, 10, 100);
1438 let fastfood_bound =
1439 <() as ApproximationQualityBounds<FastfoodMethod>>::error_bound(1000, 10, 100);
1440
1441 assert!(rff_bound > 0.0);
1442 assert!(nystrom_bound > 0.0);
1443 assert!(fastfood_bound > 0.0);
1444
1445 assert!(nystrom_bound < rff_bound);
1447 }
1448}
1449
1450pub struct KernelPresets;
1456
1457impl KernelPresets {
1458 pub fn fast_rbf_128() -> ValidatedRBFRandomFourier<128> {
1460 ValidatedRBFRandomFourier::<128>::new()
1461 .bandwidth(1.0)
1462 .quality_threshold(0.2)
1463 }
1464
1465 pub fn balanced_rbf_256() -> ValidatedRBFRandomFourier<256> {
1467 ValidatedRBFRandomFourier::<256>::new()
1468 .bandwidth(1.0)
1469 .quality_threshold(0.1)
1470 }
1471
1472 pub fn accurate_rbf_512() -> ValidatedRBFRandomFourier<512> {
1474 ValidatedRBFRandomFourier::<512>::new()
1475 .bandwidth(1.0)
1476 .quality_threshold(0.05)
1477 }
1478
1479 pub fn ultrafast_rbf_64() -> ValidatedRBFRandomFourier<64> {
1481 ValidatedRBFRandomFourier::<64>::new()
1482 .bandwidth(1.0)
1483 .quality_threshold(0.3)
1484 }
1485
1486 pub fn precise_nystroem_128() -> ValidatedLaplacianNystrom<128> {
1488 ValidatedLaplacianNystrom::<128>::new()
1489 .bandwidth(1.0)
1490 .quality_threshold(0.01)
1491 }
1492
1493 pub fn memory_efficient_rbf_32() -> ValidatedRBFRandomFourier<32> {
1495 ValidatedRBFRandomFourier::<32>::new()
1496 .bandwidth(1.0)
1497 .quality_threshold(0.4)
1498 }
1499
1500 pub fn polynomial_features_256() -> ValidatedPolynomialRFF<256> {
1502 ValidatedPolynomialRFF::<256>::new()
1503 .bandwidth(1.0)
1504 .quality_threshold(0.15)
1505 }
1506}
1507
1508#[derive(Debug, Clone)]
1514pub struct ProfileGuidedConfig {
1516 pub enable_pgo_feature_selection: bool,
1518
1519 pub enable_pgo_bandwidth_optimization: bool,
1521
1522 pub profile_data_path: Option<String>,
1524
1525 pub target_architecture: TargetArchitecture,
1527
1528 pub optimization_level: OptimizationLevel,
1530}
1531
1532#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1534pub enum TargetArchitecture {
1536 X86_64Generic,
1538 X86_64AVX2,
1540 X86_64AVX512,
1542 ARM64,
1544 ARM64NEON,
1546}
1547
1548#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1550pub enum OptimizationLevel {
1552 None,
1554 Basic,
1556 Aggressive,
1558 Maximum,
1560}
1561
1562impl Default for ProfileGuidedConfig {
1563 fn default() -> Self {
1564 Self {
1565 enable_pgo_feature_selection: false,
1566 enable_pgo_bandwidth_optimization: false,
1567 profile_data_path: None,
1568 target_architecture: TargetArchitecture::X86_64Generic,
1569 optimization_level: OptimizationLevel::Basic,
1570 }
1571 }
1572}
1573
1574impl ProfileGuidedConfig {
1575 pub fn new() -> Self {
1577 Self::default()
1578 }
1579
1580 pub fn enable_feature_selection(mut self) -> Self {
1582 self.enable_pgo_feature_selection = true;
1583 self
1584 }
1585
1586 pub fn enable_bandwidth_optimization(mut self) -> Self {
1588 self.enable_pgo_bandwidth_optimization = true;
1589 self
1590 }
1591
1592 pub fn profile_data_path<P: Into<String>>(mut self, path: P) -> Self {
1594 self.profile_data_path = Some(path.into());
1595 self
1596 }
1597
1598 pub fn target_architecture(mut self, arch: TargetArchitecture) -> Self {
1600 self.target_architecture = arch;
1601 self
1602 }
1603
1604 pub fn optimization_level(mut self, level: OptimizationLevel) -> Self {
1606 self.optimization_level = level;
1607 self
1608 }
1609
1610 pub fn recommended_feature_count(&self, data_size: usize, dimensionality: usize) -> usize {
1612 let base_features = match self.target_architecture {
1613 TargetArchitecture::X86_64Generic => 128,
1614 TargetArchitecture::X86_64AVX2 => 256,
1615 TargetArchitecture::X86_64AVX512 => 512,
1616 TargetArchitecture::ARM64 => 128,
1617 TargetArchitecture::ARM64NEON => 256,
1618 };
1619
1620 let scale_factor = match self.optimization_level {
1621 OptimizationLevel::None => 0.5,
1622 OptimizationLevel::Basic => 1.0,
1623 OptimizationLevel::Aggressive => 1.5,
1624 OptimizationLevel::Maximum => 2.0,
1625 };
1626
1627 let scaled_features = (base_features as f64 * scale_factor) as usize;
1628
1629 let data_adjustment = if data_size > 10000 {
1631 1.2
1632 } else if data_size < 1000 {
1633 0.8
1634 } else {
1635 1.0
1636 };
1637
1638 let dimension_adjustment = if dimensionality > 100 {
1639 1.1
1640 } else if dimensionality < 10 {
1641 0.9
1642 } else {
1643 1.0
1644 };
1645
1646 ((scaled_features as f64 * data_adjustment * dimension_adjustment) as usize)
1647 .max(32)
1648 .min(1024)
1649 }
1650}
1651
1652#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1658pub struct SerializableKernelConfig {
1660 pub kernel_type: String,
1662
1663 pub approximation_method: String,
1665
1666 pub n_components: usize,
1668
1669 pub bandwidth: f64,
1671
1672 pub quality_threshold: f64,
1674
1675 pub random_state: Option<u64>,
1677
1678 pub additional_params: std::collections::HashMap<String, f64>,
1680}
1681
1682#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1684pub struct SerializableFittedParams {
1686 pub config: SerializableKernelConfig,
1688
1689 pub random_features: Option<Vec<Vec<f64>>>,
1691
1692 pub selected_indices: Option<Vec<usize>>,
1694
1695 pub eigenvalues: Option<Vec<f64>>,
1697
1698 pub eigenvectors: Option<Vec<Vec<f64>>>,
1700
1701 pub quality_metrics: std::collections::HashMap<String, f64>,
1703
1704 pub fitted_timestamp: Option<u64>,
1706}
1707
1708pub trait SerializableKernelApproximation {
1710 fn export_config(&self) -> Result<SerializableKernelConfig>;
1712
1713 fn import_config(config: &SerializableKernelConfig) -> Result<Self>
1715 where
1716 Self: Sized;
1717
1718 fn export_fitted_params(&self) -> Result<SerializableFittedParams>;
1720
1721 fn import_fitted_params(&mut self, params: &SerializableFittedParams) -> Result<()>;
1723
1724 fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
1726 let config = self.export_config()?;
1727 let fitted_params = self.export_fitted_params()?;
1728
1729 let model_data = serde_json::json!({
1730 "config": config,
1731 "fitted_params": fitted_params,
1732 "version": "1.0",
1733 "sklears_version": env!("CARGO_PKG_VERSION"),
1734 });
1735
1736 std::fs::write(
1737 path,
1738 serde_json::to_string_pretty(&model_data).expect("operation should succeed"),
1739 )
1740 .map_err(SklearsError::from)?;
1741
1742 Ok(())
1743 }
1744
1745 fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self>
1747 where
1748 Self: Sized,
1749 {
1750 let content = std::fs::read_to_string(path).map_err(SklearsError::from)?;
1751
1752 let model_data: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
1753 SklearsError::SerializationError(format!("Failed to parse JSON: {}", e))
1754 })?;
1755
1756 let config: SerializableKernelConfig = serde_json::from_value(model_data["config"].clone())
1757 .map_err(|e| {
1758 SklearsError::SerializationError(format!("Failed to deserialize config: {}", e))
1759 })?;
1760
1761 let fitted_params: SerializableFittedParams =
1762 serde_json::from_value(model_data["fitted_params"].clone()).map_err(|e| {
1763 SklearsError::SerializationError(format!(
1764 "Failed to deserialize fitted params: {}",
1765 e
1766 ))
1767 })?;
1768
1769 let mut model = Self::import_config(&config)?;
1770 model.import_fitted_params(&fitted_params)?;
1771
1772 Ok(model)
1773 }
1774}
1775
1776#[allow(non_snake_case)]
1777#[cfg(test)]
1778mod preset_tests {
1779 use super::*;
1780
1781 #[test]
1782 fn test_kernel_presets() {
1783 let _fast_config = KernelPresets::fast_rbf_128();
1784 assert_eq!(
1785 ValidatedRBFRandomFourier::<128>::performance_tier(),
1786 PerformanceTier::Optimal
1787 );
1788
1789 let _balanced_config = KernelPresets::balanced_rbf_256();
1790 assert_eq!(
1791 ValidatedRBFRandomFourier::<256>::performance_tier(),
1792 PerformanceTier::Optimal
1793 );
1794
1795 let _accurate_config = KernelPresets::accurate_rbf_512();
1796 assert_eq!(
1797 ValidatedRBFRandomFourier::<512>::performance_tier(),
1798 PerformanceTier::Optimal
1799 );
1800 }
1801
1802 #[test]
1803 fn test_profile_guided_config() {
1804 let pgo_config = ProfileGuidedConfig::new()
1805 .enable_feature_selection()
1806 .enable_bandwidth_optimization()
1807 .target_architecture(TargetArchitecture::X86_64AVX2)
1808 .optimization_level(OptimizationLevel::Aggressive);
1809
1810 assert!(pgo_config.enable_pgo_feature_selection);
1811 assert!(pgo_config.enable_pgo_bandwidth_optimization);
1812 assert_eq!(
1813 pgo_config.target_architecture,
1814 TargetArchitecture::X86_64AVX2
1815 );
1816 assert_eq!(pgo_config.optimization_level, OptimizationLevel::Aggressive);
1817
1818 let features = pgo_config.recommended_feature_count(5000, 50);
1820 assert!(features >= 32 && features <= 1024);
1821 }
1822
1823 #[test]
1824 fn test_serializable_config() {
1825 let config = SerializableKernelConfig {
1826 kernel_type: "RBF".to_string(),
1827 approximation_method: "RandomFourierFeatures".to_string(),
1828 n_components: 256,
1829 bandwidth: 1.5,
1830 quality_threshold: 0.1,
1831 random_state: Some(42),
1832 additional_params: std::collections::HashMap::new(),
1833 };
1834
1835 let serialized = serde_json::to_string(&config).expect("operation should succeed");
1837 assert!(serialized.contains("RBF"));
1838 assert!(serialized.contains("RandomFourierFeatures"));
1839
1840 let deserialized: SerializableKernelConfig =
1842 serde_json::from_str(&serialized).expect("operation should succeed");
1843 assert_eq!(deserialized.kernel_type, "RBF");
1844 assert_eq!(deserialized.n_components, 256);
1845 assert_eq!(deserialized.bandwidth, 1.5);
1846 }
1847}