Skip to main content

sklears_preprocessing/
pipeline_validation.rs

1//! Pipeline Validation Utilities
2//!
3//! Comprehensive validation framework for preprocessing pipelines,
4//! ensuring correctness, compatibility, and optimal configuration.
5
6use scirs2_core::ndarray::Array2;
7use sklears_core::prelude::SklearsError;
8
9/// Pipeline validation result
10#[derive(Debug, Clone)]
11pub struct ValidationResult {
12    /// Whether the pipeline is valid
13    pub is_valid: bool,
14    /// Validation warnings
15    pub warnings: Vec<ValidationWarning>,
16    /// Validation errors
17    pub errors: Vec<ValidationError>,
18    /// Performance recommendations
19    pub recommendations: Vec<PerformanceRecommendation>,
20    /// Estimated memory usage (bytes)
21    pub estimated_memory: Option<usize>,
22    /// Estimated computation time (milliseconds)
23    pub estimated_time: Option<f64>,
24}
25
26/// Validation warning
27#[derive(Debug, Clone)]
28pub struct ValidationWarning {
29    pub step: String,
30    pub message: String,
31    pub severity: WarningSeverity,
32}
33
34/// Warning severity levels
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum WarningSeverity {
37    Low,
38    Medium,
39    High,
40}
41
42/// Validation error
43#[derive(Debug, Clone)]
44pub struct ValidationError {
45    pub step: String,
46    pub message: String,
47    pub error_type: ValidationErrorType,
48}
49
50/// Validation error types
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum ValidationErrorType {
53    IncompatibleDimensions,
54    MissingRequirement,
55    InvalidConfiguration,
56    DataTypeMismatch,
57    ResourceExceeded,
58}
59
60/// Performance recommendation
61#[derive(Debug, Clone)]
62pub struct PerformanceRecommendation {
63    pub category: RecommendationCategory,
64    pub message: String,
65    pub expected_improvement: Option<f64>,
66}
67
68/// Recommendation categories
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum RecommendationCategory {
71    OrderOptimization,
72    ParallelProcessing,
73    MemoryEfficiency,
74    ComputationEfficiency,
75    DataQuality,
76}
77
78/// Pipeline validator configuration
79#[derive(Debug, Clone)]
80pub struct PipelineValidatorConfig {
81    /// Maximum allowed memory usage (bytes)
82    pub max_memory: Option<usize>,
83    /// Maximum allowed computation time (milliseconds)
84    pub max_time: Option<f64>,
85    /// Check for redundant transformations
86    pub check_redundancy: bool,
87    /// Check for optimal ordering
88    pub check_ordering: bool,
89    /// Validate data compatibility
90    pub validate_data: bool,
91}
92
93impl Default for PipelineValidatorConfig {
94    fn default() -> Self {
95        Self {
96            max_memory: Some(8 * 1024 * 1024 * 1024), // 8GB
97            max_time: Some(60000.0),                  // 60 seconds
98            check_redundancy: true,
99            check_ordering: true,
100            validate_data: true,
101        }
102    }
103}
104
105/// Pipeline validator
106pub struct PipelineValidator {
107    config: PipelineValidatorConfig,
108}
109
110impl PipelineValidator {
111    /// Create a new pipeline validator
112    pub fn new() -> Self {
113        Self {
114            config: PipelineValidatorConfig::default(),
115        }
116    }
117
118    /// Create a validator with custom configuration
119    pub fn with_config(config: PipelineValidatorConfig) -> Self {
120        Self { config }
121    }
122
123    /// Validate a preprocessing pipeline
124    pub fn validate(
125        &self,
126        steps: &[String],
127        sample_data: Option<&Array2<f64>>,
128    ) -> Result<ValidationResult, SklearsError> {
129        let mut warnings = Vec::new();
130        let mut errors = Vec::new();
131        let mut recommendations = Vec::new();
132
133        // Check for empty pipeline
134        if steps.is_empty() {
135            errors.push(ValidationError {
136                step: "pipeline".to_string(),
137                message: "Pipeline has no steps".to_string(),
138                error_type: ValidationErrorType::InvalidConfiguration,
139            });
140        }
141
142        // Check for redundant transformations
143        if self.config.check_redundancy {
144            let redundant = self.check_redundancy(steps);
145            for (step1, step2) in redundant {
146                warnings.push(ValidationWarning {
147                    step: step1.clone(),
148                    message: format!("Redundant with step: {}", step2),
149                    severity: WarningSeverity::Medium,
150                });
151            }
152        }
153
154        // Check for optimal ordering
155        if self.config.check_ordering {
156            let ordering_issues = self.check_ordering(steps);
157            for (step, suggestion) in ordering_issues {
158                recommendations.push(PerformanceRecommendation {
159                    category: RecommendationCategory::OrderOptimization,
160                    message: format!("Step '{}': {}", step, suggestion),
161                    expected_improvement: Some(1.5), // Estimated 50% improvement
162                });
163            }
164        }
165
166        // Validate data if provided
167        if self.config.validate_data {
168            if let Some(data) = sample_data {
169                let data_issues = self.validate_data(data, steps);
170                errors.extend(data_issues);
171            }
172        }
173
174        // Estimate resource usage
175        let estimated_memory = self.estimate_memory(steps, sample_data);
176        let estimated_time = self.estimate_time(steps, sample_data);
177
178        // Check resource limits
179        if let (Some(est_mem), Some(max_mem)) = (estimated_memory, self.config.max_memory) {
180            if est_mem > max_mem {
181                errors.push(ValidationError {
182                    step: "pipeline".to_string(),
183                    message: format!(
184                        "Estimated memory usage ({} bytes) exceeds limit ({} bytes)",
185                        est_mem, max_mem
186                    ),
187                    error_type: ValidationErrorType::ResourceExceeded,
188                });
189            }
190        }
191
192        if let (Some(est_time), Some(max_time)) = (estimated_time, self.config.max_time) {
193            if est_time > max_time {
194                warnings.push(ValidationWarning {
195                    step: "pipeline".to_string(),
196                    message: format!(
197                        "Estimated computation time ({:.2}ms) may exceed limit ({:.2}ms)",
198                        est_time, max_time
199                    ),
200                    severity: WarningSeverity::High,
201                });
202            }
203        }
204
205        // Add general recommendations
206        if steps.len() > 10 {
207            recommendations.push(PerformanceRecommendation {
208                category: RecommendationCategory::ComputationEfficiency,
209                message: "Consider using FeatureUnion or ColumnTransformer for parallel processing"
210                    .to_string(),
211                expected_improvement: Some(2.0), // 2x improvement
212            });
213        }
214
215        let is_valid = errors.is_empty();
216
217        Ok(ValidationResult {
218            is_valid,
219            warnings,
220            errors,
221            recommendations,
222            estimated_memory,
223            estimated_time,
224        })
225    }
226
227    /// Check for redundant transformations
228    fn check_redundancy(&self, steps: &[String]) -> Vec<(String, String)> {
229        let mut redundant = Vec::new();
230
231        // Check for duplicate scaling steps
232        let scaling_steps: Vec<_> = steps
233            .iter()
234            .enumerate()
235            .filter(|(_, s)| {
236                s.contains("Scaler")
237                    || s.contains("Normalizer")
238                    || s.contains("StandardScaler")
239                    || s.contains("MinMaxScaler")
240            })
241            .collect();
242
243        if scaling_steps.len() > 1 {
244            for i in 1..scaling_steps.len() {
245                redundant.push((scaling_steps[i].1.clone(), scaling_steps[0].1.clone()));
246            }
247        }
248
249        // Check for duplicate imputation steps
250        let imputation_steps: Vec<_> = steps
251            .iter()
252            .enumerate()
253            .filter(|(_, s)| s.contains("Imputer"))
254            .collect();
255
256        if imputation_steps.len() > 1 {
257            for i in 1..imputation_steps.len() {
258                redundant.push((imputation_steps[i].1.clone(), imputation_steps[0].1.clone()));
259            }
260        }
261
262        redundant
263    }
264
265    /// Check for suboptimal ordering
266    fn check_ordering(&self, steps: &[String]) -> Vec<(String, String)> {
267        let mut issues = Vec::new();
268
269        for (i, step) in steps.iter().enumerate() {
270            // Imputation should come before scaling
271            if step.contains("Scaler") || step.contains("Normalizer") {
272                if steps[..i].iter().any(|s| s.contains("Imputer")) {
273                    // Good order
274                } else if steps[i..].iter().any(|s| s.contains("Imputer")) {
275                    issues.push((
276                        step.clone(),
277                        "Consider moving imputation before scaling".to_string(),
278                    ));
279                }
280            }
281
282            // Feature selection should come after feature generation
283            if (step.contains("FeatureSelector") || step.contains("SelectK"))
284                && !steps[..i]
285                    .iter()
286                    .any(|s| s.contains("PolynomialFeatures") || s.contains("FeatureUnion"))
287                && steps[i..]
288                    .iter()
289                    .any(|s| s.contains("PolynomialFeatures") || s.contains("FeatureUnion"))
290            {
291                issues.push((
292                    step.clone(),
293                    "Consider moving feature selection after feature generation".to_string(),
294                ));
295            }
296
297            // Encoding should come before numerical transformations
298            if step.contains("Encoder")
299                && steps[..i].iter().any(|s| {
300                    s.contains("Scaler") || s.contains("Normalizer") || s.contains("Transformer")
301                })
302            {
303                issues.push((
304                    step.clone(),
305                    "Consider moving encoding before numerical transformations".to_string(),
306                ));
307            }
308        }
309
310        issues
311    }
312
313    /// Validate data compatibility
314    fn validate_data(&self, data: &Array2<f64>, steps: &[String]) -> Vec<ValidationError> {
315        let mut errors = Vec::new();
316
317        let (n_samples, n_features) = (data.nrows(), data.ncols());
318
319        // Check for insufficient samples
320        if n_samples < 2 {
321            errors.push(ValidationError {
322                step: "data".to_string(),
323                message: "Insufficient samples (need at least 2)".to_string(),
324                error_type: ValidationErrorType::InvalidConfiguration,
325            });
326        }
327
328        // Check for steps that require minimum samples
329        for step in steps {
330            if step.contains("KNN") && n_samples < 5 {
331                errors.push(ValidationError {
332                    step: step.clone(),
333                    message: format!(
334                        "KNN-based methods require at least 5 samples, found {}",
335                        n_samples
336                    ),
337                    error_type: ValidationErrorType::MissingRequirement,
338                });
339            }
340
341            if step.contains("PCA") && n_samples < n_features {
342                errors.push(ValidationError {
343                    step: step.clone(),
344                    message: format!(
345                        "PCA requires n_samples >= n_features ({} < {})",
346                        n_samples, n_features
347                    ),
348                    error_type: ValidationErrorType::InvalidConfiguration,
349                });
350            }
351        }
352
353        // Check for NaN values
354        let has_nan = data.iter().any(|v| v.is_nan());
355        if has_nan {
356            let has_imputer = steps.iter().any(|s| s.contains("Imputer"));
357            if !has_imputer {
358                errors.push(ValidationError {
359                    step: "data".to_string(),
360                    message: "Data contains NaN values but no imputation step".to_string(),
361                    error_type: ValidationErrorType::MissingRequirement,
362                });
363            }
364        }
365
366        errors
367    }
368
369    /// Estimate memory usage
370    fn estimate_memory(
371        &self,
372        steps: &[String],
373        sample_data: Option<&Array2<f64>>,
374    ) -> Option<usize> {
375        let base_size = if let Some(data) = sample_data {
376            data.nrows() * data.ncols() * std::mem::size_of::<f64>()
377        } else {
378            0
379        };
380
381        let mut total = base_size;
382
383        for step in steps {
384            if step.contains("PolynomialFeatures") {
385                // Polynomial features can significantly increase memory
386                total = total.saturating_mul(3); // Rough estimate
387            } else if step.contains("OneHotEncoder") {
388                // One-hot encoding increases dimensions
389                total = total.saturating_mul(2);
390            } else {
391                // Most transformations keep similar size
392                total = total.saturating_add(base_size / 2);
393            }
394        }
395
396        Some(total)
397    }
398
399    /// Estimate computation time
400    fn estimate_time(&self, steps: &[String], sample_data: Option<&Array2<f64>>) -> Option<f64> {
401        let n_operations = if let Some(data) = sample_data {
402            (data.nrows() * data.ncols()) as f64
403        } else {
404            10000.0 // Default estimate
405        };
406
407        let mut total_time = 0.0;
408
409        for step in steps {
410            let step_time = if step.contains("KNN") {
411                n_operations * 0.001 // KNN is O(n²) but we approximate
412            } else if step.contains("PCA") {
413                n_operations * 0.0005 // Matrix operations
414            } else if step.contains("PolynomialFeatures") {
415                n_operations * 0.0002
416            } else {
417                n_operations * 0.00001 // Simple operations
418            };
419
420            total_time += step_time;
421        }
422
423        Some(total_time)
424    }
425}
426
427impl Default for PipelineValidator {
428    fn default() -> Self {
429        Self::new()
430    }
431}
432
433impl ValidationResult {
434    /// Print a summary of the validation result
435    pub fn print_summary(&self) {
436        println!("Pipeline Validation Result");
437        println!("==========================");
438        println!(
439            "Status: {}",
440            if self.is_valid { "VALID" } else { "INVALID" }
441        );
442        println!();
443
444        if !self.errors.is_empty() {
445            println!("Errors: {}", self.errors.len());
446            for error in &self.errors {
447                println!("  [ERROR] {}: {}", error.step, error.message);
448            }
449            println!();
450        }
451
452        if !self.warnings.is_empty() {
453            println!("Warnings: {}", self.warnings.len());
454            for warning in &self.warnings {
455                let severity = match warning.severity {
456                    WarningSeverity::Low => "LOW",
457                    WarningSeverity::Medium => "MEDIUM",
458                    WarningSeverity::High => "HIGH",
459                };
460                println!("  [{}] {}: {}", severity, warning.step, warning.message);
461            }
462            println!();
463        }
464
465        if !self.recommendations.is_empty() {
466            println!("Recommendations: {}", self.recommendations.len());
467            for rec in &self.recommendations {
468                let improvement = if let Some(imp) = rec.expected_improvement {
469                    format!(" (expected {:.1}x improvement)", imp)
470                } else {
471                    String::new()
472                };
473                println!("  [RECOMMEND] {}{}", rec.message, improvement);
474            }
475            println!();
476        }
477
478        if let Some(mem) = self.estimated_memory {
479            println!("Estimated Memory: {:.2} MB", mem as f64 / 1024.0 / 1024.0);
480        }
481
482        if let Some(time) = self.estimated_time {
483            println!("Estimated Time: {:.2} ms", time);
484        }
485    }
486
487    /// Get high severity warnings
488    pub fn high_severity_warnings(&self) -> Vec<&ValidationWarning> {
489        self.warnings
490            .iter()
491            .filter(|w| w.severity == WarningSeverity::High)
492            .collect()
493    }
494
495    /// Get errors of specific type
496    pub fn errors_of_type(&self, error_type: ValidationErrorType) -> Vec<&ValidationError> {
497        self.errors
498            .iter()
499            .filter(|e| e.error_type == error_type)
500            .collect()
501    }
502
503    /// Get recommendations by category
504    pub fn recommendations_by_category(
505        &self,
506        category: RecommendationCategory,
507    ) -> Vec<&PerformanceRecommendation> {
508        self.recommendations
509            .iter()
510            .filter(|r| r.category == category)
511            .collect()
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use scirs2_core::random::essentials::Normal;
519    use scirs2_core::random::{seeded_rng, Distribution};
520
521    fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
522        let mut rng = seeded_rng(seed);
523        let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
524
525        let data: Vec<f64> = (0..nrows * ncols)
526            .map(|_| normal.sample(&mut rng))
527            .collect();
528
529        Array2::from_shape_vec((nrows, ncols), data).expect("shape and data length should match")
530    }
531
532    #[test]
533    fn test_pipeline_validator_empty() {
534        let validator = PipelineValidator::new();
535        let result = validator
536            .validate(&[], None)
537            .expect("operation should succeed");
538
539        assert!(!result.is_valid);
540        assert!(!result.errors.is_empty());
541    }
542
543    #[test]
544    fn test_pipeline_validator_redundancy() {
545        let validator = PipelineValidator::new();
546        let steps = vec!["StandardScaler".to_string(), "MinMaxScaler".to_string()];
547
548        let result = validator
549            .validate(&steps, None)
550            .expect("operation should succeed");
551
552        // Should warn about redundant scaling
553        assert!(!result.warnings.is_empty());
554    }
555
556    #[test]
557    fn test_pipeline_validator_ordering() {
558        let validator = PipelineValidator::new();
559        let steps = vec![
560            "StandardScaler".to_string(),
561            "SimpleImputer".to_string(), // Imputation should come first
562        ];
563
564        let result = validator
565            .validate(&steps, None)
566            .expect("operation should succeed");
567
568        // Should recommend reordering
569        assert!(!result.recommendations.is_empty());
570    }
571
572    #[test]
573    fn test_pipeline_validator_data() {
574        let data = generate_test_data(100, 10, 42);
575        let validator = PipelineValidator::new();
576        let steps = vec!["StandardScaler".to_string()];
577
578        let result = validator
579            .validate(&steps, Some(&data))
580            .expect("operation should succeed");
581
582        assert!(result.is_valid);
583    }
584
585    #[test]
586    fn test_pipeline_validator_insufficient_samples() {
587        let data = Array2::from_elem((1, 5), 1.0);
588        let validator = PipelineValidator::new();
589        let steps = vec!["KNNImputer".to_string()];
590
591        let result = validator
592            .validate(&steps, Some(&data))
593            .expect("operation should succeed");
594
595        // Should error on insufficient samples
596        assert!(!result.is_valid);
597        assert!(!result.errors.is_empty());
598    }
599
600    #[test]
601    fn test_pipeline_validator_nan_without_imputer() {
602        let mut data = generate_test_data(50, 5, 123);
603        data[[0, 0]] = f64::NAN;
604
605        let validator = PipelineValidator::new();
606        let steps = vec!["StandardScaler".to_string()];
607
608        let result = validator
609            .validate(&steps, Some(&data))
610            .expect("operation should succeed");
611
612        // Should error on NaN without imputation
613        assert!(!result.is_valid);
614    }
615
616    #[test]
617    fn test_memory_estimation() {
618        let data = generate_test_data(1000, 100, 456);
619        let validator = PipelineValidator::new();
620        let steps = vec![
621            "StandardScaler".to_string(),
622            "PolynomialFeatures".to_string(),
623        ];
624
625        let result = validator
626            .validate(&steps, Some(&data))
627            .expect("operation should succeed");
628
629        assert!(result.estimated_memory.is_some());
630        assert!(result.estimated_memory.expect("operation should succeed") > 0);
631    }
632
633    #[test]
634    fn test_time_estimation() {
635        let data = generate_test_data(1000, 50, 789);
636        let validator = PipelineValidator::new();
637        let steps = vec!["StandardScaler".to_string(), "PCA".to_string()];
638
639        let result = validator
640            .validate(&steps, Some(&data))
641            .expect("operation should succeed");
642
643        assert!(result.estimated_time.is_some());
644        assert!(result.estimated_time.expect("operation should succeed") > 0.0);
645    }
646
647    #[test]
648    fn test_validation_result_filtering() {
649        let validator = PipelineValidator::new();
650        let steps = vec![
651            "StandardScaler".to_string(),
652            "MinMaxScaler".to_string(),
653            "SimpleImputer".to_string(),
654        ];
655
656        let result = validator
657            .validate(&steps, None)
658            .expect("operation should succeed");
659
660        let high_warnings = result.high_severity_warnings();
661        assert!(high_warnings.len() <= result.warnings.len());
662    }
663}