1use scirs2_core::ndarray::Array2;
7use sklears_core::prelude::SklearsError;
8
9#[derive(Debug, Clone)]
11pub struct ValidationResult {
12 pub is_valid: bool,
14 pub warnings: Vec<ValidationWarning>,
16 pub errors: Vec<ValidationError>,
18 pub recommendations: Vec<PerformanceRecommendation>,
20 pub estimated_memory: Option<usize>,
22 pub estimated_time: Option<f64>,
24}
25
26#[derive(Debug, Clone)]
28pub struct ValidationWarning {
29 pub step: String,
30 pub message: String,
31 pub severity: WarningSeverity,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum WarningSeverity {
37 Low,
38 Medium,
39 High,
40}
41
42#[derive(Debug, Clone)]
44pub struct ValidationError {
45 pub step: String,
46 pub message: String,
47 pub error_type: ValidationErrorType,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum ValidationErrorType {
53 IncompatibleDimensions,
54 MissingRequirement,
55 InvalidConfiguration,
56 DataTypeMismatch,
57 ResourceExceeded,
58}
59
60#[derive(Debug, Clone)]
62pub struct PerformanceRecommendation {
63 pub category: RecommendationCategory,
64 pub message: String,
65 pub expected_improvement: Option<f64>,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum RecommendationCategory {
71 OrderOptimization,
72 ParallelProcessing,
73 MemoryEfficiency,
74 ComputationEfficiency,
75 DataQuality,
76}
77
78#[derive(Debug, Clone)]
80pub struct PipelineValidatorConfig {
81 pub max_memory: Option<usize>,
83 pub max_time: Option<f64>,
85 pub check_redundancy: bool,
87 pub check_ordering: bool,
89 pub validate_data: bool,
91}
92
93impl Default for PipelineValidatorConfig {
94 fn default() -> Self {
95 Self {
96 max_memory: Some(8 * 1024 * 1024 * 1024), max_time: Some(60000.0), check_redundancy: true,
99 check_ordering: true,
100 validate_data: true,
101 }
102 }
103}
104
105pub struct PipelineValidator {
107 config: PipelineValidatorConfig,
108}
109
110impl PipelineValidator {
111 pub fn new() -> Self {
113 Self {
114 config: PipelineValidatorConfig::default(),
115 }
116 }
117
118 pub fn with_config(config: PipelineValidatorConfig) -> Self {
120 Self { config }
121 }
122
123 pub fn validate(
125 &self,
126 steps: &[String],
127 sample_data: Option<&Array2<f64>>,
128 ) -> Result<ValidationResult, SklearsError> {
129 let mut warnings = Vec::new();
130 let mut errors = Vec::new();
131 let mut recommendations = Vec::new();
132
133 if steps.is_empty() {
135 errors.push(ValidationError {
136 step: "pipeline".to_string(),
137 message: "Pipeline has no steps".to_string(),
138 error_type: ValidationErrorType::InvalidConfiguration,
139 });
140 }
141
142 if self.config.check_redundancy {
144 let redundant = self.check_redundancy(steps);
145 for (step1, step2) in redundant {
146 warnings.push(ValidationWarning {
147 step: step1.clone(),
148 message: format!("Redundant with step: {}", step2),
149 severity: WarningSeverity::Medium,
150 });
151 }
152 }
153
154 if self.config.check_ordering {
156 let ordering_issues = self.check_ordering(steps);
157 for (step, suggestion) in ordering_issues {
158 recommendations.push(PerformanceRecommendation {
159 category: RecommendationCategory::OrderOptimization,
160 message: format!("Step '{}': {}", step, suggestion),
161 expected_improvement: Some(1.5), });
163 }
164 }
165
166 if self.config.validate_data {
168 if let Some(data) = sample_data {
169 let data_issues = self.validate_data(data, steps);
170 errors.extend(data_issues);
171 }
172 }
173
174 let estimated_memory = self.estimate_memory(steps, sample_data);
176 let estimated_time = self.estimate_time(steps, sample_data);
177
178 if let (Some(est_mem), Some(max_mem)) = (estimated_memory, self.config.max_memory) {
180 if est_mem > max_mem {
181 errors.push(ValidationError {
182 step: "pipeline".to_string(),
183 message: format!(
184 "Estimated memory usage ({} bytes) exceeds limit ({} bytes)",
185 est_mem, max_mem
186 ),
187 error_type: ValidationErrorType::ResourceExceeded,
188 });
189 }
190 }
191
192 if let (Some(est_time), Some(max_time)) = (estimated_time, self.config.max_time) {
193 if est_time > max_time {
194 warnings.push(ValidationWarning {
195 step: "pipeline".to_string(),
196 message: format!(
197 "Estimated computation time ({:.2}ms) may exceed limit ({:.2}ms)",
198 est_time, max_time
199 ),
200 severity: WarningSeverity::High,
201 });
202 }
203 }
204
205 if steps.len() > 10 {
207 recommendations.push(PerformanceRecommendation {
208 category: RecommendationCategory::ComputationEfficiency,
209 message: "Consider using FeatureUnion or ColumnTransformer for parallel processing"
210 .to_string(),
211 expected_improvement: Some(2.0), });
213 }
214
215 let is_valid = errors.is_empty();
216
217 Ok(ValidationResult {
218 is_valid,
219 warnings,
220 errors,
221 recommendations,
222 estimated_memory,
223 estimated_time,
224 })
225 }
226
227 fn check_redundancy(&self, steps: &[String]) -> Vec<(String, String)> {
229 let mut redundant = Vec::new();
230
231 let scaling_steps: Vec<_> = steps
233 .iter()
234 .enumerate()
235 .filter(|(_, s)| {
236 s.contains("Scaler")
237 || s.contains("Normalizer")
238 || s.contains("StandardScaler")
239 || s.contains("MinMaxScaler")
240 })
241 .collect();
242
243 if scaling_steps.len() > 1 {
244 for i in 1..scaling_steps.len() {
245 redundant.push((scaling_steps[i].1.clone(), scaling_steps[0].1.clone()));
246 }
247 }
248
249 let imputation_steps: Vec<_> = steps
251 .iter()
252 .enumerate()
253 .filter(|(_, s)| s.contains("Imputer"))
254 .collect();
255
256 if imputation_steps.len() > 1 {
257 for i in 1..imputation_steps.len() {
258 redundant.push((imputation_steps[i].1.clone(), imputation_steps[0].1.clone()));
259 }
260 }
261
262 redundant
263 }
264
265 fn check_ordering(&self, steps: &[String]) -> Vec<(String, String)> {
267 let mut issues = Vec::new();
268
269 for (i, step) in steps.iter().enumerate() {
270 if step.contains("Scaler") || step.contains("Normalizer") {
272 if steps[..i].iter().any(|s| s.contains("Imputer")) {
273 } else if steps[i..].iter().any(|s| s.contains("Imputer")) {
275 issues.push((
276 step.clone(),
277 "Consider moving imputation before scaling".to_string(),
278 ));
279 }
280 }
281
282 if (step.contains("FeatureSelector") || step.contains("SelectK"))
284 && !steps[..i]
285 .iter()
286 .any(|s| s.contains("PolynomialFeatures") || s.contains("FeatureUnion"))
287 && steps[i..]
288 .iter()
289 .any(|s| s.contains("PolynomialFeatures") || s.contains("FeatureUnion"))
290 {
291 issues.push((
292 step.clone(),
293 "Consider moving feature selection after feature generation".to_string(),
294 ));
295 }
296
297 if step.contains("Encoder")
299 && steps[..i].iter().any(|s| {
300 s.contains("Scaler") || s.contains("Normalizer") || s.contains("Transformer")
301 })
302 {
303 issues.push((
304 step.clone(),
305 "Consider moving encoding before numerical transformations".to_string(),
306 ));
307 }
308 }
309
310 issues
311 }
312
313 fn validate_data(&self, data: &Array2<f64>, steps: &[String]) -> Vec<ValidationError> {
315 let mut errors = Vec::new();
316
317 let (n_samples, n_features) = (data.nrows(), data.ncols());
318
319 if n_samples < 2 {
321 errors.push(ValidationError {
322 step: "data".to_string(),
323 message: "Insufficient samples (need at least 2)".to_string(),
324 error_type: ValidationErrorType::InvalidConfiguration,
325 });
326 }
327
328 for step in steps {
330 if step.contains("KNN") && n_samples < 5 {
331 errors.push(ValidationError {
332 step: step.clone(),
333 message: format!(
334 "KNN-based methods require at least 5 samples, found {}",
335 n_samples
336 ),
337 error_type: ValidationErrorType::MissingRequirement,
338 });
339 }
340
341 if step.contains("PCA") && n_samples < n_features {
342 errors.push(ValidationError {
343 step: step.clone(),
344 message: format!(
345 "PCA requires n_samples >= n_features ({} < {})",
346 n_samples, n_features
347 ),
348 error_type: ValidationErrorType::InvalidConfiguration,
349 });
350 }
351 }
352
353 let has_nan = data.iter().any(|v| v.is_nan());
355 if has_nan {
356 let has_imputer = steps.iter().any(|s| s.contains("Imputer"));
357 if !has_imputer {
358 errors.push(ValidationError {
359 step: "data".to_string(),
360 message: "Data contains NaN values but no imputation step".to_string(),
361 error_type: ValidationErrorType::MissingRequirement,
362 });
363 }
364 }
365
366 errors
367 }
368
369 fn estimate_memory(
371 &self,
372 steps: &[String],
373 sample_data: Option<&Array2<f64>>,
374 ) -> Option<usize> {
375 let base_size = if let Some(data) = sample_data {
376 data.nrows() * data.ncols() * std::mem::size_of::<f64>()
377 } else {
378 0
379 };
380
381 let mut total = base_size;
382
383 for step in steps {
384 if step.contains("PolynomialFeatures") {
385 total = total.saturating_mul(3); } else if step.contains("OneHotEncoder") {
388 total = total.saturating_mul(2);
390 } else {
391 total = total.saturating_add(base_size / 2);
393 }
394 }
395
396 Some(total)
397 }
398
399 fn estimate_time(&self, steps: &[String], sample_data: Option<&Array2<f64>>) -> Option<f64> {
401 let n_operations = if let Some(data) = sample_data {
402 (data.nrows() * data.ncols()) as f64
403 } else {
404 10000.0 };
406
407 let mut total_time = 0.0;
408
409 for step in steps {
410 let step_time = if step.contains("KNN") {
411 n_operations * 0.001 } else if step.contains("PCA") {
413 n_operations * 0.0005 } else if step.contains("PolynomialFeatures") {
415 n_operations * 0.0002
416 } else {
417 n_operations * 0.00001 };
419
420 total_time += step_time;
421 }
422
423 Some(total_time)
424 }
425}
426
427impl Default for PipelineValidator {
428 fn default() -> Self {
429 Self::new()
430 }
431}
432
433impl ValidationResult {
434 pub fn print_summary(&self) {
436 println!("Pipeline Validation Result");
437 println!("==========================");
438 println!(
439 "Status: {}",
440 if self.is_valid { "VALID" } else { "INVALID" }
441 );
442 println!();
443
444 if !self.errors.is_empty() {
445 println!("Errors: {}", self.errors.len());
446 for error in &self.errors {
447 println!(" [ERROR] {}: {}", error.step, error.message);
448 }
449 println!();
450 }
451
452 if !self.warnings.is_empty() {
453 println!("Warnings: {}", self.warnings.len());
454 for warning in &self.warnings {
455 let severity = match warning.severity {
456 WarningSeverity::Low => "LOW",
457 WarningSeverity::Medium => "MEDIUM",
458 WarningSeverity::High => "HIGH",
459 };
460 println!(" [{}] {}: {}", severity, warning.step, warning.message);
461 }
462 println!();
463 }
464
465 if !self.recommendations.is_empty() {
466 println!("Recommendations: {}", self.recommendations.len());
467 for rec in &self.recommendations {
468 let improvement = if let Some(imp) = rec.expected_improvement {
469 format!(" (expected {:.1}x improvement)", imp)
470 } else {
471 String::new()
472 };
473 println!(" [RECOMMEND] {}{}", rec.message, improvement);
474 }
475 println!();
476 }
477
478 if let Some(mem) = self.estimated_memory {
479 println!("Estimated Memory: {:.2} MB", mem as f64 / 1024.0 / 1024.0);
480 }
481
482 if let Some(time) = self.estimated_time {
483 println!("Estimated Time: {:.2} ms", time);
484 }
485 }
486
487 pub fn high_severity_warnings(&self) -> Vec<&ValidationWarning> {
489 self.warnings
490 .iter()
491 .filter(|w| w.severity == WarningSeverity::High)
492 .collect()
493 }
494
495 pub fn errors_of_type(&self, error_type: ValidationErrorType) -> Vec<&ValidationError> {
497 self.errors
498 .iter()
499 .filter(|e| e.error_type == error_type)
500 .collect()
501 }
502
503 pub fn recommendations_by_category(
505 &self,
506 category: RecommendationCategory,
507 ) -> Vec<&PerformanceRecommendation> {
508 self.recommendations
509 .iter()
510 .filter(|r| r.category == category)
511 .collect()
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use scirs2_core::random::essentials::Normal;
519 use scirs2_core::random::{seeded_rng, Distribution};
520
521 fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
522 let mut rng = seeded_rng(seed);
523 let normal = Normal::new(0.0, 1.0).unwrap();
524
525 let data: Vec<f64> = (0..nrows * ncols)
526 .map(|_| normal.sample(&mut rng))
527 .collect();
528
529 Array2::from_shape_vec((nrows, ncols), data).unwrap()
530 }
531
532 #[test]
533 fn test_pipeline_validator_empty() {
534 let validator = PipelineValidator::new();
535 let result = validator.validate(&[], None).unwrap();
536
537 assert!(!result.is_valid);
538 assert!(!result.errors.is_empty());
539 }
540
541 #[test]
542 fn test_pipeline_validator_redundancy() {
543 let validator = PipelineValidator::new();
544 let steps = vec!["StandardScaler".to_string(), "MinMaxScaler".to_string()];
545
546 let result = validator.validate(&steps, None).unwrap();
547
548 assert!(!result.warnings.is_empty());
550 }
551
552 #[test]
553 fn test_pipeline_validator_ordering() {
554 let validator = PipelineValidator::new();
555 let steps = vec![
556 "StandardScaler".to_string(),
557 "SimpleImputer".to_string(), ];
559
560 let result = validator.validate(&steps, None).unwrap();
561
562 assert!(!result.recommendations.is_empty());
564 }
565
566 #[test]
567 fn test_pipeline_validator_data() {
568 let data = generate_test_data(100, 10, 42);
569 let validator = PipelineValidator::new();
570 let steps = vec!["StandardScaler".to_string()];
571
572 let result = validator.validate(&steps, Some(&data)).unwrap();
573
574 assert!(result.is_valid);
575 }
576
577 #[test]
578 fn test_pipeline_validator_insufficient_samples() {
579 let data = Array2::from_elem((1, 5), 1.0);
580 let validator = PipelineValidator::new();
581 let steps = vec!["KNNImputer".to_string()];
582
583 let result = validator.validate(&steps, Some(&data)).unwrap();
584
585 assert!(!result.is_valid);
587 assert!(!result.errors.is_empty());
588 }
589
590 #[test]
591 fn test_pipeline_validator_nan_without_imputer() {
592 let mut data = generate_test_data(50, 5, 123);
593 data[[0, 0]] = f64::NAN;
594
595 let validator = PipelineValidator::new();
596 let steps = vec!["StandardScaler".to_string()];
597
598 let result = validator.validate(&steps, Some(&data)).unwrap();
599
600 assert!(!result.is_valid);
602 }
603
604 #[test]
605 fn test_memory_estimation() {
606 let data = generate_test_data(1000, 100, 456);
607 let validator = PipelineValidator::new();
608 let steps = vec![
609 "StandardScaler".to_string(),
610 "PolynomialFeatures".to_string(),
611 ];
612
613 let result = validator.validate(&steps, Some(&data)).unwrap();
614
615 assert!(result.estimated_memory.is_some());
616 assert!(result.estimated_memory.unwrap() > 0);
617 }
618
619 #[test]
620 fn test_time_estimation() {
621 let data = generate_test_data(1000, 50, 789);
622 let validator = PipelineValidator::new();
623 let steps = vec!["StandardScaler".to_string(), "PCA".to_string()];
624
625 let result = validator.validate(&steps, Some(&data)).unwrap();
626
627 assert!(result.estimated_time.is_some());
628 assert!(result.estimated_time.unwrap() > 0.0);
629 }
630
631 #[test]
632 fn test_validation_result_filtering() {
633 let validator = PipelineValidator::new();
634 let steps = vec![
635 "StandardScaler".to_string(),
636 "MinMaxScaler".to_string(),
637 "SimpleImputer".to_string(),
638 ];
639
640 let result = validator.validate(&steps, None).unwrap();
641
642 let high_warnings = result.high_severity_warnings();
643 assert!(high_warnings.len() <= result.warnings.len());
644 }
645}