1use crate::error::CoreError;
7use std::collections::HashMap;
8use std::fmt;
9
10use crate::ndarray::compat::ArrayStatCompat;
12use ::ndarray::{ArrayBase, Data, Dimension, ScalarOperand};
13use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
14
15#[cfg(feature = "parallel")]
16use crate::parallel_ops::*;
17
18use super::config::{ErrorSeverity, ValidationErrorType};
19use super::constraints::{ArrayValidationConstraints, StatisticalConstraints};
20use super::errors::{ValidationError, ValidationResult, ValidationStats};
21#[cfg(feature = "validation")]
22use statrs::statistics::Statistics;
23
24pub struct ArrayValidator;
26
27impl ArrayValidator {
28 pub fn new() -> Self {
30 Self
31 }
32
33 pub fn validate_ndarray<S, D>(
35 &self,
36 array: &ArrayBase<S, D>,
37 constraints: &ArrayValidationConstraints,
38 config: &super::config::ValidationConfig,
39 ) -> Result<ValidationResult, CoreError>
40 where
41 S: Data,
42 D: Dimension,
43 S::Elem: Float + fmt::Debug + Send + Sync + ScalarOperand + FromPrimitive,
44 {
45 let start_time = std::time::Instant::now();
46 let mut errors = Vec::new();
47 let mut warnings = Vec::new();
48 let mut stats = ValidationStats::default();
49
50 if let Some(expectedshape) = &constraints.expectedshape {
52 if !self.validate_arrayshape(array, expectedshape)? {
53 errors.push(ValidationError {
54 errortype: ValidationErrorType::ShapeError,
55 fieldpath: constraints
56 .fieldname
57 .clone()
58 .unwrap_or_else(|| "array".to_string()),
59 message: format!(
60 "Array shape {:?} does not match expected {:?}",
61 array.shape(),
62 expectedshape
63 ),
64 expected: Some(format!("{expectedshape:?}")),
65 actual: Some(format!("{:?}", array.shape())),
66 constraint: Some(format!("{:?}", expectedshape)),
67 severity: ErrorSeverity::Error,
68 context: HashMap::new(),
69 });
70 }
71 }
72
73 if constraints.check_numeric_quality {
75 self.validate_numeric_quality(array, &mut errors, &mut warnings, &mut stats)?;
76 }
77
78 if let Some(stat_constraints) = &constraints.statistical_constraints {
80 self.validate_statistical_properties(
81 array,
82 stat_constraints,
83 &mut errors,
84 &mut warnings,
85 )?;
86 }
87
88 if constraints.check_performance {
90 self.validate_array_performance(array, &mut warnings)?;
91 }
92
93 if let Some(ref validator) = constraints.element_validator {
95 self.validate_elements(array, validator, &mut errors, &mut warnings)?;
96 }
97
98 let valid = errors.is_empty()
99 && !warnings
100 .iter()
101 .any(|w| w.severity == ErrorSeverity::Critical);
102 let duration = start_time.elapsed();
103
104 Ok(ValidationResult {
105 valid,
106 errors,
107 warnings,
108 stats,
109 duration,
110 })
111 }
112
113 fn validate_arrayshape<S, D>(
115 &self,
116 array: &ArrayBase<S, D>,
117 expectedshape: &[usize],
118 ) -> Result<bool, CoreError>
119 where
120 S: Data,
121 D: Dimension,
122 {
123 let actualshape = array.shape();
124 Ok(actualshape == expectedshape)
125 }
126
127 #[allow(clippy::ptr_arg)]
129 fn validate_numeric_quality<S, D>(
130 &self,
131 array: &ArrayBase<S, D>,
132 errors: &mut Vec<ValidationError>,
133 warnings: &mut Vec<ValidationError>,
134 stats: &mut ValidationStats,
135 ) -> Result<(), CoreError>
136 where
137 S: Data,
138 D: Dimension,
139 S::Elem: Float + fmt::Debug + Send + Sync,
140 {
141 let mut nan_count = 0;
142 let mut inf_count = 0;
143 let total_count = array.len();
144
145 #[cfg(feature = "parallel")]
146 let check_parallel = array.len() > 10000;
147
148 #[cfg(feature = "parallel")]
149 if check_parallel {
150 if let Some(slice) = array.as_slice() {
151 let results: Vec<_> = slice
152 .par_iter()
153 .map(|&value| {
154 let is_nan = value.is_nan();
155 let is_inf = value.is_infinite();
156 (is_nan, is_inf)
157 })
158 .collect();
159
160 for (is_nan, is_inf) in results {
161 if is_nan {
162 nan_count += 1;
163 }
164 if is_inf {
165 inf_count += 1;
166 }
167 }
168 } else {
169 for value in array.iter() {
171 if value.is_nan() {
172 nan_count += 1;
173 }
174 if value.is_infinite() {
175 inf_count += 1;
176 }
177 }
178 }
179 }
180
181 #[cfg(not(feature = "parallel"))]
182 {
183 for value in array.iter() {
184 if value.is_nan() {
185 nan_count += 1;
186 }
187 if value.is_infinite() {
188 inf_count += 1;
189 }
190 }
191 }
192
193 #[cfg(feature = "parallel")]
194 if !check_parallel {
195 for value in array.iter() {
196 if value.is_nan() {
197 nan_count += 1;
198 }
199 if value.is_infinite() {
200 inf_count += 1;
201 }
202 }
203 }
204
205 stats.fields_validated += 1;
206 stats.constraints_checked += 2; stats.elements_processed += total_count;
208
209 if nan_count > 0 {
210 warnings.push(ValidationError {
211 errortype: ValidationErrorType::InvalidNumeric,
212 fieldpath: "array".to_string(),
213 message: format!(
214 "Found {} NaN values out of {} total",
215 nan_count, total_count
216 ),
217 expected: Some("finite values".to_string()),
218 actual: Some(format!("{} NaN values", nan_count)),
219 constraint: Some("numeric_quality".to_string()),
220 severity: ErrorSeverity::Warning,
221 context: HashMap::new(),
222 });
223 }
224
225 if inf_count > 0 {
226 warnings.push(ValidationError {
227 errortype: ValidationErrorType::InvalidNumeric,
228 fieldpath: "array".to_string(),
229 message: format!(
230 "Found {} infinite values out of {} total",
231 inf_count, total_count
232 ),
233 expected: Some("finite values".to_string()),
234 actual: Some(format!("{} infinite values", inf_count)),
235 constraint: Some("numeric_quality".to_string()),
236 severity: ErrorSeverity::Warning,
237 context: HashMap::new(),
238 });
239 }
240
241 Ok(())
242 }
243
244 fn validate_statistical_properties<S, D>(
246 &self,
247 array: &ArrayBase<S, D>,
248 constraints: &StatisticalConstraints,
249 errors: &mut Vec<ValidationError>,
250 warnings: &mut Vec<ValidationError>,
251 ) -> Result<(), CoreError>
252 where
253 S: Data,
254 D: Dimension,
255 S::Elem: Float + fmt::Debug + ScalarOperand + FromPrimitive,
256 {
257 if array.is_empty() {
258 return Ok(());
259 }
260
261 let mean = array.mean_or(S::Elem::zero());
263 let std_dev = array.std(num_traits::cast(1.0).expect("Operation failed"));
264
265 if let Some(min_mean) = constraints.min_mean {
267 let min_mean_typed: S::Elem = num_traits::cast(min_mean).unwrap_or(S::Elem::zero());
268 if mean < min_mean_typed {
269 errors.push(ValidationError {
270 errortype: ValidationErrorType::ConstraintViolation,
271 fieldpath: "array.mean".to_string(),
272 message: format!("Mean {:?} is less than minimum {:?}", mean, min_mean),
273 expected: Some(format!("{min_mean}")),
274 actual: Some(format!("{mean:?}")),
275 constraint: Some("statistical.min_mean".to_string()),
276 severity: ErrorSeverity::Error,
277 context: HashMap::new(),
278 });
279 }
280 }
281
282 if let Some(max_mean) = constraints.max_mean {
283 let max_mean_typed: S::Elem = num_traits::cast(max_mean).unwrap_or(S::Elem::zero());
284 if mean > max_mean_typed {
285 errors.push(ValidationError {
286 errortype: ValidationErrorType::ConstraintViolation,
287 fieldpath: "array.mean".to_string(),
288 message: format!("Mean {:?} is greater than maximum {:?}", mean, max_mean),
289 expected: Some(format!("{max_mean}")),
290 actual: Some(format!("{mean:?}")),
291 constraint: Some("statistical.max_mean".to_string()),
292 severity: ErrorSeverity::Error,
293 context: HashMap::new(),
294 });
295 }
296 }
297
298 if let Some(min_std) = constraints.min_std {
300 let min_std_typed: S::Elem = num_traits::cast(min_std).unwrap_or(S::Elem::zero());
301 if std_dev < min_std_typed {
302 warnings.push(ValidationError {
303 errortype: ValidationErrorType::ConstraintViolation,
304 fieldpath: "array.std".to_string(),
305 message: format!(
306 "Array standard deviation {:?} is below minimum {:?}",
307 std_dev, min_std
308 ),
309 expected: Some(format!("{min_std}")),
310 actual: Some(format!("{std_dev:?}")),
311 constraint: Some("statistical.min_std".to_string()),
312 severity: ErrorSeverity::Warning,
313 context: HashMap::new(),
314 });
315 }
316 }
317
318 if let Some(max_std) = constraints.max_std {
319 let max_std_typed: S::Elem = num_traits::cast(max_std).unwrap_or(S::Elem::zero());
320 if std_dev > max_std_typed {
321 warnings.push(ValidationError {
322 errortype: ValidationErrorType::ConstraintViolation,
323 fieldpath: "array.std".to_string(),
324 message: format!(
325 "Array standard deviation {:?} exceeds maximum {:?}",
326 std_dev, max_std
327 ),
328 expected: Some(format!("{max_std}")),
329 actual: Some(format!("{std_dev:?}")),
330 constraint: Some("statistical.max_std".to_string()),
331 severity: ErrorSeverity::Warning,
332 context: HashMap::new(),
333 });
334 }
335 }
336
337 Ok(())
338 }
339
340 fn validate_array_performance<S, D>(
342 &self,
343 array: &ArrayBase<S, D>,
344 warnings: &mut Vec<ValidationError>,
345 ) -> Result<(), CoreError>
346 where
347 S: Data,
348 D: Dimension,
349 S::Elem: fmt::Debug,
350 {
351 let element_count = array.len();
352 let element_size = std::mem::size_of::<S::Elem>();
353 let total_size = element_count * element_size;
354
355 const LARGE_ARRAY_THRESHOLD: usize = 100_000_000; if element_count > LARGE_ARRAY_THRESHOLD {
358 warnings.push(ValidationError {
359 errortype: ValidationErrorType::Performance,
360 fieldpath: "array.size".to_string(),
361 message: format!(
362 "Large array detected: {} elements ({} bytes). Consider chunking for better performance.",
363 element_count, total_size
364 ),
365 expected: Some(format!("<= {} elements", LARGE_ARRAY_THRESHOLD)),
366 actual: Some(format!("{} elements", element_count)),
367 constraint: Some("performance.max_elements".to_string()),
368 severity: ErrorSeverity::Warning,
369 context: HashMap::new(),
370 });
371 }
372
373 const LARGE_MEMORY_THRESHOLD: usize = 1_000_000_000; if total_size > LARGE_MEMORY_THRESHOLD {
376 warnings.push(ValidationError {
377 errortype: ValidationErrorType::Performance,
378 fieldpath: "array.memory".to_string(),
379 message: format!(
380 "High memory usage: {} bytes. Consider memory-efficient operations.",
381 total_size
382 ),
383 expected: Some(format!("<= {} bytes", LARGE_MEMORY_THRESHOLD)),
384 actual: Some(format!("{} bytes", total_size)),
385 constraint: Some("performance.max_memory".to_string()),
386 severity: ErrorSeverity::Warning,
387 context: HashMap::new(),
388 });
389 }
390
391 Ok(())
392 }
393
394 #[allow(clippy::ptr_arg)]
396 fn validate_elements<S, D>(
397 &self,
398 array: &ArrayBase<S, D>,
399 validator: &super::constraints::ElementValidatorFn<f64>,
400 errors: &mut Vec<ValidationError>,
401 warnings: &mut Vec<ValidationError>,
402 ) -> Result<(), CoreError>
403 where
404 S: Data,
405 D: Dimension,
406 S::Elem: Float + fmt::Debug + num_traits::cast::ToPrimitive,
407 {
408 let mut invalid_count = 0;
409
410 for (index, element) in array.iter().enumerate() {
411 if let Some(value) = element.to_f64() {
412 if !validator(&value) {
413 invalid_count += 1;
414 if invalid_count <= 10 {
415 errors.push(ValidationError {
417 errortype: ValidationErrorType::CustomRuleFailure,
418 fieldpath: format!("array[{}]", index),
419 message: format!("Element {:?} failed custom validation", element),
420 expected: Some("valid element".to_string()),
421 actual: Some(format!("{element:?}")),
422 constraint: Some("custom_element_validator".to_string()),
423 severity: ErrorSeverity::Error,
424 context: HashMap::new(),
425 });
426 }
427 }
428 }
429 }
430
431 if invalid_count > 10 {
433 errors.push(ValidationError {
434 errortype: ValidationErrorType::CustomRuleFailure,
435 fieldpath: "array".to_string(),
436 message: format!(
437 "Total of {} elements failed custom validation (showing first 10)",
438 invalid_count
439 ),
440 expected: Some("all elements to pass validation".to_string()),
441 actual: Some(format!("{} failed elements", invalid_count)),
442 constraint: Some("custom_element_validator".to_string()),
443 severity: ErrorSeverity::Error,
444 context: HashMap::new(),
445 });
446 }
447
448 Ok(())
449 }
450}
451
452impl Default for ArrayValidator {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use ::ndarray::Array1;
462
463 #[test]
464 fn test_array_validator() {
465 let validator = ArrayValidator::new();
466 let array = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
467 let config = super::super::config::ValidationConfig::default();
468
469 let constraints = ArrayValidationConstraints::new()
470 .withshape(vec![5])
471 .with_fieldname("test_array")
472 .check_numeric_quality();
473
474 let result = validator
475 .validate_ndarray(&array, &constraints, &config)
476 .expect("Operation failed");
477 assert!(result.is_valid());
478 }
479
480 #[test]
481 fn testshape_validation() {
482 let validator = ArrayValidator::new();
483 let array = Array1::from_vec(vec![1.0, 2.0, 3.0]);
484 let config = super::super::config::ValidationConfig::default();
485
486 let constraints = ArrayValidationConstraints::new().withshape(vec![3]);
488 let result = validator
489 .validate_ndarray(&array, &constraints, &config)
490 .expect("Operation failed");
491 assert!(result.is_valid());
492
493 let constraints = ArrayValidationConstraints::new().withshape(vec![5]);
495 let result = validator
496 .validate_ndarray(&array, &constraints, &config)
497 .expect("Operation failed");
498 assert!(!result.is_valid());
499 assert_eq!(result.errors().len(), 1);
500 assert_eq!(
501 result.errors()[0].errortype,
502 ValidationErrorType::ShapeError
503 );
504 }
505
506 #[test]
507 fn test_numeric_quality_validation() {
508 let validator = ArrayValidator::new();
509 let config = super::super::config::ValidationConfig::default();
510
511 let array = Array1::from_vec(vec![1.0, f64::NAN, 3.0]);
513 let constraints = ArrayValidationConstraints::new().check_numeric_quality();
514
515 let result = validator
516 .validate_ndarray(&array, &constraints, &config)
517 .expect("Operation failed");
518 assert!(result.is_valid()); assert!(result.has_warnings());
520 assert_eq!(result.warnings().len(), 1);
521 assert_eq!(
522 result.warnings()[0].errortype,
523 ValidationErrorType::InvalidNumeric
524 );
525 }
526
527 #[test]
528 fn test_statistical_constraints() {
529 let validator = ArrayValidator::new();
530 let config = super::super::config::ValidationConfig::default();
531
532 let array = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
534
535 let constraints = ArrayValidationConstraints::new()
536 .with_statistical_constraints(StatisticalConstraints::new().with_mean_range(2.0, 4.0));
537
538 let result = validator
539 .validate_ndarray(&array, &constraints, &config)
540 .expect("Operation failed");
541 assert!(result.is_valid());
542
543 let failing_constraints = ArrayValidationConstraints::new().with_statistical_constraints(
545 StatisticalConstraints::new().with_mean_range(5.0, 6.0), );
547
548 let result = validator
549 .validate_ndarray(&array, &failing_constraints, &config)
550 .expect("Operation failed");
551 assert!(!result.is_valid());
552 }
553
554 #[test]
555 fn test_performance_validation() {
556 let validator = ArrayValidator::new();
557 let config = super::super::config::ValidationConfig::default();
558
559 let small_array = Array1::from_vec(vec![1.0, 2.0, 3.0]);
561 let constraints = ArrayValidationConstraints::new().check_performance();
562
563 let result = validator
564 .validate_ndarray(&small_array, &constraints, &config)
565 .expect("Operation failed");
566 assert!(result.is_valid());
567 assert!(result.warnings().is_empty());
568 }
569
570 #[test]
571 fn test_element_validation() {
572 let validator = ArrayValidator::new();
573 let config = super::super::config::ValidationConfig::default();
574 let array = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
575
576 let mut constraints = ArrayValidationConstraints::new();
578 constraints.element_validator = Some(Box::new(|&x| x <= 3.0));
579
580 let result = validator
581 .validate_ndarray(&array, &constraints, &config)
582 .expect("Operation failed");
583 assert!(!result.is_valid()); assert!(!result.errors().is_empty());
585 }
586}