1#![allow(dead_code)]
4#![allow(missing_docs)]
5
6use super::*;
7use crate::csv::{read_csv, write_csv};
8use crate::error::Result;
9use scirs2_core::ndarray::Array2;
10use std::fs::File;
11use std::path::{Path, PathBuf};
12use std::time::Duration;
13
14pub struct FileReadStage {
16 path: PathBuf,
17 format: FileFormat,
18}
19
20#[derive(Debug, Clone)]
21pub enum FileFormat {
22 Csv,
23 Json,
24 Binary,
25 Text,
26 Auto,
27}
28
29impl FileReadStage {
30 pub fn new(path: impl AsRef<Path>, format: FileFormat) -> Self {
31 Self {
32 path: path.as_ref().to_path_buf(),
33 format,
34 }
35 }
36}
37
38impl PipelineStage for FileReadStage {
39 fn execute(
40 &self,
41 mut input: PipelineData<Box<dyn Any + Send + Sync>>,
42 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
43 let data = match self.format {
44 FileFormat::Csv => {
45 let data = read_csv(&self.path, None)?;
46 Box::new(data) as Box<dyn Any + Send + Sync>
47 }
48 FileFormat::Json => {
49 let file = File::open(&self.path).map_err(IoError::Io)?;
50 let value: serde_json::Value = serde_json::from_reader(file)
51 .map_err(|e| IoError::SerializationError(e.to_string()))?;
52 Box::new(value) as Box<dyn Any + Send + Sync>
53 }
54 FileFormat::Binary => {
55 let data = std::fs::read(&self.path).map_err(IoError::Io)?;
56 Box::new(data) as Box<dyn Any + Send + Sync>
57 }
58 FileFormat::Text => {
59 let data = std::fs::read_to_string(&self.path).map_err(IoError::Io)?;
60 Box::new(data) as Box<dyn Any + Send + Sync>
61 }
62 FileFormat::Auto => {
63 let extension = self
65 .path
66 .extension()
67 .and_then(|ext| ext.to_str())
68 .unwrap_or("");
69
70 match extension.to_lowercase().as_str() {
71 "csv" => {
72 let data = read_csv(&self.path, None)?;
73 Box::new(data) as Box<dyn Any + Send + Sync>
74 }
75 "json" => {
76 let file = File::open(&self.path).map_err(IoError::Io)?;
77 let value: serde_json::Value = serde_json::from_reader(file)
78 .map_err(|e| IoError::SerializationError(e.to_string()))?;
79 Box::new(value) as Box<dyn Any + Send + Sync>
80 }
81 "txt" | "text" => {
82 let data = std::fs::read_to_string(&self.path).map_err(IoError::Io)?;
83 Box::new(data) as Box<dyn Any + Send + Sync>
84 }
85 _ => {
86 let data = std::fs::read(&self.path).map_err(IoError::Io)?;
88 Box::new(data) as Box<dyn Any + Send + Sync>
89 }
90 }
91 }
92 };
93
94 input.data = data;
95 input
96 .metadata
97 .set("source_file", self.path.to_string_lossy().to_string());
98 Ok(input)
99 }
100
101 fn name(&self) -> String {
102 format!("read_{:?}", self.format)
103 }
104
105 fn stage_type(&self) -> String {
106 "input".to_string()
107 }
108}
109
110pub struct FileWriteStage {
112 path: PathBuf,
113 format: FileFormat,
114}
115
116impl FileWriteStage {
117 pub fn new(path: impl AsRef<Path>, format: FileFormat) -> Self {
118 Self {
119 path: path.as_ref().to_path_buf(),
120 format,
121 }
122 }
123}
124
125impl PipelineStage for FileWriteStage {
126 fn execute(
127 &self,
128 input: PipelineData<Box<dyn Any + Send + Sync>>,
129 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
130 match self.format {
131 FileFormat::Csv => {
132 if let Some(data) = input.data.downcast_ref::<Array2<f64>>() {
133 write_csv(&self.path, data, None, None)?;
134 }
135 }
136 FileFormat::Json => {
137 if let Some(value) = input.data.downcast_ref::<serde_json::Value>() {
138 let file = File::create(&self.path).map_err(IoError::Io)?;
139 serde_json::to_writer_pretty(file, value)
140 .map_err(|e| IoError::SerializationError(e.to_string()))?;
141 }
142 }
143 FileFormat::Binary => {
144 if let Some(data) = input.data.downcast_ref::<Vec<u8>>() {
145 std::fs::write(&self.path, data).map_err(IoError::Io)?;
146 }
147 }
148 FileFormat::Text => {
149 if let Some(data) = input.data.downcast_ref::<String>() {
150 std::fs::write(&self.path, data).map_err(IoError::Io)?;
151 }
152 }
153 FileFormat::Auto => {
154 let extension = self
156 .path
157 .extension()
158 .and_then(|ext| ext.to_str())
159 .unwrap_or("");
160
161 match extension.to_lowercase().as_str() {
162 "csv" => {
163 if let Some(data) = input.data.downcast_ref::<Array2<f64>>() {
164 write_csv(&self.path, data, None, None)?;
165 }
166 }
167 "json" => {
168 if let Some(value) = input.data.downcast_ref::<serde_json::Value>() {
169 let file = File::create(&self.path).map_err(IoError::Io)?;
170 serde_json::to_writer_pretty(file, value)
171 .map_err(|e| IoError::SerializationError(e.to_string()))?;
172 }
173 }
174 "txt" | "text" => {
175 if let Some(data) = input.data.downcast_ref::<String>() {
176 std::fs::write(&self.path, data).map_err(IoError::Io)?;
177 }
178 }
179 _ => {
180 if let Some(data) = input.data.downcast_ref::<Vec<u8>>() {
182 std::fs::write(&self.path, data).map_err(IoError::Io)?;
183 }
184 }
185 }
186 }
187 }
188
189 Ok(input)
190 }
191
192 fn name(&self) -> String {
193 format!("write_{:?}", self.format)
194 }
195
196 fn stage_type(&self) -> String {
197 "output".to_string()
198 }
199}
200
201pub struct ValidationStage {
203 validators: Vec<Box<dyn Validator>>,
204}
205
206pub trait Validator: Send + Sync {
207 fn validate(&self, data: &dyn Any) -> Result<()>;
208 fn name(&self) -> &str;
209}
210
211impl Default for ValidationStage {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217impl ValidationStage {
218 pub fn new() -> Self {
219 Self {
220 validators: Vec::new(),
221 }
222 }
223
224 pub fn add_validator(mut self, validator: Box<dyn Validator>) -> Self {
225 self.validators.push(validator);
226 self
227 }
228}
229
230impl PipelineStage for ValidationStage {
231 fn execute(
232 &self,
233 input: PipelineData<Box<dyn Any + Send + Sync>>,
234 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
235 for validator in &self.validators {
236 validator.validate(input.data.as_ref())?;
237 }
238 Ok(input)
239 }
240
241 fn name(&self) -> String {
242 "validation".to_string()
243 }
244
245 fn stage_type(&self) -> String {
246 "validation".to_string()
247 }
248}
249
250pub struct TransformStage {
252 name: String,
253 transformer: Box<dyn DataTransformer>,
254}
255
256pub trait DataTransformer: Send + Sync {
257 fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>>;
258}
259
260impl TransformStage {
261 pub fn new(name: &str, transformer: Box<dyn DataTransformer>) -> Self {
262 Self {
263 name: name.to_string(),
264 transformer,
265 }
266 }
267}
268
269impl PipelineStage for TransformStage {
270 fn execute(
271 &self,
272 mut input: PipelineData<Box<dyn Any + Send + Sync>>,
273 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
274 input.data = self.transformer.transform(input.data)?;
275 Ok(input)
276 }
277
278 fn name(&self) -> String {
279 self.name.clone()
280 }
281
282 fn stage_type(&self) -> String {
283 "transform".to_string()
284 }
285}
286
287pub struct AggregationStage<T> {
289 name: String,
290 aggregator: Box<dyn Fn(Vec<T>) -> Result<T> + Send + Sync>,
291}
292
293impl<T: 'static + Send + Sync> AggregationStage<T> {
294 pub fn new<F>(name: &str, aggregator: F) -> Self
295 where
296 F: Fn(Vec<T>) -> Result<T> + Send + Sync + 'static,
297 {
298 Self {
299 name: name.to_string(),
300 aggregator: Box::new(aggregator),
301 }
302 }
303}
304
305impl<T: 'static + Send + Sync> PipelineStage for AggregationStage<T> {
306 fn execute(
307 &self,
308 mut input: PipelineData<Box<dyn Any + Send + Sync>>,
309 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
310 if let Ok(data) = input.data.downcast::<Vec<T>>() {
311 let aggregated = (self.aggregator)(*data)?;
312 input.data = Box::new(aggregated) as Box<dyn Any + Send + Sync>;
313 Ok(input)
314 } else {
315 Err(IoError::Other(
316 "Type mismatch in aggregation stage".to_string(),
317 ))
318 }
319 }
320
321 fn name(&self) -> String {
322 self.name.clone()
323 }
324
325 fn stage_type(&self) -> String {
326 "aggregation".to_string()
327 }
328}
329
330pub struct FilterStage<T> {
332 name: String,
333 predicate: Box<dyn Fn(&T) -> bool + Send + Sync>,
334}
335
336impl<T: 'static + Send + Sync + Clone> FilterStage<T> {
337 pub fn new<F>(name: &str, predicate: F) -> Self
338 where
339 F: Fn(&T) -> bool + Send + Sync + 'static,
340 {
341 Self {
342 name: name.to_string(),
343 predicate: Box::new(predicate),
344 }
345 }
346}
347
348impl<T: 'static + Send + Sync + Clone> PipelineStage for FilterStage<T> {
349 fn execute(
350 &self,
351 mut input: PipelineData<Box<dyn Any + Send + Sync>>,
352 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
353 if let Ok(data) = input.data.downcast::<Vec<T>>() {
354 let filtered: Vec<T> = data
355 .iter()
356 .filter(|item| (self.predicate)(item))
357 .cloned()
358 .collect();
359 input.data = Box::new(filtered) as Box<dyn Any + Send + Sync>;
360 Ok(input)
361 } else {
362 Err(IoError::Other("Type mismatch in filter stage".to_string()))
363 }
364 }
365
366 fn name(&self) -> String {
367 self.name.clone()
368 }
369
370 fn stage_type(&self) -> String {
371 "filter".to_string()
372 }
373}
374
375pub struct EnrichmentStage {
377 name: String,
378 enricher: Box<dyn DataEnricher>,
379}
380
381pub trait DataEnricher: Send + Sync {
382 fn enrich(&self, data: &mut PipelineData<Box<dyn Any + Send + Sync>>) -> Result<()>;
383}
384
385impl EnrichmentStage {
386 pub fn new(name: &str, enricher: Box<dyn DataEnricher>) -> Self {
387 Self {
388 name: name.to_string(),
389 enricher,
390 }
391 }
392}
393
394impl PipelineStage for EnrichmentStage {
395 fn execute(
396 &self,
397 mut input: PipelineData<Box<dyn Any + Send + Sync>>,
398 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
399 self.enricher.enrich(&mut input)?;
400 Ok(input)
401 }
402
403 fn name(&self) -> String {
404 self.name.clone()
405 }
406
407 fn stage_type(&self) -> String {
408 "enrichment".to_string()
409 }
410}
411
412pub struct CacheStage {
414 cache_key: String,
415 cache_dir: PathBuf,
416}
417
418impl CacheStage {
419 pub fn new(cache_key: &str, cache_dir: impl AsRef<Path>) -> Self {
420 Self {
421 cache_key: cache_key.to_string(),
422 cache_dir: cache_dir.as_ref().to_path_buf(),
423 }
424 }
425}
426
427impl PipelineStage for CacheStage {
428 fn execute(
429 &self,
430 mut input: PipelineData<Box<dyn Any + Send + Sync>>,
431 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
432 std::fs::create_dir_all(&self.cache_dir).map_err(IoError::Io)?;
434
435 let cache_path = self.cache_dir.join(format!("{}.cache", self.cache_key));
436
437 if cache_path.exists() {
439 if let Ok(_cache_data) = std::fs::read(&cache_path) {
441 input.metadata.set("cache_hit", true);
443 input.metadata.set("cache_key", self.cache_key.clone());
444
445 input.context.set("cached_from", self.cache_key.clone());
447
448 return Ok(input);
449 }
450 }
451
452 let cache_marker = format!(
456 "Cache entry for: {}\nCreated: {:?}\n",
457 self.cache_key,
458 chrono::Utc::now()
459 );
460 std::fs::write(&cache_path, cache_marker).map_err(IoError::Io)?;
461
462 input.metadata.set("cache_hit", false);
464 input.metadata.set("cache_key", self.cache_key.clone());
465
466 Ok(input)
467 }
468
469 fn name(&self) -> String {
470 format!("cache_{}", self.cache_key)
471 }
472
473 fn stage_type(&self) -> String {
474 "cache".to_string()
475 }
476}
477
478pub struct MonitoringStage {
480 name: String,
481 monitor: Box<dyn Monitor>,
482}
483
484pub trait Monitor: Send + Sync {
485 fn monitor(&self, data: &PipelineData<Box<dyn Any + Send + Sync>>);
486}
487
488impl MonitoringStage {
489 pub fn new(name: &str, monitor: Box<dyn Monitor>) -> Self {
490 Self {
491 name: name.to_string(),
492 monitor,
493 }
494 }
495}
496
497impl PipelineStage for MonitoringStage {
498 fn execute(
499 &self,
500 input: PipelineData<Box<dyn Any + Send + Sync>>,
501 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
502 self.monitor.monitor(&input);
503 Ok(input)
504 }
505
506 fn name(&self) -> String {
507 self.name.clone()
508 }
509
510 fn stage_type(&self) -> String {
511 "monitoring".to_string()
512 }
513}
514
515pub struct ErrorHandlingStage {
517 name: String,
518 handler: Box<dyn ErrorHandler>,
519}
520
521pub trait ErrorHandler: Send + Sync {
522 fn handle_error(
523 &self,
524 error: IoError,
525 data: PipelineData<Box<dyn Any + Send + Sync>>,
526 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>>;
527}
528
529impl ErrorHandlingStage {
530 pub fn new(name: &str, handler: Box<dyn ErrorHandler>) -> Self {
531 Self {
532 name: name.to_string(),
533 handler,
534 }
535 }
536}
537
538impl PipelineStage for ErrorHandlingStage {
539 fn execute(
540 &self,
541 input: PipelineData<Box<dyn Any + Send + Sync>>,
542 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
543 if let Some(error_msg) = input.context.get::<String>("pipeline_error") {
548 let error = IoError::Other(error_msg);
550
551 self.handler.handle_error(error, input)
553 } else {
554 Ok(input)
556 }
557 }
558
559 fn name(&self) -> String {
560 self.name.clone()
561 }
562
563 fn stage_type(&self) -> String {
564 "error_handling".to_string()
565 }
566}
567
568pub struct RetryErrorHandler {
570 max_retries: usize,
571 retry_delay: Duration,
572}
573
574impl RetryErrorHandler {
575 pub fn new(max_retries: usize) -> Self {
576 Self {
577 max_retries,
578 retry_delay: Duration::from_secs(1),
579 }
580 }
581
582 pub fn with_delay(mut self, delay: Duration) -> Self {
583 self.retry_delay = delay;
584 self
585 }
586}
587
588impl ErrorHandler for RetryErrorHandler {
589 fn handle_error(
590 &self,
591 error: IoError,
592 mut data: PipelineData<Box<dyn Any + Send + Sync>>,
593 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
594 let retry_count = data.context.get::<usize>("retry_count").unwrap_or(0);
596
597 if retry_count < self.max_retries {
598 data.context.set("retry_count", retry_count + 1);
600
601 data.metadata.set("last_error", format!("{:?}", error));
603 data.metadata.set("retry_attempt", (retry_count + 1) as i64);
604
605 data.context.set::<Option<String>>("pipeline_error", None);
607
608 Ok(data)
609 } else {
610 Err(error)
612 }
613 }
614}
615
616pub struct SkipErrorHandler;
618
619impl ErrorHandler for SkipErrorHandler {
620 fn handle_error(
621 &self,
622 _error: IoError,
623 mut data: PipelineData<Box<dyn Any + Send + Sync>>,
624 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
625 data.metadata.set("skipped", true);
627 data.metadata.set("skip_reason", "error_occurred");
628
629 Ok(data)
631 }
632}
633
634pub struct FallbackErrorHandler<T: Any + Send + Sync + Clone + 'static> {
636 fallback_value: T,
637}
638
639impl<T: Any + Send + Sync + Clone + 'static> FallbackErrorHandler<T> {
640 pub fn new(fallback_value: T) -> Self {
641 Self { fallback_value }
642 }
643}
644
645impl<T: Any + Send + Sync + Clone + 'static> ErrorHandler for FallbackErrorHandler<T> {
646 fn handle_error(
647 &self,
648 _error: IoError,
649 mut data: PipelineData<Box<dyn Any + Send + Sync>>,
650 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
651 data.data = Box::new(self.fallback_value.clone());
653 data.metadata.set("used_fallback", true);
654
655 Ok(data)
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 struct SimpleValidator;
664
665 impl Validator for SimpleValidator {
666 fn validate(&self, data: &dyn Any) -> Result<()> {
667 if let Some(nums) = data.downcast_ref::<Vec<i32>>() {
668 if nums.is_empty() {
669 return Err(IoError::ValidationError("Empty data".to_string()));
670 }
671 }
672 Ok(())
673 }
674
675 fn name(&self) -> &str {
676 "simple"
677 }
678 }
679
680 #[test]
681 fn test_validation_stage() {
682 let stage = ValidationStage::new().add_validator(Box::new(SimpleValidator));
683
684 let data = PipelineData::new(Box::new(vec![1, 2, 3]) as Box<dyn Any + Send + Sync>);
685 let result = stage.execute(data);
686 assert!(result.is_ok());
687
688 let empty_data =
689 PipelineData::new(Box::new(vec![] as Vec<i32>) as Box<dyn Any + Send + Sync>);
690 let result = stage.execute(empty_data);
691 assert!(result.is_err());
692 }
693}