Skip to main content

tenflowers_dataset/config/
validation.rs

1//! Configuration validation utilities
2//!
3//! This module provides comprehensive validation for configuration values
4//! with descriptive error messages and suggestions for fixing issues.
5
6use super::{
7    AsyncIoConfig, AudioConfig, CacheConfig, DataLoaderConfig, DatasetConfig, FormatConfig,
8    GlobalConfig, GpuConfig, Hdf5FormatConfig, ImageFormatConfig, LoggingConfig, MonitoringConfig,
9    ParquetFormatConfig, PerformanceConfig, TextFormatConfig, TransformConfig,
10};
11use crate::{Result, TensorError};
12use std::collections::HashMap;
13
14/// Result type for validation operations
15pub type ValidationResult<T = ()> = std::result::Result<T, ValidationError>;
16
17/// Validation error with detailed information
18#[derive(Debug, Clone)]
19pub struct ValidationError {
20    /// Field path that failed validation
21    pub field: String,
22    /// Error message
23    pub message: String,
24    /// Current value that failed validation
25    pub current_value: Option<String>,
26    /// Suggested valid values or ranges
27    pub suggestions: Vec<String>,
28}
29
30impl ValidationError {
31    /// Create a new validation error
32    pub fn new(field: &str, message: &str) -> Self {
33        Self {
34            field: field.to_string(),
35            message: message.to_string(),
36            current_value: None,
37            suggestions: Vec::new(),
38        }
39    }
40
41    /// Set the current value that failed validation
42    pub fn with_current_value(mut self, value: &str) -> Self {
43        self.current_value = Some(value.to_string());
44        self
45    }
46
47    /// Add suggestions for valid values
48    pub fn with_suggestions(mut self, suggestions: Vec<&str>) -> Self {
49        self.suggestions = suggestions.into_iter().map(|s| s.to_string()).collect();
50        self
51    }
52}
53
54impl std::fmt::Display for ValidationError {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        write!(f, "Validation error in '{}': {}", self.field, self.message)?;
57
58        if let Some(ref value) = self.current_value {
59            write!(f, " (current value: {})", value)?;
60        }
61
62        if !self.suggestions.is_empty() {
63            write!(f, " - Suggestions: {}", self.suggestions.join(", "))?;
64        }
65
66        Ok(())
67    }
68}
69
70impl std::error::Error for ValidationError {}
71
72impl From<ValidationError> for TensorError {
73    fn from(err: ValidationError) -> Self {
74        TensorError::invalid_argument(err.to_string())
75    }
76}
77
78/// Configuration validation trait
79pub trait ConfigValidation {
80    /// Validate the configuration
81    fn validate(&self) -> Result<()>;
82
83    /// Get validation warnings (non-fatal issues)
84    fn get_warnings(&self) -> Vec<String> {
85        Vec::new()
86    }
87}
88
89impl ConfigValidation for GlobalConfig {
90    fn validate(&self) -> Result<()> {
91        let mut errors = Vec::new();
92
93        // Validate dataset configuration
94        if let Err(e) = self.dataset.validate() {
95            errors.push(e);
96        }
97
98        // Validate dataloader configuration
99        if let Err(e) = self.dataloader.validate() {
100            errors.push(e);
101        }
102
103        // Validate transforms configuration
104        if let Err(e) = self.transforms.validate() {
105            errors.push(e);
106        }
107
108        // Validate performance configuration
109        if let Err(e) = self.performance.validate() {
110            errors.push(e);
111        }
112
113        // Validate cache configuration
114        if let Err(e) = self.cache.validate() {
115            errors.push(e);
116        }
117
118        // Validate GPU configuration
119        if let Err(e) = self.gpu.validate() {
120            errors.push(e);
121        }
122
123        // Validate audio configuration
124        if let Err(e) = self.audio.validate() {
125            errors.push(e);
126        }
127
128        // Validate format configurations
129        if let Err(e) = self.formats.validate() {
130            errors.push(e);
131        }
132
133        // Validate logging configuration
134        if let Err(e) = self.logging.validate() {
135            errors.push(e);
136        }
137
138        // Cross-configuration validation
139        self.validate_cross_config_constraints(&mut errors);
140
141        if !errors.is_empty() {
142            let error_messages: Vec<String> = errors.into_iter().map(|e| e.to_string()).collect();
143            return Err(TensorError::invalid_argument(format!(
144                "Configuration validation failed:\n{}",
145                error_messages.join("\n")
146            )));
147        }
148
149        Ok(())
150    }
151
152    fn get_warnings(&self) -> Vec<String> {
153        let mut warnings = Vec::new();
154
155        // Performance warnings
156        if self.dataloader.num_workers > num_cpus::get() * 2 {
157            warnings.push(format!(
158                "dataloader.num_workers ({}) is much higher than CPU count ({}). This may cause performance degradation.",
159                self.dataloader.num_workers,
160                num_cpus::get()
161            ));
162        }
163
164        if self.performance.memory_pool_size > 8192 {
165            warnings.push("performance.memory_pool_size is very large (>8GB). Make sure you have sufficient RAM.".to_string());
166        }
167
168        // GPU warnings
169        if self.gpu.enabled && self.gpu.memory_pool_mb > 4096 {
170            warnings.push(
171                "gpu.memory_pool_mb is very large (>4GB). Make sure your GPU has sufficient VRAM."
172                    .to_string(),
173            );
174        }
175
176        // Cache warnings
177        if self.cache.enabled && self.cache.size_mb > self.performance.memory_pool_size {
178            warnings.push("cache.size_mb is larger than performance.memory_pool_size. This may cause memory pressure.".to_string());
179        }
180
181        warnings
182    }
183}
184
185impl GlobalConfig {
186    fn validate_cross_config_constraints(&self, errors: &mut Vec<TensorError>) {
187        // Validate that cache size doesn't exceed memory pool
188        if self.cache.enabled && self.cache.size_mb > self.performance.memory_pool_size {
189            errors.push(
190                ValidationError::new(
191                    "cache.size_mb",
192                    "Cache size cannot be larger than memory pool size",
193                )
194                .with_current_value(&self.cache.size_mb.to_string())
195                .with_suggestions(vec![
196                    &format!("Set to {} or less", self.performance.memory_pool_size),
197                    "Increase performance.memory_pool_size",
198                ])
199                .into(),
200            );
201        }
202
203        // Validate GPU settings consistency
204        if self.gpu.enabled && self.transforms.enable_gpu && self.gpu.device_id.is_none() {
205            errors.push(
206                ValidationError::new(
207                    "gpu.device_id",
208                    "GPU device ID should be specified when GPU acceleration is enabled",
209                )
210                .with_suggestions(vec!["Set gpu.device_id to a valid GPU index"])
211                .into(),
212            );
213        }
214
215        // Validate async I/O settings
216        if self.performance.async_io.enabled && self.performance.async_io.io_threads == 0 {
217            errors.push(
218                ValidationError::new(
219                    "performance.async_io.io_threads",
220                    "Async I/O threads must be greater than 0 when async I/O is enabled",
221                )
222                .with_current_value("0")
223                .with_suggestions(vec!["Set to 1 or more", "Disable async I/O"])
224                .into(),
225            );
226        }
227    }
228}
229
230impl ConfigValidation for DatasetConfig {
231    fn validate(&self) -> Result<()> {
232        let mut errors = Vec::new();
233
234        if self.batch_size == 0 {
235            errors.push(
236                ValidationError::new("dataset.batch_size", "Batch size must be greater than 0")
237                    .with_current_value("0")
238                    .with_suggestions(vec!["Set to 1 or more"]),
239            );
240        }
241
242        if self.batch_size > 10000 {
243            errors.push(
244                ValidationError::new(
245                    "dataset.batch_size",
246                    "Batch size is very large and may cause memory issues",
247                )
248                .with_current_value(&self.batch_size.to_string())
249                .with_suggestions(vec!["Consider reducing to 1000 or less"]),
250            );
251        }
252
253        if !errors.is_empty() {
254            return Err(errors
255                .into_iter()
256                .next()
257                .expect("errors vec validated as non-empty")
258                .into());
259        }
260
261        Ok(())
262    }
263}
264
265impl ConfigValidation for DataLoaderConfig {
266    fn validate(&self) -> Result<()> {
267        let mut errors = Vec::new();
268
269        if self.num_workers == 0 {
270            errors.push(
271                ValidationError::new(
272                    "dataloader.num_workers",
273                    "Number of workers must be greater than 0",
274                )
275                .with_current_value("0")
276                .with_suggestions(vec!["Set to 1 or more"]),
277            );
278        }
279
280        if self.prefetch_factor == 0 {
281            errors.push(
282                ValidationError::new(
283                    "dataloader.prefetch_factor",
284                    "Prefetch factor must be greater than 0",
285                )
286                .with_current_value("0")
287                .with_suggestions(vec!["Set to 1 or more"]),
288            );
289        }
290
291        if self.prefetch_factor > 100 {
292            errors.push(
293                ValidationError::new(
294                    "dataloader.prefetch_factor",
295                    "Prefetch factor is very large and may cause memory issues",
296                )
297                .with_current_value(&self.prefetch_factor.to_string())
298                .with_suggestions(vec!["Consider reducing to 10 or less"]),
299            );
300        }
301
302        if !errors.is_empty() {
303            return Err(errors
304                .into_iter()
305                .next()
306                .expect("errors vec validated as non-empty")
307                .into());
308        }
309
310        Ok(())
311    }
312}
313
314impl ConfigValidation for TransformConfig {
315    fn validate(&self) -> Result<()> {
316        let valid_resize_strategies = ["nearest", "bilinear", "bicubic", "lanczos"];
317
318        if !valid_resize_strategies.contains(&self.default_resize_strategy.as_str()) {
319            return Err(ValidationError::new(
320                "transforms.default_resize_strategy",
321                "Invalid resize strategy",
322            )
323            .with_current_value(&self.default_resize_strategy)
324            .with_suggestions(valid_resize_strategies.to_vec())
325            .into());
326        }
327
328        if self.augmentation_probability < 0.0 || self.augmentation_probability > 1.0 {
329            return Err(ValidationError::new(
330                "transforms.augmentation_probability",
331                "Augmentation probability must be between 0.0 and 1.0",
332            )
333            .with_current_value(&self.augmentation_probability.to_string())
334            .with_suggestions(vec!["Set to a value between 0.0 and 1.0"])
335            .into());
336        }
337
338        Ok(())
339    }
340}
341
342impl ConfigValidation for PerformanceConfig {
343    fn validate(&self) -> Result<()> {
344        if self.num_threads == 0 {
345            return Err(ValidationError::new(
346                "performance.num_threads",
347                "Number of threads must be greater than 0",
348            )
349            .with_current_value("0")
350            .with_suggestions(vec!["Set to 1 or more"])
351            .into());
352        }
353
354        if self.memory_pool_size == 0 {
355            return Err(ValidationError::new(
356                "performance.memory_pool_size",
357                "Memory pool size must be greater than 0",
358            )
359            .with_current_value("0")
360            .with_suggestions(vec!["Set to 64 MB or more"])
361            .into());
362        }
363
364        self.async_io.validate()?;
365        self.monitoring.validate()?;
366
367        Ok(())
368    }
369}
370
371impl ConfigValidation for AsyncIoConfig {
372    fn validate(&self) -> Result<()> {
373        if self.enabled && self.io_threads == 0 {
374            return Err(ValidationError::new(
375                "performance.async_io.io_threads",
376                "I/O threads must be greater than 0 when async I/O is enabled",
377            )
378            .with_current_value("0")
379            .with_suggestions(vec!["Set to 1 or more", "Disable async I/O"])
380            .into());
381        }
382
383        if self.buffer_size == 0 {
384            return Err(ValidationError::new(
385                "performance.async_io.buffer_size",
386                "Buffer size must be greater than 0",
387            )
388            .with_current_value("0")
389            .with_suggestions(vec!["Set to 4096 or more"])
390            .into());
391        }
392
393        if self.queue_depth == 0 {
394            return Err(ValidationError::new(
395                "performance.async_io.queue_depth",
396                "Queue depth must be greater than 0",
397            )
398            .with_current_value("0")
399            .with_suggestions(vec!["Set to 1 or more"])
400            .into());
401        }
402
403        Ok(())
404    }
405}
406
407impl ConfigValidation for MonitoringConfig {
408    fn validate(&self) -> Result<()> {
409        if self.enabled && self.interval == 0 {
410            return Err(ValidationError::new(
411                "performance.monitoring.interval",
412                "Monitoring interval must be greater than 0 when monitoring is enabled",
413            )
414            .with_current_value("0")
415            .with_suggestions(vec!["Set to 1 or more seconds", "Disable monitoring"])
416            .into());
417        }
418
419        Ok(())
420    }
421}
422
423impl ConfigValidation for CacheConfig {
424    fn validate(&self) -> Result<()> {
425        if self.enabled && self.size_mb == 0 {
426            return Err(ValidationError::new(
427                "cache.size_mb",
428                "Cache size must be greater than 0 when caching is enabled",
429            )
430            .with_current_value("0")
431            .with_suggestions(vec!["Set to 64 MB or more", "Disable caching"])
432            .into());
433        }
434
435        let valid_policies = ["lru", "lfu", "fifo", "random"];
436        if !valid_policies.contains(&self.eviction_policy.as_str()) {
437            return Err(ValidationError::new(
438                "cache.eviction_policy",
439                "Invalid cache eviction policy",
440            )
441            .with_current_value(&self.eviction_policy)
442            .with_suggestions(valid_policies.to_vec())
443            .into());
444        }
445
446        Ok(())
447    }
448}
449
450impl ConfigValidation for GpuConfig {
451    fn validate(&self) -> Result<()> {
452        if self.enabled && self.memory_pool_mb == 0 {
453            return Err(ValidationError::new(
454                "gpu.memory_pool_mb",
455                "GPU memory pool size must be greater than 0 when GPU is enabled",
456            )
457            .with_current_value("0")
458            .with_suggestions(vec!["Set to 256 MB or more", "Disable GPU"])
459            .into());
460        }
461
462        let valid_precisions = ["fp16", "fp32", "fp64", "bf16"];
463        if !valid_precisions.contains(&self.precision.as_str()) {
464            return Err(
465                ValidationError::new("gpu.precision", "Invalid GPU precision setting")
466                    .with_current_value(&self.precision)
467                    .with_suggestions(valid_precisions.to_vec())
468                    .into(),
469            );
470        }
471
472        Ok(())
473    }
474}
475
476impl ConfigValidation for AudioConfig {
477    fn validate(&self) -> Result<()> {
478        if self.sample_rate == 0 {
479            return Err(ValidationError::new(
480                "audio.sample_rate",
481                "Sample rate must be greater than 0",
482            )
483            .with_current_value("0")
484            .with_suggestions(vec!["Set to 44100, 48000, or other valid sample rate"])
485            .into());
486        }
487
488        if self.channels == 0 {
489            return Err(ValidationError::new(
490                "audio.channels",
491                "Number of channels must be greater than 0",
492            )
493            .with_current_value("0")
494            .with_suggestions(vec!["Set to 1 (mono) or 2 (stereo)"])
495            .into());
496        }
497
498        if self.buffer_size == 0 {
499            return Err(ValidationError::new(
500                "audio.buffer_size",
501                "Buffer size must be greater than 0",
502            )
503            .with_current_value("0")
504            .with_suggestions(vec!["Set to 1024 or other power of 2"])
505            .into());
506        }
507
508        let valid_formats = ["wav", "mp3", "flac", "ogg", "aac"];
509        if !valid_formats.contains(&self.preferred_format.as_str()) {
510            return Err(
511                ValidationError::new("audio.preferred_format", "Invalid audio format")
512                    .with_current_value(&self.preferred_format)
513                    .with_suggestions(valid_formats.to_vec())
514                    .into(),
515            );
516        }
517
518        Ok(())
519    }
520}
521
522impl ConfigValidation for FormatConfig {
523    fn validate(&self) -> Result<()> {
524        self.image.validate()?;
525        self.text.validate()?;
526        self.parquet.validate()?;
527        self.hdf5.validate()?;
528        Ok(())
529    }
530}
531
532impl ConfigValidation for ImageFormatConfig {
533    fn validate(&self) -> Result<()> {
534        if self.default_size.0 == 0 || self.default_size.1 == 0 {
535            return Err(ValidationError::new(
536                "formats.image.default_size",
537                "Image size dimensions must be greater than 0",
538            )
539            .with_current_value(&format!("{:?}", self.default_size))
540            .with_suggestions(vec!["Set to (224, 224) or other valid dimensions"])
541            .into());
542        }
543
544        Ok(())
545    }
546}
547
548impl ConfigValidation for TextFormatConfig {
549    fn validate(&self) -> Result<()> {
550        let valid_encodings = ["utf-8", "utf-16", "latin-1", "ascii"];
551        if !valid_encodings.contains(&self.encoding.as_str()) {
552            return Err(
553                ValidationError::new("formats.text.encoding", "Invalid text encoding")
554                    .with_current_value(&self.encoding)
555                    .with_suggestions(valid_encodings.to_vec())
556                    .into(),
557            );
558        }
559
560        Ok(())
561    }
562}
563
564impl ConfigValidation for ParquetFormatConfig {
565    fn validate(&self) -> Result<()> {
566        if self.batch_size == 0 {
567            return Err(ValidationError::new(
568                "formats.parquet.batch_size",
569                "Parquet batch size must be greater than 0",
570            )
571            .with_current_value("0")
572            .with_suggestions(vec!["Set to 1024 or more"])
573            .into());
574        }
575
576        Ok(())
577    }
578}
579
580impl ConfigValidation for Hdf5FormatConfig {
581    fn validate(&self) -> Result<()> {
582        if self.chunk_cache_size == 0 {
583            return Err(ValidationError::new(
584                "formats.hdf5.chunk_cache_size",
585                "HDF5 chunk cache size must be greater than 0",
586            )
587            .with_current_value("0")
588            .with_suggestions(vec!["Set to 1048576 (1MB) or more"])
589            .into());
590        }
591
592        if let Some(level) = self.compression_level {
593            if level > 9 {
594                return Err(ValidationError::new(
595                    "formats.hdf5.compression_level",
596                    "HDF5 compression level must be between 0 and 9",
597                )
598                .with_current_value(&level.to_string())
599                .with_suggestions(vec!["Set to a value between 0 and 9"])
600                .into());
601            }
602        }
603
604        Ok(())
605    }
606}
607
608impl ConfigValidation for LoggingConfig {
609    fn validate(&self) -> Result<()> {
610        let valid_levels = ["trace", "debug", "info", "warn", "error"];
611        if !valid_levels.contains(&self.level.as_str()) {
612            return Err(ValidationError::new("logging.level", "Invalid log level")
613                .with_current_value(&self.level)
614                .with_suggestions(valid_levels.to_vec())
615                .into());
616        }
617
618        let valid_formats = ["json", "text", "compact"];
619        if !valid_formats.contains(&self.format.as_str()) {
620            return Err(ValidationError::new("logging.format", "Invalid log format")
621                .with_current_value(&self.format)
622                .with_suggestions(valid_formats.to_vec())
623                .into());
624        }
625
626        Ok(())
627    }
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633
634    #[test]
635    fn test_validation_error_creation() {
636        let error = ValidationError::new("test.field", "Test error message")
637            .with_current_value("invalid_value")
638            .with_suggestions(vec!["suggestion1", "suggestion2"]);
639
640        assert_eq!(error.field, "test.field");
641        assert_eq!(error.message, "Test error message");
642        assert_eq!(error.current_value, Some("invalid_value".to_string()));
643        assert_eq!(error.suggestions, vec!["suggestion1", "suggestion2"]);
644    }
645
646    #[test]
647    fn test_valid_global_config() {
648        let config = GlobalConfig::default();
649        assert!(config.validate().is_ok());
650    }
651
652    #[test]
653    fn test_invalid_batch_size() {
654        let mut config = GlobalConfig::default();
655        config.dataset.batch_size = 0;
656        assert!(config.validate().is_err());
657    }
658
659    #[test]
660    fn test_invalid_resize_strategy() {
661        let mut config = GlobalConfig::default();
662        config.transforms.default_resize_strategy = "invalid_strategy".to_string();
663        assert!(config.validate().is_err());
664    }
665
666    #[test]
667    fn test_invalid_cache_policy() {
668        let mut config = GlobalConfig::default();
669        config.cache.eviction_policy = "invalid_policy".to_string();
670        assert!(config.validate().is_err());
671    }
672
673    #[test]
674    fn test_cross_config_validation() {
675        let mut config = GlobalConfig::default();
676        config.cache.size_mb = 2048;
677        config.performance.memory_pool_size = 1024;
678        assert!(config.validate().is_err());
679    }
680
681    #[test]
682    fn test_warnings_generation() {
683        let mut config = GlobalConfig::default();
684        config.dataloader.num_workers = num_cpus::get() * 4;
685
686        let warnings = config.get_warnings();
687        assert!(!warnings.is_empty());
688        assert!(warnings[0].contains("num_workers"));
689    }
690
691    #[test]
692    fn test_audio_config_validation() {
693        let mut config = AudioConfig {
694            sample_rate: 0,
695            ..Default::default()
696        };
697        assert!(config.validate().is_err());
698
699        config.sample_rate = 44100;
700        config.channels = 0;
701        assert!(config.validate().is_err());
702
703        config.channels = 2;
704        config.preferred_format = "invalid".to_string();
705        assert!(config.validate().is_err());
706
707        config.preferred_format = "wav".to_string();
708        assert!(config.validate().is_ok());
709    }
710}