sklears_utils/
data_pipeline.rs

1//! Data pipeline utilities for ML workflows
2//!
3//! This module provides comprehensive data pipeline functionality for machine learning
4//! workflows, including data transformations, validation, caching, and orchestration.
5
6use crate::UtilsError;
7use scirs2_core::ndarray::{s, Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11use std::sync::{Arc, Mutex, RwLock};
12use std::time::{Duration, Instant};
13
14/// Represents a step in the data pipeline
15pub trait PipelineStep: Send + Sync {
16    type Input;
17    type Output;
18    type Error;
19
20    fn process(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
21    fn name(&self) -> &str;
22    fn description(&self) -> Option<&str> {
23        None
24    }
25}
26
27/// A transformation function that can be applied to data
28pub type TransformFn<T, U> = Box<dyn Fn(T) -> Result<U, UtilsError> + Send + Sync>;
29
30/// Generic pipeline step that applies a transformation function
31pub struct TransformStep<T, U> {
32    name: String,
33    description: Option<String>,
34    transform_fn: TransformFn<T, U>,
35}
36
37impl<T, U> TransformStep<T, U> {
38    pub fn new(name: String, transform_fn: TransformFn<T, U>) -> Self {
39        Self {
40            name,
41            description: None,
42            transform_fn,
43        }
44    }
45
46    pub fn with_description(mut self, description: String) -> Self {
47        self.description = Some(description);
48        self
49    }
50}
51
52impl<T, U> PipelineStep for TransformStep<T, U>
53where
54    T: Send + Sync,
55    U: Send + Sync,
56{
57    type Input = T;
58    type Output = U;
59    type Error = UtilsError;
60
61    fn process(&self, input: T) -> Result<U, UtilsError> {
62        (self.transform_fn)(input)
63    }
64
65    fn name(&self) -> &str {
66        &self.name
67    }
68
69    fn description(&self) -> Option<&str> {
70        self.description.as_deref()
71    }
72}
73
74/// Pipeline execution context with metadata and caching
75#[derive(Debug, Clone)]
76pub struct PipelineContext {
77    pub metadata: HashMap<String, String>,
78    pub start_time: Instant,
79    cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
80}
81
82impl Default for PipelineContext {
83    fn default() -> Self {
84        Self {
85            metadata: HashMap::new(),
86            start_time: Instant::now(),
87            cache: Arc::new(RwLock::new(HashMap::new())),
88        }
89    }
90}
91
92impl PipelineContext {
93    pub fn new() -> Self {
94        Self::default()
95    }
96
97    pub fn with_metadata(mut self, key: String, value: String) -> Self {
98        self.metadata.insert(key, value);
99        self
100    }
101
102    pub fn elapsed(&self) -> Duration {
103        self.start_time.elapsed()
104    }
105
106    pub fn cache_get(&self, key: &str) -> Option<Vec<u8>> {
107        self.cache.read().ok()?.get(key).cloned()
108    }
109
110    pub fn cache_set(&self, key: String, value: Vec<u8>) {
111        if let Ok(mut cache) = self.cache.write() {
112            cache.insert(key, value);
113        }
114    }
115
116    pub fn cache_clear(&self) {
117        if let Ok(mut cache) = self.cache.write() {
118            cache.clear();
119        }
120    }
121}
122
123/// Result of pipeline execution with timing and metadata
124#[derive(Debug, Clone)]
125pub struct PipelineResult<T> {
126    pub data: T,
127    pub execution_time: Duration,
128    pub steps_executed: Vec<String>,
129    pub metadata: HashMap<String, String>,
130}
131
132impl<T> PipelineResult<T> {
133    pub fn new(data: T, execution_time: Duration, steps_executed: Vec<String>) -> Self {
134        Self {
135            data,
136            execution_time,
137            steps_executed,
138            metadata: HashMap::new(),
139        }
140    }
141
142    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
143        self.metadata = metadata;
144        self
145    }
146}
147
148/// Data pipeline orchestrator
149pub struct DataPipeline<T> {
150    steps: Vec<Box<dyn PipelineStep<Input = T, Output = T, Error = UtilsError>>>,
151    context: PipelineContext,
152    validation_enabled: bool,
153    parallel_execution: bool,
154}
155
156impl<T> Default for DataPipeline<T>
157where
158    T: Clone + Send + Sync + 'static,
159{
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165impl<T> DataPipeline<T>
166where
167    T: Clone + Send + Sync + 'static,
168{
169    pub fn new() -> Self {
170        Self {
171            steps: Vec::new(),
172            context: PipelineContext::new(),
173            validation_enabled: true,
174            parallel_execution: false,
175        }
176    }
177
178    pub fn with_context(mut self, context: PipelineContext) -> Self {
179        self.context = context;
180        self
181    }
182
183    pub fn enable_validation(mut self, enabled: bool) -> Self {
184        self.validation_enabled = enabled;
185        self
186    }
187
188    pub fn enable_parallel_execution(mut self, enabled: bool) -> Self {
189        self.parallel_execution = enabled;
190        self
191    }
192
193    pub fn add_step(
194        mut self,
195        step: Box<dyn PipelineStep<Input = T, Output = T, Error = UtilsError>>,
196    ) -> Self {
197        self.steps.push(step);
198        self
199    }
200
201    pub fn add_transform<F>(self, name: String, transform_fn: F) -> Self
202    where
203        F: Fn(T) -> Result<T, UtilsError> + Send + Sync + 'static,
204    {
205        let step = TransformStep::new(name, Box::new(transform_fn));
206        self.add_step(Box::new(step))
207    }
208
209    pub fn execute(&self, mut data: T) -> Result<PipelineResult<T>, UtilsError> {
210        let start_time = Instant::now();
211        let mut steps_executed = Vec::new();
212
213        for step in &self.steps {
214            let step_start = Instant::now();
215
216            data = step.process(data).map_err(|e| {
217                UtilsError::InvalidParameter(format!(
218                    "Pipeline step '{}' failed: {}",
219                    step.name(),
220                    e
221                ))
222            })?;
223
224            steps_executed.push(format!(
225                "{} ({}ms)",
226                step.name(),
227                step_start.elapsed().as_millis()
228            ));
229        }
230
231        let execution_time = start_time.elapsed();
232        Ok(PipelineResult::new(data, execution_time, steps_executed)
233            .with_metadata(self.context.metadata.clone()))
234    }
235}
236
237/// Builder for creating common ML data pipelines
238pub struct MLPipelineBuilder;
239
240impl MLPipelineBuilder {
241    /// Create a basic data cleaning pipeline
242    pub fn data_cleaning() -> DataPipeline<Array2<f64>> {
243        DataPipeline::new()
244            .add_transform("remove_duplicates".to_string(), |data: Array2<f64>| {
245                // Simple duplicate removal by checking if consecutive rows are identical
246                let mut unique_rows = Vec::new();
247                let mut prev_row: Option<Array1<f64>> = None;
248
249                for row in data.rows() {
250                    let current_row = row.to_owned();
251                    if prev_row.as_ref() != Some(&current_row) {
252                        unique_rows.push(current_row.clone());
253                    }
254                    prev_row = Some(current_row.clone());
255                }
256
257                if unique_rows.is_empty() {
258                    return Err(UtilsError::EmptyInput);
259                }
260
261                let n_cols = unique_rows[0].len();
262                let mut result = Array2::zeros((unique_rows.len(), n_cols));
263                for (i, row) in unique_rows.iter().enumerate() {
264                    result.row_mut(i).assign(row);
265                }
266                Ok(result)
267            })
268            .add_transform("handle_missing_values".to_string(), |mut data| {
269                // Replace NaN values with column means
270                for mut col in data.columns_mut() {
271                    let valid_values: Vec<f64> =
272                        col.iter().filter(|&&x| x.is_finite()).copied().collect();
273
274                    if !valid_values.is_empty() {
275                        let mean = valid_values.iter().sum::<f64>() / valid_values.len() as f64;
276                        for val in col.iter_mut() {
277                            if !val.is_finite() {
278                                *val = mean;
279                            }
280                        }
281                    }
282                }
283                Ok(data)
284            })
285            .add_transform("normalize_data".to_string(), |mut data| {
286                // Z-score normalization
287                for mut col in data.columns_mut() {
288                    let mean = col.mean().unwrap_or(0.0);
289                    let std = {
290                        let variance =
291                            col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / col.len() as f64;
292                        variance.sqrt()
293                    };
294
295                    if std > 1e-10 {
296                        for val in col.iter_mut() {
297                            *val = (*val - mean) / std;
298                        }
299                    }
300                }
301                Ok(data)
302            })
303    }
304
305    /// Create a feature engineering pipeline
306    pub fn feature_engineering() -> DataPipeline<Array2<f64>> {
307        DataPipeline::new()
308            .add_transform(
309                "add_polynomial_features".to_string(),
310                |data: Array2<f64>| {
311                    let (n_rows, n_cols) = data.dim();
312                    let mut result = Array2::zeros((n_rows, n_cols + n_cols * (n_cols - 1) / 2));
313
314                    // Copy original features
315                    result.slice_mut(s![.., ..n_cols]).assign(&data);
316
317                    // Add polynomial features (interactions)
318                    let mut col_idx = n_cols;
319                    for i in 0..n_cols {
320                        for j in (i + 1)..n_cols {
321                            for row in 0..n_rows {
322                                result[[row, col_idx]] = data[[row, i]] * data[[row, j]];
323                            }
324                            col_idx += 1;
325                        }
326                    }
327
328                    Ok(result)
329                },
330            )
331            .add_transform("add_statistical_features".to_string(), |data| {
332                let (n_rows, n_cols) = data.dim();
333                let mut result = Array2::zeros((n_rows, n_cols + 3)); // mean, std, range per row
334
335                // Copy original features
336                result.slice_mut(s![.., ..n_cols]).assign(&data);
337
338                // Add statistical features
339                for (i, row) in data.rows().into_iter().enumerate() {
340                    let mean = row.mean().unwrap_or(0.0);
341                    let std = {
342                        let variance =
343                            row.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / row.len() as f64;
344                        variance.sqrt()
345                    };
346                    let min = row.iter().copied().fold(f64::INFINITY, f64::min);
347                    let max = row.iter().copied().fold(f64::NEG_INFINITY, f64::max);
348                    let range = max - min;
349
350                    result[[i, n_cols]] = mean;
351                    result[[i, n_cols + 1]] = std;
352                    result[[i, n_cols + 2]] = range;
353                }
354
355                Ok(result)
356            })
357    }
358
359    /// Create a data validation pipeline
360    pub fn data_validation() -> DataPipeline<Array2<f64>> {
361        DataPipeline::new()
362            .add_transform(
363                "check_shape_consistency".to_string(),
364                |data: Array2<f64>| {
365                    if data.is_empty() {
366                        return Err(UtilsError::EmptyInput);
367                    }
368                    if data.nrows() == 0 || data.ncols() == 0 {
369                        return Err(UtilsError::InvalidParameter(
370                            "Data has zero rows or columns".to_string(),
371                        ));
372                    }
373                    Ok(data)
374                },
375            )
376            .add_transform("check_data_quality".to_string(), |data| {
377                let total_elements = data.len();
378                let nan_count = data.iter().filter(|&&x| !x.is_finite()).count();
379                let nan_ratio = nan_count as f64 / total_elements as f64;
380
381                if nan_ratio > 0.5 {
382                    return Err(UtilsError::InvalidParameter(format!(
383                        "Too many missing values: {:.2}%",
384                        nan_ratio * 100.0
385                    )));
386                }
387
388                Ok(data)
389            })
390            .add_transform("check_feature_variance".to_string(), |data| {
391                for (i, col) in data.columns().into_iter().enumerate() {
392                    let mean = col.mean().unwrap_or(0.0);
393                    let variance =
394                        col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / col.len() as f64;
395
396                    if variance < 1e-10 {
397                        return Err(UtilsError::InvalidParameter(format!(
398                            "Feature {i} has zero variance"
399                        )));
400                    }
401                }
402                Ok(data)
403            })
404    }
405}
406
407/// Pipeline metrics and monitoring
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct PipelineMetrics {
410    pub total_executions: u64,
411    pub successful_executions: u64,
412    pub failed_executions: u64,
413    pub average_execution_time: Duration,
414    pub total_execution_time: Duration,
415    pub step_metrics: HashMap<String, StepMetrics>,
416}
417
418#[derive(Debug, Clone, Serialize, Deserialize)]
419pub struct StepMetrics {
420    pub executions: u64,
421    pub average_time: Duration,
422    pub total_time: Duration,
423    pub success_rate: f64,
424}
425
426impl Default for PipelineMetrics {
427    fn default() -> Self {
428        Self {
429            total_executions: 0,
430            successful_executions: 0,
431            failed_executions: 0,
432            average_execution_time: Duration::from_secs(0),
433            total_execution_time: Duration::from_secs(0),
434            step_metrics: HashMap::new(),
435        }
436    }
437}
438
439impl PipelineMetrics {
440    pub fn success_rate(&self) -> f64 {
441        if self.total_executions == 0 {
442            0.0
443        } else {
444            self.successful_executions as f64 / self.total_executions as f64
445        }
446    }
447
448    pub fn record_execution(&mut self, result: &PipelineResult<impl Clone>, success: bool) {
449        self.total_executions += 1;
450        if success {
451            self.successful_executions += 1;
452        } else {
453            self.failed_executions += 1;
454        }
455
456        self.total_execution_time += result.execution_time;
457        self.average_execution_time = Duration::from_nanos(
458            (self.total_execution_time.as_nanos() / self.total_executions as u128) as u64,
459        );
460    }
461}
462
463/// Pipeline monitor for tracking execution statistics
464pub struct PipelineMonitor {
465    metrics: Arc<Mutex<PipelineMetrics>>,
466    enabled: bool,
467}
468
469impl Default for PipelineMonitor {
470    fn default() -> Self {
471        Self {
472            metrics: Arc::new(Mutex::new(PipelineMetrics::default())),
473            enabled: true,
474        }
475    }
476}
477
478impl PipelineMonitor {
479    pub fn new() -> Self {
480        Self::default()
481    }
482
483    pub fn enable(&mut self, enabled: bool) {
484        self.enabled = enabled;
485    }
486
487    pub fn record_execution<T: Clone>(&self, result: &PipelineResult<T>, success: bool) {
488        if !self.enabled {
489            return;
490        }
491
492        if let Ok(mut metrics) = self.metrics.lock() {
493            metrics.record_execution(result, success);
494        }
495    }
496
497    pub fn get_metrics(&self) -> Option<PipelineMetrics> {
498        self.metrics.lock().ok().map(|m| m.clone())
499    }
500
501    pub fn reset_metrics(&self) {
502        if let Ok(mut metrics) = self.metrics.lock() {
503            *metrics = PipelineMetrics::default();
504        }
505    }
506}
507
508impl fmt::Display for PipelineMetrics {
509    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510        writeln!(f, "Pipeline Metrics:")?;
511        writeln!(f, "  Total Executions: {}", self.total_executions)?;
512        writeln!(f, "  Success Rate: {:.2}%", self.success_rate() * 100.0)?;
513        writeln!(
514            f,
515            "  Average Execution Time: {:?}",
516            self.average_execution_time
517        )?;
518        writeln!(f, "  Total Execution Time: {:?}", self.total_execution_time)?;
519        Ok(())
520    }
521}
522
523#[allow(non_snake_case)]
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use scirs2_core::ndarray::array;
528
529    #[test]
530    fn test_pipeline_context() {
531        let context = PipelineContext::new().with_metadata("user".to_string(), "test".to_string());
532
533        assert_eq!(context.metadata.get("user"), Some(&"test".to_string()));
534
535        context.cache_set("key1".to_string(), vec![1, 2, 3]);
536        assert_eq!(context.cache_get("key1"), Some(vec![1, 2, 3]));
537
538        context.cache_clear();
539        assert_eq!(context.cache_get("key1"), None);
540    }
541
542    #[test]
543    fn test_transform_step() {
544        let step = TransformStep::new("double".to_string(), Box::new(|x: f64| Ok(x * 2.0)))
545            .with_description("Doubles the input value".to_string());
546
547        assert_eq!(step.name(), "double");
548        assert_eq!(step.description(), Some("Doubles the input value"));
549        assert_eq!(step.process(5.0).unwrap(), 10.0);
550    }
551
552    #[test]
553    fn test_data_pipeline_execution() {
554        let pipeline = DataPipeline::new()
555            .add_transform("add_one".to_string(), |x: f64| Ok(x + 1.0))
556            .add_transform("multiply_two".to_string(), |x: f64| Ok(x * 2.0));
557
558        let result = pipeline.execute(5.0).unwrap();
559        assert_eq!(result.data, 12.0); // (5 + 1) * 2
560        assert_eq!(result.steps_executed.len(), 2);
561    }
562
563    #[test]
564    fn test_ml_pipeline_data_cleaning() {
565        let data = array![[1.0, 2.0, f64::NAN], [3.0, f64::NAN, 4.0], [5.0, 6.0, 7.0]];
566
567        let pipeline = MLPipelineBuilder::data_cleaning();
568        let result = pipeline.execute(data).unwrap();
569
570        // Check that NaN values were replaced
571        assert!(result.data.iter().all(|&x| x.is_finite()));
572        assert_eq!(result.steps_executed.len(), 3);
573    }
574
575    #[test]
576    fn test_ml_pipeline_feature_engineering() {
577        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
578
579        let pipeline = MLPipelineBuilder::feature_engineering();
580        let result = pipeline.execute(data).unwrap();
581
582        // Original 2 features + 1 interaction + 3 statistical features = 6 total
583        assert_eq!(result.data.ncols(), 6);
584        assert_eq!(result.steps_executed.len(), 2);
585    }
586
587    #[test]
588    fn test_ml_pipeline_validation() {
589        let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
590
591        let pipeline = MLPipelineBuilder::data_validation();
592        let result = pipeline.execute(data).unwrap();
593
594        assert_eq!(result.data.shape(), &[3, 3]);
595        assert_eq!(result.steps_executed.len(), 3);
596    }
597
598    #[test]
599    fn test_pipeline_validation_failure() {
600        // Test with constant feature (zero variance)
601        let data = array![[1.0, 1.0], [2.0, 1.0], [3.0, 1.0]];
602
603        let pipeline = MLPipelineBuilder::data_validation();
604        let result = pipeline.execute(data);
605
606        assert!(result.is_err());
607        assert!(result.unwrap_err().to_string().contains("zero variance"));
608    }
609
610    #[test]
611    fn test_pipeline_monitor() {
612        let monitor = PipelineMonitor::new();
613
614        let result =
615            PipelineResult::new(42.0, Duration::from_millis(100), vec!["step1".to_string()]);
616
617        monitor.record_execution(&result, true);
618
619        let metrics = monitor.get_metrics().unwrap();
620        assert_eq!(metrics.total_executions, 1);
621        assert_eq!(metrics.successful_executions, 1);
622        assert_eq!(metrics.success_rate(), 1.0);
623
624        monitor.reset_metrics();
625        let metrics = monitor.get_metrics().unwrap();
626        assert_eq!(metrics.total_executions, 0);
627    }
628
629    #[test]
630    fn test_pipeline_metrics_display() {
631        let mut metrics = PipelineMetrics::default();
632        metrics.total_executions = 10;
633        metrics.successful_executions = 8;
634        metrics.average_execution_time = Duration::from_millis(50);
635
636        let display = format!("{metrics}");
637        assert!(display.contains("Total Executions: 10"));
638        assert!(display.contains("Success Rate: 80.00%"));
639    }
640}