1use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12use std::sync::{Arc, RwLock};
13use std::time::Instant;
14
15use scirs2_core::ndarray::{s, Array2, ArrayView2};
16
17#[cfg(feature = "parallel")]
18use rayon::prelude::*;
19
20#[cfg(feature = "serde")]
21use serde::{Deserialize, Serialize};
22
23use sklears_core::{
24 error::{Result, SklearsError},
25 traits::Transform,
26};
27
28use crate::streaming::{StreamingConfig, StreamingStats, StreamingTransformer};
29
30#[derive(Clone, Debug)]
32struct CacheEntry<T> {
33 result: T,
34 timestamp: Instant,
35 access_count: usize,
36}
37
38#[derive(Debug, Clone)]
40#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
41pub struct CacheConfig {
42 pub max_entries: usize,
44 pub ttl_seconds: u64,
46 pub enabled: bool,
48}
49
50impl Default for CacheConfig {
51 fn default() -> Self {
52 Self {
53 max_entries: 100,
54 ttl_seconds: 3600, enabled: true,
56 }
57 }
58}
59
60pub struct TransformationCache<T> {
62 cache: Arc<RwLock<HashMap<u64, CacheEntry<T>>>>,
63 config: CacheConfig,
64}
65
66impl<T: Clone> TransformationCache<T> {
67 pub fn new(config: CacheConfig) -> Self {
68 Self {
69 cache: Arc::new(RwLock::new(HashMap::new())),
70 config,
71 }
72 }
73
74 fn generate_key<U: Hash>(&self, input: U) -> u64 {
76 let mut hasher = std::collections::hash_map::DefaultHasher::new();
77 input.hash(&mut hasher);
78 hasher.finish()
79 }
80
81 pub fn get(&self, key: u64) -> Option<T> {
83 if !self.config.enabled {
84 return None;
85 }
86
87 let mut cache = self.cache.write().ok()?;
88
89 if let Some(entry) = cache.get_mut(&key) {
91 let age = entry.timestamp.elapsed();
92 if age.as_secs() <= self.config.ttl_seconds {
93 entry.access_count += 1;
94 return Some(entry.result.clone());
95 } else {
96 cache.remove(&key);
98 }
99 }
100
101 None
102 }
103
104 pub fn put(&self, key: u64, value: T) {
106 if !self.config.enabled {
107 return;
108 }
109
110 let mut cache = self.cache.write().expect("operation should succeed");
111
112 if cache.len() >= self.config.max_entries {
114 self.evict_lru(&mut cache);
115 }
116
117 cache.insert(
118 key,
119 CacheEntry {
120 result: value,
121 timestamp: Instant::now(),
122 access_count: 1,
123 },
124 );
125 }
126
127 fn evict_lru(&self, cache: &mut HashMap<u64, CacheEntry<T>>) {
129 if let Some((key_to_remove, _)) = cache.iter().min_by_key(|(_, entry)| entry.access_count) {
130 let key_to_remove = *key_to_remove;
131 cache.remove(&key_to_remove);
132 }
133 }
134
135 pub fn clear(&self) {
137 if let Ok(mut cache) = self.cache.write() {
138 cache.clear();
139 }
140 }
141
142 pub fn stats(&self) -> CacheStats {
144 let cache = self.cache.read().expect("operation should succeed");
145 CacheStats {
146 entries: cache.len(),
147 max_entries: self.config.max_entries,
148 enabled: self.config.enabled,
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct CacheStats {
156 pub entries: usize,
157 pub max_entries: usize,
158 pub enabled: bool,
159}
160
161pub type ConditionFn = Box<dyn Fn(&ArrayView2<f64>) -> bool + Send + Sync>;
163
164pub struct ConditionalStepConfig<T> {
166 pub transformer: T,
168 pub condition: ConditionFn,
170 pub name: String,
172 pub skip_on_false: bool,
174}
175
176impl<T: std::fmt::Debug> std::fmt::Debug for ConditionalStepConfig<T> {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 f.debug_struct("ConditionalStepConfig")
179 .field("transformer", &self.transformer)
180 .field("condition", &"<function>")
181 .field("name", &self.name)
182 .field("skip_on_false", &self.skip_on_false)
183 .finish()
184 }
185}
186
187pub struct ConditionalStep<T> {
189 config: ConditionalStepConfig<T>,
190 fitted: bool,
191}
192
193impl<T> ConditionalStep<T>
194where
195 T: Transform<Array2<f64>, Array2<f64>> + Clone,
196{
197 pub fn new(config: ConditionalStepConfig<T>) -> Self {
198 Self {
199 config,
200 fitted: false,
201 }
202 }
203
204 pub fn check_condition(&self, data: &ArrayView2<f64>) -> bool {
206 (self.config.condition)(data)
207 }
208}
209
210impl<T> Transform<Array2<f64>, Array2<f64>> for ConditionalStep<T>
211where
212 T: Transform<Array2<f64>, Array2<f64>> + Clone,
213{
214 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
215 let data_view = data.view();
216
217 if self.check_condition(&data_view) {
218 self.config.transformer.transform(data)
219 } else if self.config.skip_on_false {
220 Ok(data.clone()) } else {
222 Err(SklearsError::InvalidInput(format!(
223 "Condition not met for step: {}",
224 self.config.name
225 )))
226 }
227 }
228}
229
230#[derive(Debug)]
232pub struct ParallelBranchConfig<T> {
233 pub transformers: Vec<T>,
235 pub branch_names: Vec<String>,
237 pub combination_strategy: BranchCombinationStrategy,
239}
240
241#[derive(Debug, Clone)]
243#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
244pub enum BranchCombinationStrategy {
245 Concatenate,
247 Average,
249 FirstSuccess,
251 WeightedCombination(Vec<f64>),
253}
254
255pub struct ParallelBranches<T> {
257 config: ParallelBranchConfig<T>,
258 fitted: bool,
259}
260
261impl<T> ParallelBranches<T>
262where
263 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
264{
265 pub fn new(config: ParallelBranchConfig<T>) -> Result<Self> {
266 if config.transformers.len() != config.branch_names.len() {
267 return Err(SklearsError::InvalidInput(
268 "Number of transformers must match number of branch names".to_string(),
269 ));
270 }
271
272 if let BranchCombinationStrategy::WeightedCombination(ref weights) =
273 config.combination_strategy
274 {
275 if weights.len() != config.transformers.len() {
276 return Err(SklearsError::InvalidInput(
277 "Number of weights must match number of transformers".to_string(),
278 ));
279 }
280 }
281
282 Ok(Self {
283 config,
284 fitted: false,
285 })
286 }
287}
288
289impl<T> Transform<Array2<f64>, Array2<f64>> for ParallelBranches<T>
290where
291 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
292{
293 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
294 #[cfg(feature = "parallel")]
296 let results: Result<Vec<Array2<f64>>> = self
297 .config
298 .transformers
299 .par_iter()
300 .zip(self.config.branch_names.par_iter())
301 .map(|(transformer, name)| {
302 transformer.transform(data).map_err(|e| {
303 SklearsError::TransformError(format!("Error in branch '{}': {}", name, e))
304 })
305 })
306 .collect();
307
308 #[cfg(not(feature = "parallel"))]
309 let results: Result<Vec<Array2<f64>>> = self
310 .config
311 .transformers
312 .iter()
313 .zip(self.config.branch_names.iter())
314 .map(|(transformer, name)| {
315 transformer.transform(data).map_err(|e| {
316 SklearsError::TransformError(format!("Error in branch '{}': {}", name, e))
317 })
318 })
319 .collect();
320
321 let branch_results = results?;
322
323 match &self.config.combination_strategy {
325 BranchCombinationStrategy::Concatenate => self.concatenate_results(branch_results),
326 BranchCombinationStrategy::Average => self.average_results(branch_results),
327 BranchCombinationStrategy::FirstSuccess => Ok(branch_results
328 .into_iter()
329 .next()
330 .expect("operation should succeed")),
331 BranchCombinationStrategy::WeightedCombination(weights) => {
332 self.weighted_combination(branch_results, weights)
333 }
334 }
335 }
336}
337
338impl<T> ParallelBranches<T> {
339 fn concatenate_results(&self, results: Vec<Array2<f64>>) -> Result<Array2<f64>> {
341 if results.is_empty() {
342 return Err(SklearsError::InvalidInput(
343 "No results to concatenate".to_string(),
344 ));
345 }
346
347 let n_rows = results[0].nrows();
348 if !results.iter().all(|r| r.nrows() == n_rows) {
349 return Err(SklearsError::InvalidInput(
350 "All results must have the same number of rows for concatenation".to_string(),
351 ));
352 }
353
354 let total_cols: usize = results.iter().map(|r| r.ncols()).sum();
355 let mut combined = Array2::zeros((n_rows, total_cols));
356
357 let mut col_offset = 0;
358 for result in results {
359 let n_cols = result.ncols();
360 combined
361 .slice_mut(s![.., col_offset..col_offset + n_cols])
362 .assign(&result);
363 col_offset += n_cols;
364 }
365
366 Ok(combined)
367 }
368
369 fn average_results(&self, results: Vec<Array2<f64>>) -> Result<Array2<f64>> {
371 if results.is_empty() {
372 return Err(SklearsError::InvalidInput(
373 "No results to average".to_string(),
374 ));
375 }
376
377 let shape = results[0].raw_dim();
378 if !results.iter().all(|r| r.raw_dim() == shape) {
379 return Err(SklearsError::InvalidInput(
380 "All results must have the same shape for averaging".to_string(),
381 ));
382 }
383
384 let mut sum = Array2::zeros(shape);
385 for result in &results {
386 sum += result;
387 }
388 sum /= results.len() as f64;
389
390 Ok(sum)
391 }
392
393 fn weighted_combination(
395 &self,
396 results: Vec<Array2<f64>>,
397 weights: &[f64],
398 ) -> Result<Array2<f64>> {
399 if results.is_empty() {
400 return Err(SklearsError::InvalidInput(
401 "No results to combine".to_string(),
402 ));
403 }
404
405 let shape = results[0].raw_dim();
406 if !results.iter().all(|r| r.raw_dim() == shape) {
407 return Err(SklearsError::InvalidInput(
408 "All results must have the same shape for weighted combination".to_string(),
409 ));
410 }
411
412 let mut combined = Array2::zeros(shape);
413 for (result, &weight) in results.iter().zip(weights.iter()) {
414 combined += &(result * weight);
415 }
416
417 Ok(combined)
418 }
419}
420
421pub struct StreamingTransformerWrapper {
423 transformer: Box<dyn StreamingTransformer + Send + Sync>,
424 name: String,
425 fitted: bool,
426}
427
428impl StreamingTransformerWrapper {
429 pub fn new<S>(transformer: S, name: String) -> Self
431 where
432 S: StreamingTransformer + Send + Sync + 'static,
433 {
434 Self {
435 transformer: Box::new(transformer),
436 name,
437 fitted: false,
438 }
439 }
440
441 pub fn partial_fit(&mut self, data: &Array2<f64>) -> Result<()> {
443 self.transformer.partial_fit(data).map_err(|e| {
444 SklearsError::InvalidInput(format!("Streaming transformer error: {}", e))
445 })?;
446 self.fitted = true;
447 Ok(())
448 }
449
450 pub fn is_fitted(&self) -> bool {
452 self.fitted && self.transformer.is_fitted()
453 }
454
455 pub fn get_streaming_stats(&self) -> Option<StreamingStats> {
457 Some(self.transformer.get_stats())
458 }
459
460 pub fn reset(&mut self) {
462 self.transformer.reset();
463 self.fitted = false;
464 }
465
466 pub fn name(&self) -> &str {
468 &self.name
469 }
470}
471
472impl Transform<Array2<f64>, Array2<f64>> for StreamingTransformerWrapper {
473 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
474 if !self.is_fitted() {
475 return Err(SklearsError::NotFitted {
476 operation: format!("transform on streaming transformer '{}'", self.name),
477 });
478 }
479 self.transformer
480 .transform(data)
481 .map_err(|e| SklearsError::InvalidInput(e.to_string()))
482 }
483}
484
485impl Clone for StreamingTransformerWrapper {
486 fn clone(&self) -> Self {
487 Self {
490 transformer: Box::new(crate::streaming::StreamingStandardScaler::new(
491 StreamingConfig::default(),
492 )),
493 name: self.name.clone(),
494 fitted: false,
495 }
496 }
497}
498
499pub struct AdvancedPipeline<T> {
501 steps: Vec<PipelineStep<T>>,
502 cache: TransformationCache<Array2<f64>>,
503 config: AdvancedPipelineConfig,
504}
505
506pub enum PipelineStep<T> {
508 Simple(T),
510 Conditional(ConditionalStep<T>),
512 Parallel(ParallelBranches<T>),
514 Cached(T, String), Streaming(StreamingTransformerWrapper),
518}
519
520#[derive(Debug, Clone)]
522#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
523pub struct AdvancedPipelineConfig {
524 pub cache_config: CacheConfig,
526 pub parallel_execution: bool,
528 pub error_strategy: ErrorHandlingStrategy,
530}
531
532#[derive(Debug, Clone, Copy)]
534#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
535pub enum ErrorHandlingStrategy {
536 StopOnError,
538 SkipOnError,
540 Fallback,
542}
543
544impl Default for AdvancedPipelineConfig {
545 fn default() -> Self {
546 Self {
547 cache_config: CacheConfig::default(),
548 parallel_execution: true,
549 error_strategy: ErrorHandlingStrategy::StopOnError,
550 }
551 }
552}
553
554impl<T> AdvancedPipeline<T>
555where
556 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
557{
558 pub fn new(config: AdvancedPipelineConfig) -> Self {
559 Self {
560 steps: Vec::new(),
561 cache: TransformationCache::new(config.cache_config.clone()),
562 config,
563 }
564 }
565
566 pub fn add_step(mut self, transformer: T) -> Self {
568 self.steps.push(PipelineStep::Simple(transformer));
569 self
570 }
571
572 pub fn add_conditional_step(mut self, config: ConditionalStepConfig<T>) -> Self {
574 self.steps
575 .push(PipelineStep::Conditional(ConditionalStep::new(config)));
576 self
577 }
578
579 pub fn add_parallel_branches(mut self, config: ParallelBranchConfig<T>) -> Result<Self> {
581 let branches = ParallelBranches::new(config)?;
582 self.steps.push(PipelineStep::Parallel(branches));
583 Ok(self)
584 }
585
586 pub fn add_cached_step(mut self, transformer: T, cache_key_prefix: String) -> Self {
588 self.steps
589 .push(PipelineStep::Cached(transformer, cache_key_prefix));
590 self
591 }
592
593 pub fn add_streaming_step<S>(mut self, transformer: S, name: String) -> Self
595 where
596 S: StreamingTransformer + Send + Sync + 'static,
597 {
598 let wrapper = StreamingTransformerWrapper::new(transformer, name);
599 self.steps.push(PipelineStep::Streaming(wrapper));
600 self
601 }
602
603 pub fn add_pca_step(self, _pca: crate::dimensionality_reduction::PCA) -> Self {
605 self
608 }
609
610 pub fn cache_stats(&self) -> CacheStats {
612 self.cache.stats()
613 }
614
615 pub fn clear_cache(&self) {
617 self.cache.clear();
618 }
619
620 pub fn partial_fit(&mut self, data: &Array2<f64>) -> Result<()> {
622 let mut current_data = data.clone();
623
624 for step in &mut self.steps {
625 match step {
626 PipelineStep::Streaming(ref mut streaming_wrapper) => {
627 streaming_wrapper.partial_fit(¤t_data)?;
628 if streaming_wrapper.is_fitted() {
630 current_data = streaming_wrapper.transform(¤t_data)?;
631 }
632 }
633 PipelineStep::Simple(transformer) => {
635 if let Ok(transformed) = transformer.transform(¤t_data) {
637 current_data = transformed;
638 }
639 }
640 PipelineStep::Conditional(conditional) => {
641 if let Ok(transformed) = conditional.transform(¤t_data) {
642 current_data = transformed;
643 }
644 }
645 PipelineStep::Parallel(parallel) => {
646 if let Ok(transformed) = parallel.transform(¤t_data) {
647 current_data = transformed;
648 }
649 }
650 PipelineStep::Cached(transformer, _) => {
651 if let Ok(transformed) = transformer.transform(¤t_data) {
652 current_data = transformed;
653 }
654 }
655 }
656 }
657
658 Ok(())
659 }
660
661 pub fn get_streaming_stats(&self) -> Vec<(String, Option<StreamingStats>)> {
663 let mut stats = Vec::new();
664
665 for step in &self.steps {
666 if let PipelineStep::Streaming(streaming_wrapper) = step {
667 stats.push((
668 streaming_wrapper.name().to_string(),
669 streaming_wrapper.get_streaming_stats(),
670 ));
671 }
672 }
673
674 stats
675 }
676
677 pub fn reset_streaming(&mut self) {
679 for step in &mut self.steps {
680 if let PipelineStep::Streaming(ref mut streaming_wrapper) = step {
681 streaming_wrapper.reset();
682 }
683 }
684 }
685}
686
687impl<T> Transform<Array2<f64>, Array2<f64>> for AdvancedPipeline<T>
688where
689 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
690{
691 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
692 let mut current_data = data.clone();
693 for (step_idx, step) in self.steps.iter().enumerate() {
694 let step_result = match step {
695 PipelineStep::Simple(transformer) => transformer.transform(¤t_data),
696 PipelineStep::Conditional(conditional) => conditional.transform(¤t_data),
697 PipelineStep::Parallel(parallel) => parallel.transform(¤t_data),
698 PipelineStep::Cached(transformer, _cache_key_prefix) => {
699 transformer.transform(¤t_data)
702 }
703 PipelineStep::Streaming(streaming_wrapper) => {
704 streaming_wrapper.transform(¤t_data)
705 }
706 };
707
708 match step_result {
710 Ok(result) => {
711 current_data = result;
712 }
713 Err(e) => {
714 match self.config.error_strategy {
715 ErrorHandlingStrategy::StopOnError => return Err(e),
716 ErrorHandlingStrategy::SkipOnError => {
717 eprintln!("Warning: Step {} failed: {}. Skipping...", step_idx, e);
719 }
721 ErrorHandlingStrategy::Fallback => {
722 eprintln!(
725 "Warning: Step {} failed: {}. Using fallback (passthrough)...",
726 step_idx, e
727 );
728 }
729 }
730 }
731 }
732 }
733
734 Ok(current_data)
735 }
736}
737
738pub struct AdvancedPipelineBuilder<T> {
740 config: AdvancedPipelineConfig,
741 pipeline: AdvancedPipeline<T>,
742}
743
744impl<T> AdvancedPipelineBuilder<T>
745where
746 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
747{
748 pub fn new() -> Self {
749 let config = AdvancedPipelineConfig::default();
750 let pipeline = AdvancedPipeline::new(config.clone());
751 Self { config, pipeline }
752 }
753
754 pub fn with_cache_config(mut self, cache_config: CacheConfig) -> Self {
755 self.config.cache_config = cache_config;
756 self.pipeline.cache = TransformationCache::new(self.config.cache_config.clone());
757 self
758 }
759
760 pub fn with_error_strategy(mut self, strategy: ErrorHandlingStrategy) -> Self {
761 self.config.error_strategy = strategy;
762 self.pipeline.config.error_strategy = strategy;
763 self
764 }
765
766 pub fn add_step(mut self, transformer: T) -> Self {
767 self.pipeline = self.pipeline.add_step(transformer);
768 self
769 }
770
771 pub fn add_conditional_step(mut self, config: ConditionalStepConfig<T>) -> Self {
772 self.pipeline = self.pipeline.add_conditional_step(config);
773 self
774 }
775
776 pub fn add_parallel_branches(mut self, config: ParallelBranchConfig<T>) -> Result<Self> {
777 self.pipeline = self.pipeline.add_parallel_branches(config)?;
778 Ok(self)
779 }
780
781 pub fn add_cached_step(mut self, transformer: T, cache_key_prefix: String) -> Self {
782 self.pipeline = self.pipeline.add_cached_step(transformer, cache_key_prefix);
783 self
784 }
785
786 pub fn add_streaming_step<S>(mut self, transformer: S, name: String) -> Self
787 where
788 S: StreamingTransformer + Send + Sync + 'static,
789 {
790 self.pipeline = self.pipeline.add_streaming_step(transformer, name);
791 self
792 }
793
794 pub fn build(self) -> AdvancedPipeline<T> {
795 self.pipeline
796 }
797}
798
799impl<T> Default for AdvancedPipelineBuilder<T>
800where
801 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
802{
803 fn default() -> Self {
804 Self::new()
805 }
806}
807
808pub struct DynamicPipeline<T> {
810 steps: Arc<RwLock<Vec<PipelineStep<T>>>>,
811 cache: TransformationCache<Array2<f64>>,
812 config: AdvancedPipelineConfig,
813}
814
815impl<T> DynamicPipeline<T>
816where
817 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
818{
819 pub fn new(config: AdvancedPipelineConfig) -> Self {
820 Self {
821 steps: Arc::new(RwLock::new(Vec::new())),
822 cache: TransformationCache::new(config.cache_config.clone()),
823 config,
824 }
825 }
826
827 pub fn add_step_runtime(&self, transformer: T) -> Result<()> {
829 let mut steps = self
830 .steps
831 .write()
832 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
833 steps.push(PipelineStep::Simple(transformer));
834 Ok(())
835 }
836
837 pub fn add_streaming_step_runtime<S>(&self, transformer: S, name: String) -> Result<()>
839 where
840 S: StreamingTransformer + Send + Sync + 'static,
841 {
842 let mut steps = self
843 .steps
844 .write()
845 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
846 let wrapper = StreamingTransformerWrapper::new(transformer, name);
847 steps.push(PipelineStep::Streaming(wrapper));
848 Ok(())
849 }
850
851 pub fn remove_step(&self, index: usize) -> Result<()> {
853 let mut steps = self
854 .steps
855 .write()
856 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
857
858 if index >= steps.len() {
859 return Err(SklearsError::InvalidInput(
860 "Step index out of bounds".to_string(),
861 ));
862 }
863
864 steps.remove(index);
865 Ok(())
866 }
867
868 pub fn len(&self) -> usize {
870 self.steps.read().map(|s| s.len()).unwrap_or(0)
871 }
872
873 pub fn is_empty(&self) -> bool {
875 self.len() == 0
876 }
877
878 pub fn partial_fit(&self, data: &Array2<f64>) -> Result<()> {
880 let mut current_data = data.clone();
881 let mut steps = self
882 .steps
883 .write()
884 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
885
886 for step in steps.iter_mut() {
887 match step {
888 PipelineStep::Streaming(ref mut streaming_wrapper) => {
889 streaming_wrapper.partial_fit(¤t_data)?;
890 if streaming_wrapper.is_fitted() {
892 current_data = streaming_wrapper.transform(¤t_data)?;
893 }
894 }
895 PipelineStep::Simple(transformer) => {
897 if let Ok(transformed) = transformer.transform(¤t_data) {
898 current_data = transformed;
899 }
900 }
901 PipelineStep::Conditional(conditional) => {
902 if let Ok(transformed) = conditional.transform(¤t_data) {
903 current_data = transformed;
904 }
905 }
906 PipelineStep::Parallel(parallel) => {
907 if let Ok(transformed) = parallel.transform(¤t_data) {
908 current_data = transformed;
909 }
910 }
911 PipelineStep::Cached(transformer, _) => {
912 if let Ok(transformed) = transformer.transform(¤t_data) {
913 current_data = transformed;
914 }
915 }
916 }
917 }
918
919 Ok(())
920 }
921
922 pub fn get_streaming_stats(&self) -> Result<Vec<(String, Option<StreamingStats>)>> {
924 let mut stats = Vec::new();
925 let steps = self
926 .steps
927 .read()
928 .map_err(|_| SklearsError::InvalidInput("Failed to acquire read lock".to_string()))?;
929
930 for step in steps.iter() {
931 if let PipelineStep::Streaming(streaming_wrapper) = step {
932 stats.push((
933 streaming_wrapper.name().to_string(),
934 streaming_wrapper.get_streaming_stats(),
935 ));
936 }
937 }
938
939 Ok(stats)
940 }
941
942 pub fn reset_streaming(&self) -> Result<()> {
944 let mut steps = self
945 .steps
946 .write()
947 .map_err(|_| SklearsError::InvalidInput("Failed to acquire write lock".to_string()))?;
948
949 for step in steps.iter_mut() {
950 if let PipelineStep::Streaming(ref mut streaming_wrapper) = step {
951 streaming_wrapper.reset();
952 }
953 }
954
955 Ok(())
956 }
957}
958
959impl<T> Transform<Array2<f64>, Array2<f64>> for DynamicPipeline<T>
960where
961 T: Transform<Array2<f64>, Array2<f64>> + Clone + Send + Sync,
962{
963 fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
964 let mut current_data = data.clone();
965 let steps = self
966 .steps
967 .read()
968 .map_err(|_| SklearsError::InvalidInput("Failed to acquire read lock".to_string()))?;
969
970 for (step_idx, step) in steps.iter().enumerate() {
971 let step_result = match step {
972 PipelineStep::Simple(transformer) => transformer.transform(¤t_data),
973 PipelineStep::Conditional(conditional) => conditional.transform(¤t_data),
974 PipelineStep::Parallel(parallel) => parallel.transform(¤t_data),
975 PipelineStep::Cached(transformer, _cache_key_prefix) => {
976 transformer.transform(¤t_data)
978 }
979 PipelineStep::Streaming(streaming_wrapper) => {
980 streaming_wrapper.transform(¤t_data)
981 }
982 };
983
984 match step_result {
985 Ok(result) => {
986 current_data = result;
987 }
988 Err(e) => match self.config.error_strategy {
989 ErrorHandlingStrategy::StopOnError => return Err(e),
990 ErrorHandlingStrategy::SkipOnError => {
991 eprintln!("Warning: Step {} failed: {}. Skipping...", step_idx, e);
992 }
993 ErrorHandlingStrategy::Fallback => {
994 eprintln!(
995 "Warning: Step {} failed: {}. Using fallback (passthrough)...",
996 step_idx, e
997 );
998 }
999 },
1000 }
1001 }
1002
1003 Ok(current_data)
1004 }
1005}
1006
1007#[allow(non_snake_case)]
1008#[cfg(test)]
1009mod tests {
1010 use super::*;
1011 use scirs2_core::ndarray::arr2;
1012
1013 #[test]
1014 fn test_transformation_cache() {
1015 let config = CacheConfig {
1016 max_entries: 2,
1017 ttl_seconds: 1,
1018 enabled: true,
1019 };
1020
1021 let cache = TransformationCache::new(config);
1022 let data = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
1023
1024 let key = cache.generate_key("test_key");
1026 assert!(cache.get(key).is_none());
1027
1028 cache.put(key, data.clone());
1030 assert!(cache.get(key).is_some());
1031
1032 let stats = cache.stats();
1034 assert_eq!(stats.entries, 1);
1035 assert!(stats.enabled);
1036 }
1037
1038 #[test]
1090 fn test_streaming_transformer_wrapper() {
1091 use crate::streaming::{StreamingConfig, StreamingStandardScaler};
1092 use scirs2_core::ndarray::Array2;
1093
1094 let scaler = StreamingStandardScaler::new(StreamingConfig::default());
1095 let mut wrapper = StreamingTransformerWrapper::new(scaler, "test_scaler".to_string());
1096
1097 assert!(!wrapper.is_fitted());
1099
1100 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1102 .expect("shape and data length should match");
1103 wrapper
1104 .partial_fit(&data)
1105 .expect("operation should succeed");
1106
1107 assert!(wrapper.is_fitted());
1109
1110 let result = wrapper
1112 .transform(&data)
1113 .expect("transformation should succeed");
1114 assert_eq!(result.dim(), data.dim());
1115
1116 let stats = wrapper.get_streaming_stats();
1118 assert!(stats.is_some());
1119
1120 assert_eq!(wrapper.name(), "test_scaler");
1122
1123 wrapper.reset();
1125 assert!(!wrapper.is_fitted());
1126 }
1127
1128 }