scirs2_io/pipeline/
stages.rs

1//! Common pipeline stages for data processing
2
3#![allow(dead_code)]
4#![allow(missing_docs)]
5
6use super::*;
7use crate::csv::{read_csv, write_csv};
8use crate::error::Result;
9use scirs2_core::ndarray::Array2;
10use std::fs::File;
11use std::path::{Path, PathBuf};
12use std::time::Duration;
13
14/// File reading stage
15pub struct FileReadStage {
16    path: PathBuf,
17    format: FileFormat,
18}
19
20#[derive(Debug, Clone)]
21pub enum FileFormat {
22    Csv,
23    Json,
24    Binary,
25    Text,
26    Auto,
27}
28
29impl FileReadStage {
30    pub fn new(path: impl AsRef<Path>, format: FileFormat) -> Self {
31        Self {
32            path: path.as_ref().to_path_buf(),
33            format,
34        }
35    }
36}
37
38impl PipelineStage for FileReadStage {
39    fn execute(
40        &self,
41        mut input: PipelineData<Box<dyn Any + Send + Sync>>,
42    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
43        let data = match self.format {
44            FileFormat::Csv => {
45                let data = read_csv(&self.path, None)?;
46                Box::new(data) as Box<dyn Any + Send + Sync>
47            }
48            FileFormat::Json => {
49                let file = File::open(&self.path).map_err(IoError::Io)?;
50                let value: serde_json::Value = serde_json::from_reader(file)
51                    .map_err(|e| IoError::SerializationError(e.to_string()))?;
52                Box::new(value) as Box<dyn Any + Send + Sync>
53            }
54            FileFormat::Binary => {
55                let data = std::fs::read(&self.path).map_err(IoError::Io)?;
56                Box::new(data) as Box<dyn Any + Send + Sync>
57            }
58            FileFormat::Text => {
59                let data = std::fs::read_to_string(&self.path).map_err(IoError::Io)?;
60                Box::new(data) as Box<dyn Any + Send + Sync>
61            }
62            FileFormat::Auto => {
63                // Auto-detect format based on file extension
64                let extension = self
65                    .path
66                    .extension()
67                    .and_then(|ext| ext.to_str())
68                    .unwrap_or("");
69
70                match extension.to_lowercase().as_str() {
71                    "csv" => {
72                        let data = read_csv(&self.path, None)?;
73                        Box::new(data) as Box<dyn Any + Send + Sync>
74                    }
75                    "json" => {
76                        let file = File::open(&self.path).map_err(IoError::Io)?;
77                        let value: serde_json::Value = serde_json::from_reader(file)
78                            .map_err(|e| IoError::SerializationError(e.to_string()))?;
79                        Box::new(value) as Box<dyn Any + Send + Sync>
80                    }
81                    "txt" | "text" => {
82                        let data = std::fs::read_to_string(&self.path).map_err(IoError::Io)?;
83                        Box::new(data) as Box<dyn Any + Send + Sync>
84                    }
85                    _ => {
86                        // Default to binary for unknown extensions
87                        let data = std::fs::read(&self.path).map_err(IoError::Io)?;
88                        Box::new(data) as Box<dyn Any + Send + Sync>
89                    }
90                }
91            }
92        };
93
94        input.data = data;
95        input
96            .metadata
97            .set("source_file", self.path.to_string_lossy().to_string());
98        Ok(input)
99    }
100
101    fn name(&self) -> String {
102        format!("read_{:?}", self.format)
103    }
104
105    fn stage_type(&self) -> String {
106        "input".to_string()
107    }
108}
109
110/// File writing stage
111pub struct FileWriteStage {
112    path: PathBuf,
113    format: FileFormat,
114}
115
116impl FileWriteStage {
117    pub fn new(path: impl AsRef<Path>, format: FileFormat) -> Self {
118        Self {
119            path: path.as_ref().to_path_buf(),
120            format,
121        }
122    }
123}
124
125impl PipelineStage for FileWriteStage {
126    fn execute(
127        &self,
128        input: PipelineData<Box<dyn Any + Send + Sync>>,
129    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
130        match self.format {
131            FileFormat::Csv => {
132                if let Some(data) = input.data.downcast_ref::<Array2<f64>>() {
133                    write_csv(&self.path, data, None, None)?;
134                }
135            }
136            FileFormat::Json => {
137                if let Some(value) = input.data.downcast_ref::<serde_json::Value>() {
138                    let file = File::create(&self.path).map_err(IoError::Io)?;
139                    serde_json::to_writer_pretty(file, value)
140                        .map_err(|e| IoError::SerializationError(e.to_string()))?;
141                }
142            }
143            FileFormat::Binary => {
144                if let Some(data) = input.data.downcast_ref::<Vec<u8>>() {
145                    std::fs::write(&self.path, data).map_err(IoError::Io)?;
146                }
147            }
148            FileFormat::Text => {
149                if let Some(data) = input.data.downcast_ref::<String>() {
150                    std::fs::write(&self.path, data).map_err(IoError::Io)?;
151                }
152            }
153            FileFormat::Auto => {
154                // Auto-detect format based on file extension
155                let extension = self
156                    .path
157                    .extension()
158                    .and_then(|ext| ext.to_str())
159                    .unwrap_or("");
160
161                match extension.to_lowercase().as_str() {
162                    "csv" => {
163                        if let Some(data) = input.data.downcast_ref::<Array2<f64>>() {
164                            write_csv(&self.path, data, None, None)?;
165                        }
166                    }
167                    "json" => {
168                        if let Some(value) = input.data.downcast_ref::<serde_json::Value>() {
169                            let file = File::create(&self.path).map_err(IoError::Io)?;
170                            serde_json::to_writer_pretty(file, value)
171                                .map_err(|e| IoError::SerializationError(e.to_string()))?;
172                        }
173                    }
174                    "txt" | "text" => {
175                        if let Some(data) = input.data.downcast_ref::<String>() {
176                            std::fs::write(&self.path, data).map_err(IoError::Io)?;
177                        }
178                    }
179                    _ => {
180                        // Default to binary for unknown extensions
181                        if let Some(data) = input.data.downcast_ref::<Vec<u8>>() {
182                            std::fs::write(&self.path, data).map_err(IoError::Io)?;
183                        }
184                    }
185                }
186            }
187        }
188
189        Ok(input)
190    }
191
192    fn name(&self) -> String {
193        format!("write_{:?}", self.format)
194    }
195
196    fn stage_type(&self) -> String {
197        "output".to_string()
198    }
199}
200
201/// Data validation stage
202pub struct ValidationStage {
203    validators: Vec<Box<dyn Validator>>,
204}
205
206pub trait Validator: Send + Sync {
207    fn validate(&self, data: &dyn Any) -> Result<()>;
208    fn name(&self) -> &str;
209}
210
211impl Default for ValidationStage {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl ValidationStage {
218    pub fn new() -> Self {
219        Self {
220            validators: Vec::new(),
221        }
222    }
223
224    pub fn add_validator(mut self, validator: Box<dyn Validator>) -> Self {
225        self.validators.push(validator);
226        self
227    }
228}
229
230impl PipelineStage for ValidationStage {
231    fn execute(
232        &self,
233        input: PipelineData<Box<dyn Any + Send + Sync>>,
234    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
235        for validator in &self.validators {
236            validator.validate(input.data.as_ref())?;
237        }
238        Ok(input)
239    }
240
241    fn name(&self) -> String {
242        "validation".to_string()
243    }
244
245    fn stage_type(&self) -> String {
246        "validation".to_string()
247    }
248}
249
250/// Data transformation stage
251pub struct TransformStage {
252    name: String,
253    transformer: Box<dyn DataTransformer>,
254}
255
256pub trait DataTransformer: Send + Sync {
257    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>>;
258}
259
260impl TransformStage {
261    pub fn new(name: &str, transformer: Box<dyn DataTransformer>) -> Self {
262        Self {
263            name: name.to_string(),
264            transformer,
265        }
266    }
267}
268
269impl PipelineStage for TransformStage {
270    fn execute(
271        &self,
272        mut input: PipelineData<Box<dyn Any + Send + Sync>>,
273    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
274        input.data = self.transformer.transform(input.data)?;
275        Ok(input)
276    }
277
278    fn name(&self) -> String {
279        self.name.clone()
280    }
281
282    fn stage_type(&self) -> String {
283        "transform".to_string()
284    }
285}
286
287/// Aggregation stage
288pub struct AggregationStage<T> {
289    name: String,
290    aggregator: Box<dyn Fn(Vec<T>) -> Result<T> + Send + Sync>,
291}
292
293impl<T: 'static + Send + Sync> AggregationStage<T> {
294    pub fn new<F>(name: &str, aggregator: F) -> Self
295    where
296        F: Fn(Vec<T>) -> Result<T> + Send + Sync + 'static,
297    {
298        Self {
299            name: name.to_string(),
300            aggregator: Box::new(aggregator),
301        }
302    }
303}
304
305impl<T: 'static + Send + Sync> PipelineStage for AggregationStage<T> {
306    fn execute(
307        &self,
308        mut input: PipelineData<Box<dyn Any + Send + Sync>>,
309    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
310        if let Ok(data) = input.data.downcast::<Vec<T>>() {
311            let aggregated = (self.aggregator)(*data)?;
312            input.data = Box::new(aggregated) as Box<dyn Any + Send + Sync>;
313            Ok(input)
314        } else {
315            Err(IoError::Other(
316                "Type mismatch in aggregation stage".to_string(),
317            ))
318        }
319    }
320
321    fn name(&self) -> String {
322        self.name.clone()
323    }
324
325    fn stage_type(&self) -> String {
326        "aggregation".to_string()
327    }
328}
329
330/// Filtering stage
331pub struct FilterStage<T> {
332    name: String,
333    predicate: Box<dyn Fn(&T) -> bool + Send + Sync>,
334}
335
336impl<T: 'static + Send + Sync + Clone> FilterStage<T> {
337    pub fn new<F>(name: &str, predicate: F) -> Self
338    where
339        F: Fn(&T) -> bool + Send + Sync + 'static,
340    {
341        Self {
342            name: name.to_string(),
343            predicate: Box::new(predicate),
344        }
345    }
346}
347
348impl<T: 'static + Send + Sync + Clone> PipelineStage for FilterStage<T> {
349    fn execute(
350        &self,
351        mut input: PipelineData<Box<dyn Any + Send + Sync>>,
352    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
353        if let Ok(data) = input.data.downcast::<Vec<T>>() {
354            let filtered: Vec<T> = data
355                .iter()
356                .filter(|item| (self.predicate)(item))
357                .cloned()
358                .collect();
359            input.data = Box::new(filtered) as Box<dyn Any + Send + Sync>;
360            Ok(input)
361        } else {
362            Err(IoError::Other("Type mismatch in filter stage".to_string()))
363        }
364    }
365
366    fn name(&self) -> String {
367        self.name.clone()
368    }
369
370    fn stage_type(&self) -> String {
371        "filter".to_string()
372    }
373}
374
375/// Enrichment stage - adds metadata or augments data
376pub struct EnrichmentStage {
377    name: String,
378    enricher: Box<dyn DataEnricher>,
379}
380
381pub trait DataEnricher: Send + Sync {
382    fn enrich(&self, data: &mut PipelineData<Box<dyn Any + Send + Sync>>) -> Result<()>;
383}
384
385impl EnrichmentStage {
386    pub fn new(name: &str, enricher: Box<dyn DataEnricher>) -> Self {
387        Self {
388            name: name.to_string(),
389            enricher,
390        }
391    }
392}
393
394impl PipelineStage for EnrichmentStage {
395    fn execute(
396        &self,
397        mut input: PipelineData<Box<dyn Any + Send + Sync>>,
398    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
399        self.enricher.enrich(&mut input)?;
400        Ok(input)
401    }
402
403    fn name(&self) -> String {
404        self.name.clone()
405    }
406
407    fn stage_type(&self) -> String {
408        "enrichment".to_string()
409    }
410}
411
412/// Cache stage - caches intermediate results
413pub struct CacheStage {
414    cache_key: String,
415    cache_dir: PathBuf,
416}
417
418impl CacheStage {
419    pub fn new(cache_key: &str, cache_dir: impl AsRef<Path>) -> Self {
420        Self {
421            cache_key: cache_key.to_string(),
422            cache_dir: cache_dir.as_ref().to_path_buf(),
423        }
424    }
425}
426
427impl PipelineStage for CacheStage {
428    fn execute(
429        &self,
430        mut input: PipelineData<Box<dyn Any + Send + Sync>>,
431    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
432        // Create cache directory if needed
433        std::fs::create_dir_all(&self.cache_dir).map_err(IoError::Io)?;
434
435        let cache_path = self.cache_dir.join(format!("{}.cache", self.cache_key));
436
437        // Check if cache exists
438        if cache_path.exists() {
439            // Try to load from cache
440            if let Ok(_cache_data) = std::fs::read(&cache_path) {
441                // Update metadata to indicate cache hit
442                input.metadata.set("cache_hit", true);
443                input.metadata.set("cache_key", self.cache_key.clone());
444
445                // For demonstration, we'll store a simple flag in context
446                input.context.set("cached_from", self.cache_key.clone());
447
448                return Ok(input);
449            }
450        }
451
452        // Cache miss - save data for future use
453        // Note: In a real implementation, we would serialize the actual data
454        // For now, we'll just create a marker file
455        let cache_marker = format!(
456            "Cache entry for: {}\nCreated: {:?}\n",
457            self.cache_key,
458            chrono::Utc::now()
459        );
460        std::fs::write(&cache_path, cache_marker).map_err(IoError::Io)?;
461
462        // Update metadata
463        input.metadata.set("cache_hit", false);
464        input.metadata.set("cache_key", self.cache_key.clone());
465
466        Ok(input)
467    }
468
469    fn name(&self) -> String {
470        format!("cache_{}", self.cache_key)
471    }
472
473    fn stage_type(&self) -> String {
474        "cache".to_string()
475    }
476}
477
478/// Monitoring stage - logs metrics and progress
479pub struct MonitoringStage {
480    name: String,
481    monitor: Box<dyn Monitor>,
482}
483
484pub trait Monitor: Send + Sync {
485    fn monitor(&self, data: &PipelineData<Box<dyn Any + Send + Sync>>);
486}
487
488impl MonitoringStage {
489    pub fn new(name: &str, monitor: Box<dyn Monitor>) -> Self {
490        Self {
491            name: name.to_string(),
492            monitor,
493        }
494    }
495}
496
497impl PipelineStage for MonitoringStage {
498    fn execute(
499        &self,
500        input: PipelineData<Box<dyn Any + Send + Sync>>,
501    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
502        self.monitor.monitor(&input);
503        Ok(input)
504    }
505
506    fn name(&self) -> String {
507        self.name.clone()
508    }
509
510    fn stage_type(&self) -> String {
511        "monitoring".to_string()
512    }
513}
514
515/// Error handling stage - catches and handles errors
516pub struct ErrorHandlingStage {
517    name: String,
518    handler: Box<dyn ErrorHandler>,
519}
520
521pub trait ErrorHandler: Send + Sync {
522    fn handle_error(
523        &self,
524        error: IoError,
525        data: PipelineData<Box<dyn Any + Send + Sync>>,
526    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>>;
527}
528
529impl ErrorHandlingStage {
530    pub fn new(name: &str, handler: Box<dyn ErrorHandler>) -> Self {
531        Self {
532            name: name.to_string(),
533            handler,
534        }
535    }
536}
537
538impl PipelineStage for ErrorHandlingStage {
539    fn execute(
540        &self,
541        input: PipelineData<Box<dyn Any + Send + Sync>>,
542    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
543        // In a real pipeline, this would wrap the next stage's execution
544        // For now, we'll simulate error handling by checking context for errors
545
546        // Check if there's an error flag in the context
547        if let Some(error_msg) = input.context.get::<String>("pipeline_error") {
548            // Create an error from the message
549            let error = IoError::Other(error_msg);
550
551            // Let the handler decide what to do
552            self.handler.handle_error(error, input)
553        } else {
554            // No error, pass through
555            Ok(input)
556        }
557    }
558
559    fn name(&self) -> String {
560        self.name.clone()
561    }
562
563    fn stage_type(&self) -> String {
564        "error_handling".to_string()
565    }
566}
567
568/// Default error handler that logs and retries
569pub struct RetryErrorHandler {
570    max_retries: usize,
571    retry_delay: Duration,
572}
573
574impl RetryErrorHandler {
575    pub fn new(max_retries: usize) -> Self {
576        Self {
577            max_retries,
578            retry_delay: Duration::from_secs(1),
579        }
580    }
581
582    pub fn with_delay(mut self, delay: Duration) -> Self {
583        self.retry_delay = delay;
584        self
585    }
586}
587
588impl ErrorHandler for RetryErrorHandler {
589    fn handle_error(
590        &self,
591        error: IoError,
592        mut data: PipelineData<Box<dyn Any + Send + Sync>>,
593    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
594        // Get current retry count
595        let retry_count = data.context.get::<usize>("retry_count").unwrap_or(0);
596
597        if retry_count < self.max_retries {
598            // Increment retry count
599            data.context.set("retry_count", retry_count + 1);
600
601            // Log retry attempt
602            data.metadata.set("last_error", format!("{:?}", error));
603            data.metadata.set("retry_attempt", (retry_count + 1) as i64);
604
605            // Clear error flag to retry
606            data.context.set::<Option<String>>("pipeline_error", None);
607
608            Ok(data)
609        } else {
610            // Max retries exceeded
611            Err(error)
612        }
613    }
614}
615
616/// Skip error handler that continues on error
617pub struct SkipErrorHandler;
618
619impl ErrorHandler for SkipErrorHandler {
620    fn handle_error(
621        &self,
622        _error: IoError,
623        mut data: PipelineData<Box<dyn Any + Send + Sync>>,
624    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
625        // Mark as skipped in metadata
626        data.metadata.set("skipped", true);
627        data.metadata.set("skip_reason", "error_occurred");
628
629        // Continue processing
630        Ok(data)
631    }
632}
633
634/// Fallback error handler that provides default values
635pub struct FallbackErrorHandler<T: Any + Send + Sync + Clone + 'static> {
636    fallback_value: T,
637}
638
639impl<T: Any + Send + Sync + Clone + 'static> FallbackErrorHandler<T> {
640    pub fn new(fallback_value: T) -> Self {
641        Self { fallback_value }
642    }
643}
644
645impl<T: Any + Send + Sync + Clone + 'static> ErrorHandler for FallbackErrorHandler<T> {
646    fn handle_error(
647        &self,
648        _error: IoError,
649        mut data: PipelineData<Box<dyn Any + Send + Sync>>,
650    ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
651        // Replace data with fallback value
652        data.data = Box::new(self.fallback_value.clone());
653        data.metadata.set("used_fallback", true);
654
655        Ok(data)
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    struct SimpleValidator;
664
665    impl Validator for SimpleValidator {
666        fn validate(&self, data: &dyn Any) -> Result<()> {
667            if let Some(nums) = data.downcast_ref::<Vec<i32>>() {
668                if nums.is_empty() {
669                    return Err(IoError::ValidationError("Empty data".to_string()));
670                }
671            }
672            Ok(())
673        }
674
675        fn name(&self) -> &str {
676            "simple"
677        }
678    }
679
680    #[test]
681    fn test_validation_stage() {
682        let stage = ValidationStage::new().add_validator(Box::new(SimpleValidator));
683
684        let data = PipelineData::new(Box::new(vec![1, 2, 3]) as Box<dyn Any + Send + Sync>);
685        let result = stage.execute(data);
686        assert!(result.is_ok());
687
688        let empty_data =
689            PipelineData::new(Box::new(vec![] as Vec<i32>) as Box<dyn Any + Send + Sync>);
690        let result = stage.execute(empty_data);
691        assert!(result.is_err());
692    }
693}