1use crate::Dataset;
2use std::collections::HashMap;
3use tenflowers_core::{Result, Tensor};
4
5#[allow(dead_code)]
7type SampleData<T> = (usize, (Tensor<T>, Tensor<T>));
8type SampleList<T> = [(usize, (Tensor<T>, Tensor<T>))];
9
10#[derive(Debug, Clone)]
11pub struct ValidationConfig {
12 pub check_schema: bool,
13 pub check_ranges: bool,
14 pub check_duplicates: bool,
15 pub check_outliers: bool,
16 pub outlier_threshold: f64, }
18
19impl Default for ValidationConfig {
20 fn default() -> Self {
21 Self {
22 check_schema: true,
23 check_ranges: true,
24 check_duplicates: true,
25 check_outliers: true,
26 outlier_threshold: 3.0, }
28 }
29}
30
31#[derive(Debug, Clone)]
32pub struct SchemaInfo {
33 pub feature_shape: Vec<usize>,
34 pub label_shape: Vec<usize>,
35 pub expected_feature_type: String,
36 pub expected_label_type: String,
37}
38
39#[derive(Debug, Clone)]
40pub struct RangeConstraint<T> {
41 pub min_value: Option<T>,
42 pub max_value: Option<T>,
43}
44
45impl<T> RangeConstraint<T> {
46 pub fn new(min_value: Option<T>, max_value: Option<T>) -> Self {
47 Self {
48 min_value,
49 max_value,
50 }
51 }
52
53 pub fn min(min_value: T) -> Self {
54 Self {
55 min_value: Some(min_value),
56 max_value: None,
57 }
58 }
59
60 pub fn max(max_value: T) -> Self {
61 Self {
62 min_value: None,
63 max_value: Some(max_value),
64 }
65 }
66
67 pub fn range(min_value: T, max_value: T) -> Self {
68 Self {
69 min_value: Some(min_value),
70 max_value: Some(max_value),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
76pub struct ValidationResult {
77 pub is_valid: bool,
78 pub schema_errors: Vec<String>,
79 pub range_errors: Vec<String>,
80 pub duplicate_indices: Vec<usize>,
81 pub outlier_indices: Vec<usize>,
82}
83
84impl ValidationResult {
85 pub fn new() -> Self {
86 Self {
87 is_valid: true,
88 schema_errors: Vec::new(),
89 range_errors: Vec::new(),
90 duplicate_indices: Vec::new(),
91 outlier_indices: Vec::new(),
92 }
93 }
94
95 pub fn has_errors(&self) -> bool {
96 !self.schema_errors.is_empty()
97 || !self.range_errors.is_empty()
98 || !self.duplicate_indices.is_empty()
99 || !self.outlier_indices.is_empty()
100 }
101
102 pub fn add_schema_error(&mut self, error: String) {
103 self.schema_errors.push(error);
104 self.is_valid = false;
105 }
106
107 pub fn add_range_error(&mut self, error: String) {
108 self.range_errors.push(error);
109 self.is_valid = false;
110 }
111
112 pub fn add_duplicate(&mut self, index: usize) {
113 self.duplicate_indices.push(index);
114 self.is_valid = false;
115 }
116
117 pub fn add_outlier(&mut self, index: usize) {
118 self.outlier_indices.push(index);
119 self.is_valid = false;
120 }
121}
122
123impl Default for ValidationResult {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129pub struct DataValidator<T> {
130 config: ValidationConfig,
131 schema_info: Option<SchemaInfo>,
132 feature_range: Option<RangeConstraint<T>>,
133 label_range: Option<RangeConstraint<T>>,
134}
135
136impl<T> DataValidator<T>
137where
138 T: Clone
139 + Default
140 + PartialEq
141 + PartialOrd
142 + std::fmt::Display
143 + scirs2_core::numeric::Float
144 + Send
145 + Sync
146 + 'static,
147{
148 pub fn new(config: ValidationConfig) -> Self {
149 Self {
150 config,
151 schema_info: None,
152 feature_range: None,
153 label_range: None,
154 }
155 }
156
157 pub fn with_schema(mut self, schema: SchemaInfo) -> Self {
158 self.schema_info = Some(schema);
159 self
160 }
161
162 pub fn with_feature_range(mut self, range: RangeConstraint<T>) -> Self {
163 self.feature_range = Some(range);
164 self
165 }
166
167 pub fn with_label_range(mut self, range: RangeConstraint<T>) -> Self {
168 self.label_range = Some(range);
169 self
170 }
171
172 pub fn validate<D: Dataset<T>>(&self, dataset: &D) -> Result<ValidationResult> {
173 let mut result = ValidationResult::new();
174
175 if dataset.is_empty() {
176 result.add_schema_error("Dataset is empty".to_string());
177 return Ok(result);
178 }
179
180 let mut samples = Vec::new();
182 for i in 0..dataset.len() {
183 match dataset.get(i) {
184 Ok(sample) => samples.push((i, sample)),
185 Err(e) => {
186 result.add_schema_error(format!("Failed to get sample {i}: {e:?}"));
187 }
188 }
189 }
190
191 if self.config.check_schema {
192 self.validate_schema(&samples, &mut result)?;
193 }
194
195 if self.config.check_ranges {
196 self.validate_ranges(&samples, &mut result)?;
197 }
198
199 if self.config.check_duplicates {
200 self.validate_duplicates(&samples, &mut result)?;
201 }
202
203 if self.config.check_outliers {
204 self.validate_outliers(&samples, &mut result)?;
205 }
206
207 Ok(result)
208 }
209
210 fn validate_schema(
211 &self,
212 samples: &SampleList<T>,
213 result: &mut ValidationResult,
214 ) -> Result<()> {
215 if let Some(ref schema) = self.schema_info {
216 for (index, (features, labels)) in samples {
217 if features.shape().dims() != schema.feature_shape {
219 result.add_schema_error(format!(
220 "Sample {}: Feature shape mismatch. Expected {:?}, got {:?}",
221 index,
222 schema.feature_shape,
223 features.shape().dims()
224 ));
225 }
226
227 if labels.shape().dims() != schema.label_shape {
229 result.add_schema_error(format!(
230 "Sample {}: Label shape mismatch. Expected {:?}, got {:?}",
231 index,
232 schema.label_shape,
233 labels.shape().dims()
234 ));
235 }
236 }
237 }
238 Ok(())
239 }
240
241 fn validate_ranges(
242 &self,
243 samples: &SampleList<T>,
244 result: &mut ValidationResult,
245 ) -> Result<()> {
246 for (index, (features, labels)) in samples {
247 if let Some(ref range) = self.feature_range {
249 if let Some(feature_data) = features.as_slice() {
250 for (i, &value) in feature_data.iter().enumerate() {
251 if let Some(min_val) = &range.min_value {
252 if value < *min_val {
253 result.add_range_error(format!(
254 "Sample {index}: Feature {i} value {value} below minimum {min_val}"
255 ));
256 }
257 }
258 if let Some(max_val) = &range.max_value {
259 if value > *max_val {
260 result.add_range_error(format!(
261 "Sample {index}: Feature {i} value {value} above maximum {max_val}"
262 ));
263 }
264 }
265 }
266 }
267 }
268
269 if let Some(ref range) = self.label_range {
271 if let Some(label_data) = labels.as_slice() {
272 for (i, &value) in label_data.iter().enumerate() {
273 if let Some(min_val) = &range.min_value {
274 if value < *min_val {
275 result.add_range_error(format!(
276 "Sample {index}: Label {i} value {value} below minimum {min_val}"
277 ));
278 }
279 }
280 if let Some(max_val) = &range.max_value {
281 if value > *max_val {
282 result.add_range_error(format!(
283 "Sample {index}: Label {i} value {value} above maximum {max_val}"
284 ));
285 }
286 }
287 }
288 }
289 }
290 }
291 Ok(())
292 }
293
294 fn validate_duplicates(
295 &self,
296 samples: &SampleList<T>,
297 result: &mut ValidationResult,
298 ) -> Result<()> {
299 let mut seen_features: HashMap<Vec<String>, Vec<usize>> = HashMap::new();
300
301 for (index, (features, _)) in samples {
302 if let Some(feature_data) = features.as_slice() {
303 let feature_key: Vec<String> = feature_data
305 .iter()
306 .map(|&x| format!("{x:.6}")) .collect();
308
309 seen_features.entry(feature_key).or_default().push(*index);
310 }
311 }
312
313 for (_, indices) in seen_features {
315 if indices.len() > 1 {
316 for &index in &indices[1..] {
317 result.add_duplicate(index);
319 }
320 }
321 }
322
323 Ok(())
324 }
325
326 fn validate_outliers(
327 &self,
328 samples: &SampleList<T>,
329 result: &mut ValidationResult,
330 ) -> Result<()> {
331 if samples.is_empty() {
332 return Ok(());
333 }
334
335 let mut feature_values: Vec<Vec<T>> = Vec::new();
337 let feature_size = if let Some((_, (features, _))) = samples.first() {
338 if let Some(data) = features.as_slice() {
339 data.len()
340 } else {
341 return Ok(()); }
343 } else {
344 return Ok(());
345 };
346
347 for _ in 0..feature_size {
349 feature_values.push(Vec::new());
350 }
351
352 for (_, (features, _)) in samples {
354 if let Some(data) = features.as_slice() {
355 for (i, &value) in data.iter().enumerate() {
356 if i < feature_values.len() {
357 feature_values[i].push(value);
358 }
359 }
360 }
361 }
362
363 let mut means = Vec::new();
365 let mut stds = Vec::new();
366
367 for values in &feature_values {
368 if values.is_empty() {
369 continue;
370 }
371
372 let mean = values.iter().copied().fold(T::zero(), |acc, x| acc + x)
373 / T::from(values.len()).expect("values length should convert to float");
374 means.push(mean);
375
376 let variance = values
377 .iter()
378 .map(|&x| {
379 let diff = x - mean;
380 diff * diff
381 })
382 .fold(T::zero(), |acc, x| acc + x)
383 / T::from(values.len()).expect("values length should convert to float");
384
385 let std = variance.sqrt();
386 stds.push(std);
387 }
388
389 let threshold = T::from(self.config.outlier_threshold)
391 .expect("outlier threshold should convert to float");
392
393 for (index, (features, _)) in samples {
394 if let Some(data) = features.as_slice() {
395 for (i, &value) in data.iter().enumerate() {
396 if i < means.len() && i < stds.len() {
397 let mean = means[i];
398 let std = stds[i];
399
400 if std > T::zero() {
401 let z_score = ((value - mean) / std).abs();
402 if z_score > threshold {
403 result.add_outlier(*index);
404 break; }
406 }
407 }
408 }
409 }
410 }
411
412 Ok(())
413 }
414}
415
416pub trait DatasetValidationExt<T> {
417 fn validate(&self, validator: &DataValidator<T>) -> Result<ValidationResult>;
418 fn validate_with_config(&self, config: ValidationConfig) -> Result<ValidationResult>;
419 fn is_valid(&self) -> Result<bool>;
420}
421
422impl<T, D: Dataset<T>> DatasetValidationExt<T> for D
423where
424 T: Clone
425 + Default
426 + PartialEq
427 + PartialOrd
428 + std::fmt::Display
429 + scirs2_core::numeric::Float
430 + Send
431 + Sync
432 + 'static,
433{
434 fn validate(&self, validator: &DataValidator<T>) -> Result<ValidationResult> {
435 validator.validate(self)
436 }
437
438 fn validate_with_config(&self, config: ValidationConfig) -> Result<ValidationResult> {
439 let validator = DataValidator::new(config);
440 validator.validate(self)
441 }
442
443 fn is_valid(&self) -> Result<bool> {
444 let config = ValidationConfig::default();
445 let result = self.validate_with_config(config)?;
446 Ok(!result.has_errors())
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use crate::TensorDataset;
454 use tenflowers_core::Tensor;
455
456 #[test]
457 fn test_validation_config() {
458 let config = ValidationConfig::default();
459 assert!(config.check_schema);
460 assert!(config.check_ranges);
461 assert!(config.check_duplicates);
462 assert!(config.check_outliers);
463 assert_eq!(config.outlier_threshold, 3.0);
464 }
465
466 #[test]
467 fn test_range_constraint() {
468 let range = RangeConstraint::range(0.0f32, 1.0f32);
469 assert_eq!(range.min_value, Some(0.0));
470 assert_eq!(range.max_value, Some(1.0));
471
472 let min_only = RangeConstraint::min(-1.0f32);
473 assert_eq!(min_only.min_value, Some(-1.0));
474 assert_eq!(min_only.max_value, None);
475
476 let max_only = RangeConstraint::max(10.0f32);
477 assert_eq!(max_only.min_value, None);
478 assert_eq!(max_only.max_value, Some(10.0));
479 }
480
481 #[test]
482 fn test_validation_result() {
483 let mut result = ValidationResult::new();
484 assert!(result.is_valid);
485 assert!(!result.has_errors());
486
487 result.add_schema_error("Schema error".to_string());
488 assert!(!result.is_valid);
489 assert!(result.has_errors());
490 assert_eq!(result.schema_errors.len(), 1);
491
492 result.add_duplicate(5);
493 assert_eq!(result.duplicate_indices.len(), 1);
494 assert_eq!(result.duplicate_indices[0], 5);
495 }
496
497 #[test]
498 fn test_schema_validation() {
499 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
500 .expect("test: tensor creation should succeed");
501 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
502 .expect("test: tensor creation should succeed");
503 let dataset = TensorDataset::new(features, labels);
504
505 let schema = SchemaInfo {
506 feature_shape: vec![2], label_shape: vec![], expected_feature_type: "f32".to_string(),
509 expected_label_type: "f32".to_string(),
510 };
511
512 let validator = DataValidator::new(ValidationConfig::default()).with_schema(schema);
513
514 let result = validator
515 .validate(&dataset)
516 .expect("test: operation should succeed");
517 assert!(result.is_valid);
518 assert!(!result.has_errors());
519 }
520
521 #[test]
522 fn test_range_validation() {
523 let features = Tensor::<f32>::from_vec(
524 vec![0.5, 0.8, 1.2, 0.3], &[2, 2],
526 )
527 .expect("test: operation should succeed");
528 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
529 .expect("test: tensor creation should succeed");
530 let dataset = TensorDataset::new(features, labels);
531
532 let feature_range = RangeConstraint::range(0.0f32, 1.0f32);
533 let validator =
534 DataValidator::new(ValidationConfig::default()).with_feature_range(feature_range);
535
536 let result = validator
537 .validate(&dataset)
538 .expect("test: operation should succeed");
539 assert!(!result.is_valid);
540 assert!(result.has_errors());
541 assert!(!result.range_errors.is_empty());
542 }
543
544 #[test]
545 fn test_duplicate_detection() {
546 let features = Tensor::<f32>::from_vec(
547 vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0], &[3, 2],
549 )
550 .expect("test: operation should succeed");
551 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0], &[3])
552 .expect("test: tensor creation should succeed");
553 let dataset = TensorDataset::new(features, labels);
554
555 let config = ValidationConfig {
556 check_schema: false,
557 check_ranges: false,
558 check_duplicates: true,
559 check_outliers: false,
560 outlier_threshold: 3.0,
561 };
562
563 let validator = DataValidator::new(config);
564 let result = validator
565 .validate(&dataset)
566 .expect("test: operation should succeed");
567
568 assert!(!result.is_valid);
569 assert!(result.has_errors());
570 assert!(!result.duplicate_indices.is_empty());
571 }
572
573 #[test]
574 fn test_outlier_detection() {
575 let features = Tensor::<f32>::from_vec(
576 vec![1.0, 1.0, 1.1, 1.0, 1.2, 1.0, 1.0, 1.0, 100.0, 1.0], &[5, 2],
578 )
579 .expect("test: operation should succeed");
580 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5])
581 .expect("test: tensor creation should succeed");
582 let dataset = TensorDataset::new(features, labels);
583
584 let config = ValidationConfig {
585 check_schema: false,
586 check_ranges: false,
587 check_duplicates: false,
588 check_outliers: true,
589 outlier_threshold: 1.0, };
591
592 let validator = DataValidator::new(config);
593 let result = validator
594 .validate(&dataset)
595 .expect("test: operation should succeed");
596
597 assert!(!result.is_valid);
598 assert!(result.has_errors());
599 assert!(!result.outlier_indices.is_empty());
600 }
601
602 #[test]
603 fn test_dataset_validation_ext() {
604 let features = Tensor::<f32>::from_vec(vec![0.5, 0.8, 0.3, 0.7], &[2, 2])
605 .expect("test: tensor creation should succeed");
606 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
607 .expect("test: tensor creation should succeed");
608 let dataset = TensorDataset::new(features, labels);
609
610 let is_valid = dataset.is_valid().expect("test: operation should succeed");
611 assert!(is_valid);
612
613 let config = ValidationConfig::default();
614 let result = dataset
615 .validate_with_config(config)
616 .expect("test: operation should succeed");
617 assert!(result.is_valid);
618 }
619}