sklears_preprocessing/
pipeline_validation.rs

1//! Pipeline Validation Utilities
2//!
3//! Comprehensive validation framework for preprocessing pipelines,
4//! ensuring correctness, compatibility, and optimal configuration.
5
6use scirs2_core::ndarray::Array2;
7use sklears_core::prelude::SklearsError;
8
9/// Pipeline validation result
10#[derive(Debug, Clone)]
11pub struct ValidationResult {
12    /// Whether the pipeline is valid
13    pub is_valid: bool,
14    /// Validation warnings
15    pub warnings: Vec<ValidationWarning>,
16    /// Validation errors
17    pub errors: Vec<ValidationError>,
18    /// Performance recommendations
19    pub recommendations: Vec<PerformanceRecommendation>,
20    /// Estimated memory usage (bytes)
21    pub estimated_memory: Option<usize>,
22    /// Estimated computation time (milliseconds)
23    pub estimated_time: Option<f64>,
24}
25
26/// Validation warning
27#[derive(Debug, Clone)]
28pub struct ValidationWarning {
29    pub step: String,
30    pub message: String,
31    pub severity: WarningSeverity,
32}
33
34/// Warning severity levels
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum WarningSeverity {
37    Low,
38    Medium,
39    High,
40}
41
42/// Validation error
43#[derive(Debug, Clone)]
44pub struct ValidationError {
45    pub step: String,
46    pub message: String,
47    pub error_type: ValidationErrorType,
48}
49
50/// Validation error types
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum ValidationErrorType {
53    IncompatibleDimensions,
54    MissingRequirement,
55    InvalidConfiguration,
56    DataTypeMismatch,
57    ResourceExceeded,
58}
59
60/// Performance recommendation
61#[derive(Debug, Clone)]
62pub struct PerformanceRecommendation {
63    pub category: RecommendationCategory,
64    pub message: String,
65    pub expected_improvement: Option<f64>,
66}
67
68/// Recommendation categories
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum RecommendationCategory {
71    OrderOptimization,
72    ParallelProcessing,
73    MemoryEfficiency,
74    ComputationEfficiency,
75    DataQuality,
76}
77
78/// Pipeline validator configuration
79#[derive(Debug, Clone)]
80pub struct PipelineValidatorConfig {
81    /// Maximum allowed memory usage (bytes)
82    pub max_memory: Option<usize>,
83    /// Maximum allowed computation time (milliseconds)
84    pub max_time: Option<f64>,
85    /// Check for redundant transformations
86    pub check_redundancy: bool,
87    /// Check for optimal ordering
88    pub check_ordering: bool,
89    /// Validate data compatibility
90    pub validate_data: bool,
91}
92
93impl Default for PipelineValidatorConfig {
94    fn default() -> Self {
95        Self {
96            max_memory: Some(8 * 1024 * 1024 * 1024), // 8GB
97            max_time: Some(60000.0),                  // 60 seconds
98            check_redundancy: true,
99            check_ordering: true,
100            validate_data: true,
101        }
102    }
103}
104
105/// Pipeline validator
106pub struct PipelineValidator {
107    config: PipelineValidatorConfig,
108}
109
110impl PipelineValidator {
111    /// Create a new pipeline validator
112    pub fn new() -> Self {
113        Self {
114            config: PipelineValidatorConfig::default(),
115        }
116    }
117
118    /// Create a validator with custom configuration
119    pub fn with_config(config: PipelineValidatorConfig) -> Self {
120        Self { config }
121    }
122
123    /// Validate a preprocessing pipeline
124    pub fn validate(
125        &self,
126        steps: &[String],
127        sample_data: Option<&Array2<f64>>,
128    ) -> Result<ValidationResult, SklearsError> {
129        let mut warnings = Vec::new();
130        let mut errors = Vec::new();
131        let mut recommendations = Vec::new();
132
133        // Check for empty pipeline
134        if steps.is_empty() {
135            errors.push(ValidationError {
136                step: "pipeline".to_string(),
137                message: "Pipeline has no steps".to_string(),
138                error_type: ValidationErrorType::InvalidConfiguration,
139            });
140        }
141
142        // Check for redundant transformations
143        if self.config.check_redundancy {
144            let redundant = self.check_redundancy(steps);
145            for (step1, step2) in redundant {
146                warnings.push(ValidationWarning {
147                    step: step1.clone(),
148                    message: format!("Redundant with step: {}", step2),
149                    severity: WarningSeverity::Medium,
150                });
151            }
152        }
153
154        // Check for optimal ordering
155        if self.config.check_ordering {
156            let ordering_issues = self.check_ordering(steps);
157            for (step, suggestion) in ordering_issues {
158                recommendations.push(PerformanceRecommendation {
159                    category: RecommendationCategory::OrderOptimization,
160                    message: format!("Step '{}': {}", step, suggestion),
161                    expected_improvement: Some(1.5), // Estimated 50% improvement
162                });
163            }
164        }
165
166        // Validate data if provided
167        if self.config.validate_data {
168            if let Some(data) = sample_data {
169                let data_issues = self.validate_data(data, steps);
170                errors.extend(data_issues);
171            }
172        }
173
174        // Estimate resource usage
175        let estimated_memory = self.estimate_memory(steps, sample_data);
176        let estimated_time = self.estimate_time(steps, sample_data);
177
178        // Check resource limits
179        if let (Some(est_mem), Some(max_mem)) = (estimated_memory, self.config.max_memory) {
180            if est_mem > max_mem {
181                errors.push(ValidationError {
182                    step: "pipeline".to_string(),
183                    message: format!(
184                        "Estimated memory usage ({} bytes) exceeds limit ({} bytes)",
185                        est_mem, max_mem
186                    ),
187                    error_type: ValidationErrorType::ResourceExceeded,
188                });
189            }
190        }
191
192        if let (Some(est_time), Some(max_time)) = (estimated_time, self.config.max_time) {
193            if est_time > max_time {
194                warnings.push(ValidationWarning {
195                    step: "pipeline".to_string(),
196                    message: format!(
197                        "Estimated computation time ({:.2}ms) may exceed limit ({:.2}ms)",
198                        est_time, max_time
199                    ),
200                    severity: WarningSeverity::High,
201                });
202            }
203        }
204
205        // Add general recommendations
206        if steps.len() > 10 {
207            recommendations.push(PerformanceRecommendation {
208                category: RecommendationCategory::ComputationEfficiency,
209                message: "Consider using FeatureUnion or ColumnTransformer for parallel processing"
210                    .to_string(),
211                expected_improvement: Some(2.0), // 2x improvement
212            });
213        }
214
215        let is_valid = errors.is_empty();
216
217        Ok(ValidationResult {
218            is_valid,
219            warnings,
220            errors,
221            recommendations,
222            estimated_memory,
223            estimated_time,
224        })
225    }
226
227    /// Check for redundant transformations
228    fn check_redundancy(&self, steps: &[String]) -> Vec<(String, String)> {
229        let mut redundant = Vec::new();
230
231        // Check for duplicate scaling steps
232        let scaling_steps: Vec<_> = steps
233            .iter()
234            .enumerate()
235            .filter(|(_, s)| {
236                s.contains("Scaler")
237                    || s.contains("Normalizer")
238                    || s.contains("StandardScaler")
239                    || s.contains("MinMaxScaler")
240            })
241            .collect();
242
243        if scaling_steps.len() > 1 {
244            for i in 1..scaling_steps.len() {
245                redundant.push((scaling_steps[i].1.clone(), scaling_steps[0].1.clone()));
246            }
247        }
248
249        // Check for duplicate imputation steps
250        let imputation_steps: Vec<_> = steps
251            .iter()
252            .enumerate()
253            .filter(|(_, s)| s.contains("Imputer"))
254            .collect();
255
256        if imputation_steps.len() > 1 {
257            for i in 1..imputation_steps.len() {
258                redundant.push((imputation_steps[i].1.clone(), imputation_steps[0].1.clone()));
259            }
260        }
261
262        redundant
263    }
264
265    /// Check for suboptimal ordering
266    fn check_ordering(&self, steps: &[String]) -> Vec<(String, String)> {
267        let mut issues = Vec::new();
268
269        for (i, step) in steps.iter().enumerate() {
270            // Imputation should come before scaling
271            if step.contains("Scaler") || step.contains("Normalizer") {
272                if steps[..i].iter().any(|s| s.contains("Imputer")) {
273                    // Good order
274                } else if steps[i..].iter().any(|s| s.contains("Imputer")) {
275                    issues.push((
276                        step.clone(),
277                        "Consider moving imputation before scaling".to_string(),
278                    ));
279                }
280            }
281
282            // Feature selection should come after feature generation
283            if (step.contains("FeatureSelector") || step.contains("SelectK"))
284                && !steps[..i]
285                    .iter()
286                    .any(|s| s.contains("PolynomialFeatures") || s.contains("FeatureUnion"))
287                && steps[i..]
288                    .iter()
289                    .any(|s| s.contains("PolynomialFeatures") || s.contains("FeatureUnion"))
290            {
291                issues.push((
292                    step.clone(),
293                    "Consider moving feature selection after feature generation".to_string(),
294                ));
295            }
296
297            // Encoding should come before numerical transformations
298            if step.contains("Encoder")
299                && steps[..i].iter().any(|s| {
300                    s.contains("Scaler") || s.contains("Normalizer") || s.contains("Transformer")
301                })
302            {
303                issues.push((
304                    step.clone(),
305                    "Consider moving encoding before numerical transformations".to_string(),
306                ));
307            }
308        }
309
310        issues
311    }
312
313    /// Validate data compatibility
314    fn validate_data(&self, data: &Array2<f64>, steps: &[String]) -> Vec<ValidationError> {
315        let mut errors = Vec::new();
316
317        let (n_samples, n_features) = (data.nrows(), data.ncols());
318
319        // Check for insufficient samples
320        if n_samples < 2 {
321            errors.push(ValidationError {
322                step: "data".to_string(),
323                message: "Insufficient samples (need at least 2)".to_string(),
324                error_type: ValidationErrorType::InvalidConfiguration,
325            });
326        }
327
328        // Check for steps that require minimum samples
329        for step in steps {
330            if step.contains("KNN") && n_samples < 5 {
331                errors.push(ValidationError {
332                    step: step.clone(),
333                    message: format!(
334                        "KNN-based methods require at least 5 samples, found {}",
335                        n_samples
336                    ),
337                    error_type: ValidationErrorType::MissingRequirement,
338                });
339            }
340
341            if step.contains("PCA") && n_samples < n_features {
342                errors.push(ValidationError {
343                    step: step.clone(),
344                    message: format!(
345                        "PCA requires n_samples >= n_features ({} < {})",
346                        n_samples, n_features
347                    ),
348                    error_type: ValidationErrorType::InvalidConfiguration,
349                });
350            }
351        }
352
353        // Check for NaN values
354        let has_nan = data.iter().any(|v| v.is_nan());
355        if has_nan {
356            let has_imputer = steps.iter().any(|s| s.contains("Imputer"));
357            if !has_imputer {
358                errors.push(ValidationError {
359                    step: "data".to_string(),
360                    message: "Data contains NaN values but no imputation step".to_string(),
361                    error_type: ValidationErrorType::MissingRequirement,
362                });
363            }
364        }
365
366        errors
367    }
368
369    /// Estimate memory usage
370    fn estimate_memory(
371        &self,
372        steps: &[String],
373        sample_data: Option<&Array2<f64>>,
374    ) -> Option<usize> {
375        let base_size = if let Some(data) = sample_data {
376            data.nrows() * data.ncols() * std::mem::size_of::<f64>()
377        } else {
378            0
379        };
380
381        let mut total = base_size;
382
383        for step in steps {
384            if step.contains("PolynomialFeatures") {
385                // Polynomial features can significantly increase memory
386                total = total.saturating_mul(3); // Rough estimate
387            } else if step.contains("OneHotEncoder") {
388                // One-hot encoding increases dimensions
389                total = total.saturating_mul(2);
390            } else {
391                // Most transformations keep similar size
392                total = total.saturating_add(base_size / 2);
393            }
394        }
395
396        Some(total)
397    }
398
399    /// Estimate computation time
400    fn estimate_time(&self, steps: &[String], sample_data: Option<&Array2<f64>>) -> Option<f64> {
401        let n_operations = if let Some(data) = sample_data {
402            (data.nrows() * data.ncols()) as f64
403        } else {
404            10000.0 // Default estimate
405        };
406
407        let mut total_time = 0.0;
408
409        for step in steps {
410            let step_time = if step.contains("KNN") {
411                n_operations * 0.001 // KNN is O(n²) but we approximate
412            } else if step.contains("PCA") {
413                n_operations * 0.0005 // Matrix operations
414            } else if step.contains("PolynomialFeatures") {
415                n_operations * 0.0002
416            } else {
417                n_operations * 0.00001 // Simple operations
418            };
419
420            total_time += step_time;
421        }
422
423        Some(total_time)
424    }
425}
426
427impl Default for PipelineValidator {
428    fn default() -> Self {
429        Self::new()
430    }
431}
432
433impl ValidationResult {
434    /// Print a summary of the validation result
435    pub fn print_summary(&self) {
436        println!("Pipeline Validation Result");
437        println!("==========================");
438        println!(
439            "Status: {}",
440            if self.is_valid { "VALID" } else { "INVALID" }
441        );
442        println!();
443
444        if !self.errors.is_empty() {
445            println!("Errors: {}", self.errors.len());
446            for error in &self.errors {
447                println!("  [ERROR] {}: {}", error.step, error.message);
448            }
449            println!();
450        }
451
452        if !self.warnings.is_empty() {
453            println!("Warnings: {}", self.warnings.len());
454            for warning in &self.warnings {
455                let severity = match warning.severity {
456                    WarningSeverity::Low => "LOW",
457                    WarningSeverity::Medium => "MEDIUM",
458                    WarningSeverity::High => "HIGH",
459                };
460                println!("  [{}] {}: {}", severity, warning.step, warning.message);
461            }
462            println!();
463        }
464
465        if !self.recommendations.is_empty() {
466            println!("Recommendations: {}", self.recommendations.len());
467            for rec in &self.recommendations {
468                let improvement = if let Some(imp) = rec.expected_improvement {
469                    format!(" (expected {:.1}x improvement)", imp)
470                } else {
471                    String::new()
472                };
473                println!("  [RECOMMEND] {}{}", rec.message, improvement);
474            }
475            println!();
476        }
477
478        if let Some(mem) = self.estimated_memory {
479            println!("Estimated Memory: {:.2} MB", mem as f64 / 1024.0 / 1024.0);
480        }
481
482        if let Some(time) = self.estimated_time {
483            println!("Estimated Time: {:.2} ms", time);
484        }
485    }
486
487    /// Get high severity warnings
488    pub fn high_severity_warnings(&self) -> Vec<&ValidationWarning> {
489        self.warnings
490            .iter()
491            .filter(|w| w.severity == WarningSeverity::High)
492            .collect()
493    }
494
495    /// Get errors of specific type
496    pub fn errors_of_type(&self, error_type: ValidationErrorType) -> Vec<&ValidationError> {
497        self.errors
498            .iter()
499            .filter(|e| e.error_type == error_type)
500            .collect()
501    }
502
503    /// Get recommendations by category
504    pub fn recommendations_by_category(
505        &self,
506        category: RecommendationCategory,
507    ) -> Vec<&PerformanceRecommendation> {
508        self.recommendations
509            .iter()
510            .filter(|r| r.category == category)
511            .collect()
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use scirs2_core::random::essentials::Normal;
519    use scirs2_core::random::{seeded_rng, Distribution};
520
521    fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
522        let mut rng = seeded_rng(seed);
523        let normal = Normal::new(0.0, 1.0).unwrap();
524
525        let data: Vec<f64> = (0..nrows * ncols)
526            .map(|_| normal.sample(&mut rng))
527            .collect();
528
529        Array2::from_shape_vec((nrows, ncols), data).unwrap()
530    }
531
532    #[test]
533    fn test_pipeline_validator_empty() {
534        let validator = PipelineValidator::new();
535        let result = validator.validate(&[], None).unwrap();
536
537        assert!(!result.is_valid);
538        assert!(!result.errors.is_empty());
539    }
540
541    #[test]
542    fn test_pipeline_validator_redundancy() {
543        let validator = PipelineValidator::new();
544        let steps = vec!["StandardScaler".to_string(), "MinMaxScaler".to_string()];
545
546        let result = validator.validate(&steps, None).unwrap();
547
548        // Should warn about redundant scaling
549        assert!(!result.warnings.is_empty());
550    }
551
552    #[test]
553    fn test_pipeline_validator_ordering() {
554        let validator = PipelineValidator::new();
555        let steps = vec![
556            "StandardScaler".to_string(),
557            "SimpleImputer".to_string(), // Imputation should come first
558        ];
559
560        let result = validator.validate(&steps, None).unwrap();
561
562        // Should recommend reordering
563        assert!(!result.recommendations.is_empty());
564    }
565
566    #[test]
567    fn test_pipeline_validator_data() {
568        let data = generate_test_data(100, 10, 42);
569        let validator = PipelineValidator::new();
570        let steps = vec!["StandardScaler".to_string()];
571
572        let result = validator.validate(&steps, Some(&data)).unwrap();
573
574        assert!(result.is_valid);
575    }
576
577    #[test]
578    fn test_pipeline_validator_insufficient_samples() {
579        let data = Array2::from_elem((1, 5), 1.0);
580        let validator = PipelineValidator::new();
581        let steps = vec!["KNNImputer".to_string()];
582
583        let result = validator.validate(&steps, Some(&data)).unwrap();
584
585        // Should error on insufficient samples
586        assert!(!result.is_valid);
587        assert!(!result.errors.is_empty());
588    }
589
590    #[test]
591    fn test_pipeline_validator_nan_without_imputer() {
592        let mut data = generate_test_data(50, 5, 123);
593        data[[0, 0]] = f64::NAN;
594
595        let validator = PipelineValidator::new();
596        let steps = vec!["StandardScaler".to_string()];
597
598        let result = validator.validate(&steps, Some(&data)).unwrap();
599
600        // Should error on NaN without imputation
601        assert!(!result.is_valid);
602    }
603
604    #[test]
605    fn test_memory_estimation() {
606        let data = generate_test_data(1000, 100, 456);
607        let validator = PipelineValidator::new();
608        let steps = vec![
609            "StandardScaler".to_string(),
610            "PolynomialFeatures".to_string(),
611        ];
612
613        let result = validator.validate(&steps, Some(&data)).unwrap();
614
615        assert!(result.estimated_memory.is_some());
616        assert!(result.estimated_memory.unwrap() > 0);
617    }
618
619    #[test]
620    fn test_time_estimation() {
621        let data = generate_test_data(1000, 50, 789);
622        let validator = PipelineValidator::new();
623        let steps = vec!["StandardScaler".to_string(), "PCA".to_string()];
624
625        let result = validator.validate(&steps, Some(&data)).unwrap();
626
627        assert!(result.estimated_time.is_some());
628        assert!(result.estimated_time.unwrap() > 0.0);
629    }
630
631    #[test]
632    fn test_validation_result_filtering() {
633        let validator = PipelineValidator::new();
634        let steps = vec![
635            "StandardScaler".to_string(),
636            "MinMaxScaler".to_string(),
637            "SimpleImputer".to_string(),
638        ];
639
640        let result = validator.validate(&steps, None).unwrap();
641
642        let high_warnings = result.high_severity_warnings();
643        assert!(high_warnings.len() <= result.warnings.len());
644    }
645}