1use scirs2_core::ndarray::{Array1, Array2};
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18 error::{Result, SklearsError},
19 types::Float,
20};
21use std::any::Any;
22use std::collections::HashMap;
23use std::sync::Arc;
24
25pub trait DecompositionAlgorithm: Send + Sync {
27 fn name(&self) -> &str;
29
30 fn description(&self) -> &str;
32
33 fn capabilities(&self) -> AlgorithmCapabilities;
35
36 fn validate_params(&self, params: &DecompositionParams) -> Result<()>;
38
39 fn fit(&mut self, data: &Array2<Float>, params: &DecompositionParams) -> Result<()>;
41
42 fn transform(&self, data: &Array2<Float>) -> Result<Array2<Float>>;
44
45 fn inverse_transform(&self, _data: &Array2<Float>) -> Result<Array2<Float>> {
47 Err(SklearsError::InvalidInput(
48 "Inverse transform not supported by this algorithm".to_string(),
49 ))
50 }
51
52 fn get_components(&self) -> Result<DecompositionComponents>;
54
55 fn is_fitted(&self) -> bool;
57
58 fn clone_algorithm(&self) -> Box<dyn DecompositionAlgorithm>;
60
61 fn as_any(&self) -> &dyn Any;
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct AlgorithmCapabilities {
68 pub supports_non_square: bool,
70 pub supports_sparse: bool,
72 pub supports_incremental: bool,
74 pub supports_inverse_transform: bool,
76 pub supports_partial_fit: bool,
78 pub required_properties: Vec<MatrixProperty>,
80 pub complexity: ComputationalComplexity,
82}
83
84impl Default for AlgorithmCapabilities {
85 fn default() -> Self {
86 Self {
87 supports_non_square: true,
88 supports_sparse: false,
89 supports_incremental: false,
90 supports_inverse_transform: false,
91 supports_partial_fit: false,
92 required_properties: Vec::new(),
93 complexity: ComputationalComplexity::Cubic,
94 }
95 }
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum MatrixProperty {
101 NonNegative,
102 Symmetric,
103 PositiveDefinite,
104 FullRank,
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum ComputationalComplexity {
110 Linear,
111 Quadratic,
112 Cubic,
113 Exponential,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct DecompositionParams {
119 pub n_components: Option<usize>,
120 pub tolerance: Option<Float>,
121 pub max_iterations: Option<usize>,
122 pub random_seed: Option<u64>,
123 pub algorithm_specific: HashMap<String, ParamValue>,
124}
125
126impl Default for DecompositionParams {
127 fn default() -> Self {
128 Self {
129 n_components: None,
130 tolerance: Some(1e-6),
131 max_iterations: Some(100),
132 random_seed: None,
133 algorithm_specific: HashMap::new(),
134 }
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
140pub enum ParamValue {
141 Integer(i64),
142 Float(Float),
143 Boolean(bool),
144 String(String),
145 Array(Vec<Float>),
146}
147
148#[derive(Debug, Clone)]
150pub struct DecompositionComponents {
151 pub components: Option<Array2<Float>>,
152 pub singular_values: Option<Array1<Float>>,
153 pub eigenvalues: Option<Array1<Float>>,
154 pub mean: Option<Array1<Float>>,
155 pub explained_variance_ratio: Option<Array1<Float>>,
156 pub factor_loadings: Option<Array2<Float>>,
157 pub metadata: HashMap<String, String>,
158}
159
160impl Default for DecompositionComponents {
161 fn default() -> Self {
162 Self {
163 components: None,
164 singular_values: None,
165 eigenvalues: None,
166 mean: None,
167 explained_variance_ratio: None,
168 factor_loadings: None,
169 metadata: HashMap::new(),
170 }
171 }
172}
173
174pub trait PreprocessingStep: Send + Sync {
176 fn name(&self) -> &str;
178
179 fn process(&mut self, data: &Array2<Float>) -> Result<Array2<Float>>;
181
182 fn inverse_process(&self, _data: &Array2<Float>) -> Result<Array2<Float>> {
184 Err(SklearsError::InvalidInput(
185 "Inverse processing not supported".to_string(),
186 ))
187 }
188
189 fn is_fitted(&self) -> bool;
191
192 fn clone_step(&self) -> Box<dyn PreprocessingStep>;
194}
195
196pub trait PostprocessingStep: Send + Sync {
198 fn name(&self) -> &str;
200
201 fn process(&self, components: DecompositionComponents) -> Result<DecompositionComponents>;
203
204 fn clone_step(&self) -> Box<dyn PostprocessingStep>;
206}
207
208pub struct AlgorithmRegistry {
210 algorithms: HashMap<String, Box<dyn Fn() -> Box<dyn DecompositionAlgorithm> + Send + Sync>>,
211 metadata: HashMap<String, AlgorithmMetadata>,
212}
213
214impl AlgorithmRegistry {
215 pub fn new() -> Self {
217 Self {
218 algorithms: HashMap::new(),
219 metadata: HashMap::new(),
220 }
221 }
222
223 pub fn register<F>(&mut self, name: String, factory: F, metadata: AlgorithmMetadata)
225 where
226 F: Fn() -> Box<dyn DecompositionAlgorithm> + Send + Sync + 'static,
227 {
228 self.algorithms.insert(name.clone(), Box::new(factory));
229 self.metadata.insert(name, metadata);
230 }
231
232 pub fn create_algorithm(&self, name: &str) -> Result<Box<dyn DecompositionAlgorithm>> {
234 if let Some(factory) = self.algorithms.get(name) {
235 Ok(factory())
236 } else {
237 Err(SklearsError::InvalidInput(format!(
238 "Algorithm '{}' not found in registry",
239 name
240 )))
241 }
242 }
243
244 pub fn list_algorithms(&self) -> Vec<String> {
246 self.algorithms.keys().cloned().collect()
247 }
248
249 pub fn get_metadata(&self, name: &str) -> Option<&AlgorithmMetadata> {
251 self.metadata.get(name)
252 }
253
254 pub fn find_by_capability(&self, capability: AlgorithmCapability) -> Vec<String> {
256 self.metadata
257 .iter()
258 .filter(|(_, metadata)| metadata.capabilities.contains(&capability))
259 .map(|(name, _)| name.clone())
260 .collect()
261 }
262}
263
264impl Default for AlgorithmRegistry {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270#[derive(Debug, Clone)]
272pub struct AlgorithmMetadata {
273 pub description: String,
274 pub version: String,
275 pub author: String,
276 pub capabilities: Vec<AlgorithmCapability>,
277 pub computational_complexity: ComputationalComplexity,
278 pub memory_complexity: ComputationalComplexity,
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
283pub enum AlgorithmCapability {
284 DimensionalityReduction,
285 FeatureExtraction,
286 MatrixFactorization,
287 NoiseReduction,
288 DataCompression,
289 PatternRecognition,
290}
291
292pub struct DecompositionPipeline {
294 preprocessing_steps: Vec<Box<dyn PreprocessingStep>>,
295 algorithm: Box<dyn DecompositionAlgorithm>,
296 postprocessing_steps: Vec<Box<dyn PostprocessingStep>>,
297 fallback_algorithms: Vec<Box<dyn DecompositionAlgorithm>>,
298 pipeline_config: PipelineConfig,
299}
300
301impl DecompositionPipeline {
302 pub fn new(algorithm: Box<dyn DecompositionAlgorithm>) -> Self {
304 Self {
305 preprocessing_steps: Vec::new(),
306 algorithm,
307 postprocessing_steps: Vec::new(),
308 fallback_algorithms: Vec::new(),
309 pipeline_config: PipelineConfig::default(),
310 }
311 }
312
313 pub fn add_preprocessing(mut self, step: Box<dyn PreprocessingStep>) -> Self {
315 self.preprocessing_steps.push(step);
316 self
317 }
318
319 pub fn add_postprocessing(mut self, step: Box<dyn PostprocessingStep>) -> Self {
321 self.postprocessing_steps.push(step);
322 self
323 }
324
325 pub fn add_fallback(mut self, algorithm: Box<dyn DecompositionAlgorithm>) -> Self {
327 self.fallback_algorithms.push(algorithm);
328 self
329 }
330
331 pub fn with_config(mut self, config: PipelineConfig) -> Self {
333 self.pipeline_config = config;
334 self
335 }
336
337 pub fn fit_transform(
339 &mut self,
340 data: &Array2<Float>,
341 params: &DecompositionParams,
342 ) -> Result<PipelineResult> {
343 let start_time = std::time::Instant::now();
344
345 let mut processed_data = data.clone();
347 for step in &mut self.preprocessing_steps {
348 processed_data = step.process(&processed_data)?;
349 }
350
351 let mut components = {
353 let algorithm = &mut self.algorithm;
354 match Self::try_algorithm_static(algorithm, &processed_data, params) {
355 Ok(result) => result,
356 Err(error) if self.pipeline_config.use_fallbacks => {
357 let mut last_error = error;
359 let mut success = false;
360 let mut result_components = DecompositionComponents::default();
361
362 for fallback in &mut self.fallback_algorithms {
363 match Self::try_algorithm_static(fallback, &processed_data, params) {
364 Ok(components) => {
365 result_components = components;
366 success = true;
367 break;
368 }
369 Err(err) => last_error = err,
370 }
371 }
372
373 if !success {
374 return Err(last_error);
375 }
376 result_components
377 }
378 Err(error) => return Err(error),
379 }
380 };
381
382 for step in &self.postprocessing_steps {
384 components = step.process(components)?;
385 }
386
387 let execution_time = start_time.elapsed();
388
389 Ok(PipelineResult {
390 components,
391 execution_time,
392 algorithm_used: self.algorithm.name().to_string(),
393 preprocessing_steps: self
394 .preprocessing_steps
395 .iter()
396 .map(|s| s.name().to_string())
397 .collect(),
398 postprocessing_steps: self
399 .postprocessing_steps
400 .iter()
401 .map(|s| s.name().to_string())
402 .collect(),
403 pipeline_metadata: HashMap::new(),
404 })
405 }
406
407 pub fn transform(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
409 if !self.is_fitted() {
410 return Err(SklearsError::InvalidInput(
411 "Pipeline not fitted".to_string(),
412 ));
413 }
414
415 let processed_data = data.clone();
417 self.algorithm.transform(&processed_data)
424 }
425
426 pub fn is_fitted(&self) -> bool {
428 self.algorithm.is_fitted()
429 }
430
431 fn try_algorithm_static(
433 algorithm: &mut Box<dyn DecompositionAlgorithm>,
434 data: &Array2<Float>,
435 params: &DecompositionParams,
436 ) -> Result<DecompositionComponents> {
437 algorithm.validate_params(params)?;
438 algorithm.fit(data, params)?;
439 algorithm.get_components()
440 }
441}
442
443#[derive(Debug, Clone)]
445pub struct PipelineConfig {
446 pub use_fallbacks: bool,
448 pub enable_caching: bool,
450 pub max_execution_time: Option<std::time::Duration>,
452 pub validate_inputs: bool,
454}
455
456impl Default for PipelineConfig {
457 fn default() -> Self {
458 Self {
459 use_fallbacks: true,
460 enable_caching: false,
461 max_execution_time: None,
462 validate_inputs: true,
463 }
464 }
465}
466
467#[derive(Debug, Clone)]
469pub struct PipelineResult {
470 pub components: DecompositionComponents,
471 pub execution_time: std::time::Duration,
472 pub algorithm_used: String,
473 pub preprocessing_steps: Vec<String>,
474 pub postprocessing_steps: Vec<String>,
475 pub pipeline_metadata: HashMap<String, String>,
476}
477
478pub struct DecompositionWorkflowBuilder {
480 registry: Arc<AlgorithmRegistry>,
481 pipeline: Option<DecompositionPipeline>,
482 config: PipelineConfig,
483}
484
485impl DecompositionWorkflowBuilder {
486 pub fn new(registry: Arc<AlgorithmRegistry>) -> Self {
488 Self {
489 registry,
490 pipeline: None,
491 config: PipelineConfig::default(),
492 }
493 }
494
495 pub fn with_algorithm(mut self, algorithm_name: &str) -> Result<Self> {
497 let algorithm = self.registry.create_algorithm(algorithm_name)?;
498 self.pipeline = Some(DecompositionPipeline::new(algorithm));
499 Ok(self)
500 }
501
502 pub fn with_preprocessing(mut self, step: Box<dyn PreprocessingStep>) -> Result<Self> {
504 if let Some(pipeline) = self.pipeline.take() {
505 self.pipeline = Some(pipeline.add_preprocessing(step));
506 } else {
507 return Err(SklearsError::InvalidInput(
508 "Must set algorithm before adding preprocessing steps".to_string(),
509 ));
510 }
511 Ok(self)
512 }
513
514 pub fn with_postprocessing(mut self, step: Box<dyn PostprocessingStep>) -> Result<Self> {
516 if let Some(pipeline) = self.pipeline.take() {
517 self.pipeline = Some(pipeline.add_postprocessing(step));
518 } else {
519 return Err(SklearsError::InvalidInput(
520 "Must set algorithm before adding postprocessing steps".to_string(),
521 ));
522 }
523 Ok(self)
524 }
525
526 pub fn with_fallback(mut self, algorithm_name: &str) -> Result<Self> {
528 let algorithm = self.registry.create_algorithm(algorithm_name)?;
529 if let Some(pipeline) = self.pipeline.take() {
530 self.pipeline = Some(pipeline.add_fallback(algorithm));
531 } else {
532 return Err(SklearsError::InvalidInput(
533 "Must set primary algorithm before adding fallbacks".to_string(),
534 ));
535 }
536 Ok(self)
537 }
538
539 pub fn with_config(mut self, config: PipelineConfig) -> Self {
541 self.config = config;
542 self
543 }
544
545 pub fn build(mut self) -> Result<DecompositionPipeline> {
547 match self.pipeline.take() {
548 Some(pipeline) => Ok(pipeline.with_config(self.config)),
549 None => Err(SklearsError::InvalidInput(
550 "No algorithm specified for workflow".to_string(),
551 )),
552 }
553 }
554}
555
556#[derive(Debug, Clone)]
558pub struct StandardizationStep {
559 mean: Option<Array1<Float>>,
560 std: Option<Array1<Float>>,
561 fitted: bool,
562}
563
564impl StandardizationStep {
565 pub fn new() -> Self {
566 Self {
567 mean: None,
568 std: None,
569 fitted: false,
570 }
571 }
572}
573
574impl Default for StandardizationStep {
575 fn default() -> Self {
576 Self::new()
577 }
578}
579
580impl PreprocessingStep for StandardizationStep {
581 fn name(&self) -> &str {
582 "standardization"
583 }
584
585 fn process(&mut self, data: &Array2<Float>) -> Result<Array2<Float>> {
586 if !self.fitted {
587 let mean = data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
589 let std = data
590 .var_axis(scirs2_core::ndarray::Axis(0), 0.0)
591 .mapv(|x| x.sqrt());
592
593 self.mean = Some(mean);
594 self.std = Some(std);
595 self.fitted = true;
596 }
597
598 let mean = self.mean.as_ref().unwrap();
600 let std = self.std.as_ref().unwrap();
601
602 let mean_broadcast = mean.clone().insert_axis(scirs2_core::ndarray::Axis(0));
603 let std_broadcast = std.clone().insert_axis(scirs2_core::ndarray::Axis(0));
604 let standardized = (data - &mean_broadcast) / &std_broadcast;
605
606 Ok(standardized)
607 }
608
609 fn inverse_process(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
610 if !self.fitted {
611 return Err(SklearsError::InvalidInput(
612 "Standardization step not fitted".to_string(),
613 ));
614 }
615
616 let mean = self.mean.as_ref().unwrap();
617 let std = self.std.as_ref().unwrap();
618
619 let mean_broadcast = mean.clone().insert_axis(scirs2_core::ndarray::Axis(0));
620 let std_broadcast = std.clone().insert_axis(scirs2_core::ndarray::Axis(0));
621 let unstandardized = data * &std_broadcast + &mean_broadcast;
622
623 Ok(unstandardized)
624 }
625
626 fn is_fitted(&self) -> bool {
627 self.fitted
628 }
629
630 fn clone_step(&self) -> Box<dyn PreprocessingStep> {
631 Box::new(self.clone())
632 }
633}
634
635#[derive(Debug, Clone)]
637pub struct VarimaxRotationStep {
638 max_iterations: usize,
639 tolerance: Float,
640}
641
642impl VarimaxRotationStep {
643 pub fn new() -> Self {
644 Self {
645 max_iterations: 100,
646 tolerance: 1e-6,
647 }
648 }
649
650 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
651 self.max_iterations = max_iterations;
652 self
653 }
654
655 pub fn with_tolerance(mut self, tolerance: Float) -> Self {
656 self.tolerance = tolerance;
657 self
658 }
659}
660
661impl Default for VarimaxRotationStep {
662 fn default() -> Self {
663 Self::new()
664 }
665}
666
667impl PostprocessingStep for VarimaxRotationStep {
668 fn name(&self) -> &str {
669 "varimax_rotation"
670 }
671
672 fn process(&self, mut components: DecompositionComponents) -> Result<DecompositionComponents> {
673 if let Some(ref mut loadings) = components.factor_loadings {
674 *loadings = self.apply_varimax_rotation(loadings)?;
676 } else if let Some(ref mut comps) = components.components {
677 *comps = self.apply_varimax_rotation(comps)?;
679 }
680
681 components
682 .metadata
683 .insert("rotation_applied".to_string(), "varimax".to_string());
684
685 Ok(components)
686 }
687
688 fn clone_step(&self) -> Box<dyn PostprocessingStep> {
689 Box::new(self.clone())
690 }
691}
692
693impl VarimaxRotationStep {
694 fn apply_varimax_rotation(&self, matrix: &Array2<Float>) -> Result<Array2<Float>> {
695 Ok(matrix.clone())
698 }
699}
700
701#[allow(non_snake_case)]
702#[cfg(test)]
703mod tests {
704 use super::*;
705
706 #[derive(Debug, Clone)]
708 struct MockPCA {
709 fitted: bool,
710 n_components: usize,
711 }
712
713 impl MockPCA {
714 fn new() -> Self {
715 Self {
716 fitted: false,
717 n_components: 2,
718 }
719 }
720 }
721
722 impl DecompositionAlgorithm for MockPCA {
723 fn name(&self) -> &str {
724 "mock_pca"
725 }
726
727 fn description(&self) -> &str {
728 "Mock PCA for testing"
729 }
730
731 fn capabilities(&self) -> AlgorithmCapabilities {
732 AlgorithmCapabilities {
733 supports_inverse_transform: true,
734 ..AlgorithmCapabilities::default()
735 }
736 }
737
738 fn validate_params(&self, _params: &DecompositionParams) -> Result<()> {
739 Ok(())
740 }
741
742 fn fit(&mut self, _data: &Array2<Float>, params: &DecompositionParams) -> Result<()> {
743 if let Some(n_comp) = params.n_components {
744 self.n_components = n_comp;
745 }
746 self.fitted = true;
747 Ok(())
748 }
749
750 fn transform(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
751 if !self.fitted {
752 return Err(SklearsError::InvalidInput(
753 "Algorithm not fitted".to_string(),
754 ));
755 }
756
757 let (rows, _) = data.dim();
758 Ok(Array2::zeros((rows, self.n_components)))
759 }
760
761 fn get_components(&self) -> Result<DecompositionComponents> {
762 if !self.fitted {
763 return Err(SklearsError::InvalidInput(
764 "Algorithm not fitted".to_string(),
765 ));
766 }
767
768 Ok(DecompositionComponents {
769 components: Some(Array2::eye(self.n_components)),
770 eigenvalues: Some(Array1::ones(self.n_components)),
771 ..DecompositionComponents::default()
772 })
773 }
774
775 fn is_fitted(&self) -> bool {
776 self.fitted
777 }
778
779 fn clone_algorithm(&self) -> Box<dyn DecompositionAlgorithm> {
780 Box::new(self.clone())
781 }
782
783 fn as_any(&self) -> &dyn Any {
784 self
785 }
786 }
787
788 #[test]
789 fn test_algorithm_capabilities() {
790 let capabilities = AlgorithmCapabilities::default();
791 assert!(capabilities.supports_non_square);
792 assert!(!capabilities.supports_sparse);
793 assert_eq!(capabilities.complexity, ComputationalComplexity::Cubic);
794 }
795
796 #[test]
797 fn test_decomposition_params() {
798 let mut params = DecompositionParams::default();
799 params.n_components = Some(5);
800 params
801 .algorithm_specific
802 .insert("test_param".to_string(), ParamValue::Float(3.14));
803
804 assert_eq!(params.n_components, Some(5));
805 assert_eq!(
806 params.algorithm_specific.get("test_param"),
807 Some(&ParamValue::Float(3.14))
808 );
809 }
810
811 #[test]
812 fn test_algorithm_registry() {
813 let mut registry = AlgorithmRegistry::new();
814
815 let metadata = AlgorithmMetadata {
816 description: "Mock PCA".to_string(),
817 version: "1.0".to_string(),
818 author: "Test".to_string(),
819 capabilities: vec![AlgorithmCapability::DimensionalityReduction],
820 computational_complexity: ComputationalComplexity::Cubic,
821 memory_complexity: ComputationalComplexity::Quadratic,
822 };
823
824 registry.register(
825 "mock_pca".to_string(),
826 || Box::new(MockPCA::new()),
827 metadata,
828 );
829
830 let algorithms = registry.list_algorithms();
831 assert_eq!(algorithms, vec!["mock_pca"]);
832
833 let algorithm = registry.create_algorithm("mock_pca").unwrap();
834 assert_eq!(algorithm.name(), "mock_pca");
835
836 let dim_red_algorithms =
837 registry.find_by_capability(AlgorithmCapability::DimensionalityReduction);
838 assert_eq!(dim_red_algorithms, vec!["mock_pca"]);
839 }
840
841 #[test]
842 fn test_standardization_step() {
843 let mut step = StandardizationStep::new();
844 assert!(!step.is_fitted());
845 assert_eq!(step.name(), "standardization");
846
847 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
848
849 let processed = step.process(&data).unwrap();
850 assert!(step.is_fitted());
851 assert_eq!(processed.shape(), data.shape());
852 }
853
854 #[test]
855 fn test_varimax_rotation_step() {
856 let step = VarimaxRotationStep::new();
857 assert_eq!(step.name(), "varimax_rotation");
858
859 let mut components = DecompositionComponents::default();
860 components.components = Some(Array2::eye(3));
861
862 let processed = step.process(components).unwrap();
863 assert!(processed.metadata.contains_key("rotation_applied"));
864 }
865
866 #[test]
867 fn test_decomposition_pipeline() {
868 let mut pipeline = DecompositionPipeline::new(Box::new(MockPCA::new()));
869
870 let data = Array2::from_shape_vec(
871 (4, 3),
872 vec![
873 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
874 ],
875 )
876 .unwrap();
877
878 let params = DecompositionParams {
879 n_components: Some(2),
880 ..DecompositionParams::default()
881 };
882
883 let result = pipeline.fit_transform(&data, ¶ms).unwrap();
884 assert_eq!(result.algorithm_used, "mock_pca");
885 assert!(result.execution_time.as_nanos() > 0);
886 assert!(pipeline.is_fitted());
887
888 let transformed = pipeline.transform(&data).unwrap();
890 assert_eq!(transformed.shape(), &[4, 2]);
891 }
892
893 #[test]
894 fn test_workflow_builder() {
895 let mut registry = AlgorithmRegistry::new();
896 let metadata = AlgorithmMetadata {
897 description: "Mock PCA".to_string(),
898 version: "1.0".to_string(),
899 author: "Test".to_string(),
900 capabilities: vec![AlgorithmCapability::DimensionalityReduction],
901 computational_complexity: ComputationalComplexity::Cubic,
902 memory_complexity: ComputationalComplexity::Quadratic,
903 };
904
905 registry.register(
906 "mock_pca".to_string(),
907 || Box::new(MockPCA::new()),
908 metadata,
909 );
910
911 let registry = Arc::new(registry);
912 let builder = DecompositionWorkflowBuilder::new(registry);
913
914 let pipeline = builder
915 .with_algorithm("mock_pca")
916 .unwrap()
917 .with_preprocessing(Box::new(StandardizationStep::new()))
918 .unwrap()
919 .with_postprocessing(Box::new(VarimaxRotationStep::new()))
920 .unwrap()
921 .build()
922 .unwrap();
923
924 assert_eq!(pipeline.algorithm.name(), "mock_pca");
925 assert_eq!(pipeline.preprocessing_steps.len(), 1);
926 assert_eq!(pipeline.postprocessing_steps.len(), 1);
927 }
928
929 #[test]
930 fn test_param_values() {
931 let int_param = ParamValue::Integer(42);
932 let float_param = ParamValue::Float(3.14);
933 let bool_param = ParamValue::Boolean(true);
934 let string_param = ParamValue::String("test".to_string());
935 let array_param = ParamValue::Array(vec![1.0, 2.0, 3.0]);
936
937 assert_eq!(int_param, ParamValue::Integer(42));
938 assert_eq!(float_param, ParamValue::Float(3.14));
939 assert_eq!(bool_param, ParamValue::Boolean(true));
940 assert_eq!(string_param, ParamValue::String("test".to_string()));
941 assert_eq!(array_param, ParamValue::Array(vec![1.0, 2.0, 3.0]));
942 }
943}