tenflowers_dataset/config/
validation.rs1use super::{
7 AsyncIoConfig, AudioConfig, CacheConfig, DataLoaderConfig, DatasetConfig, FormatConfig,
8 GlobalConfig, GpuConfig, Hdf5FormatConfig, ImageFormatConfig, LoggingConfig, MonitoringConfig,
9 ParquetFormatConfig, PerformanceConfig, TextFormatConfig, TransformConfig,
10};
11use crate::{Result, TensorError};
12use std::collections::HashMap;
13
14pub type ValidationResult<T = ()> = std::result::Result<T, ValidationError>;
16
17#[derive(Debug, Clone)]
19pub struct ValidationError {
20 pub field: String,
22 pub message: String,
24 pub current_value: Option<String>,
26 pub suggestions: Vec<String>,
28}
29
30impl ValidationError {
31 pub fn new(field: &str, message: &str) -> Self {
33 Self {
34 field: field.to_string(),
35 message: message.to_string(),
36 current_value: None,
37 suggestions: Vec::new(),
38 }
39 }
40
41 pub fn with_current_value(mut self, value: &str) -> Self {
43 self.current_value = Some(value.to_string());
44 self
45 }
46
47 pub fn with_suggestions(mut self, suggestions: Vec<&str>) -> Self {
49 self.suggestions = suggestions.into_iter().map(|s| s.to_string()).collect();
50 self
51 }
52}
53
54impl std::fmt::Display for ValidationError {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "Validation error in '{}': {}", self.field, self.message)?;
57
58 if let Some(ref value) = self.current_value {
59 write!(f, " (current value: {})", value)?;
60 }
61
62 if !self.suggestions.is_empty() {
63 write!(f, " - Suggestions: {}", self.suggestions.join(", "))?;
64 }
65
66 Ok(())
67 }
68}
69
70impl std::error::Error for ValidationError {}
71
72impl From<ValidationError> for TensorError {
73 fn from(err: ValidationError) -> Self {
74 TensorError::invalid_argument(err.to_string())
75 }
76}
77
78pub trait ConfigValidation {
80 fn validate(&self) -> Result<()>;
82
83 fn get_warnings(&self) -> Vec<String> {
85 Vec::new()
86 }
87}
88
89impl ConfigValidation for GlobalConfig {
90 fn validate(&self) -> Result<()> {
91 let mut errors = Vec::new();
92
93 if let Err(e) = self.dataset.validate() {
95 errors.push(e);
96 }
97
98 if let Err(e) = self.dataloader.validate() {
100 errors.push(e);
101 }
102
103 if let Err(e) = self.transforms.validate() {
105 errors.push(e);
106 }
107
108 if let Err(e) = self.performance.validate() {
110 errors.push(e);
111 }
112
113 if let Err(e) = self.cache.validate() {
115 errors.push(e);
116 }
117
118 if let Err(e) = self.gpu.validate() {
120 errors.push(e);
121 }
122
123 if let Err(e) = self.audio.validate() {
125 errors.push(e);
126 }
127
128 if let Err(e) = self.formats.validate() {
130 errors.push(e);
131 }
132
133 if let Err(e) = self.logging.validate() {
135 errors.push(e);
136 }
137
138 self.validate_cross_config_constraints(&mut errors);
140
141 if !errors.is_empty() {
142 let error_messages: Vec<String> = errors.into_iter().map(|e| e.to_string()).collect();
143 return Err(TensorError::invalid_argument(format!(
144 "Configuration validation failed:\n{}",
145 error_messages.join("\n")
146 )));
147 }
148
149 Ok(())
150 }
151
152 fn get_warnings(&self) -> Vec<String> {
153 let mut warnings = Vec::new();
154
155 if self.dataloader.num_workers > num_cpus::get() * 2 {
157 warnings.push(format!(
158 "dataloader.num_workers ({}) is much higher than CPU count ({}). This may cause performance degradation.",
159 self.dataloader.num_workers,
160 num_cpus::get()
161 ));
162 }
163
164 if self.performance.memory_pool_size > 8192 {
165 warnings.push("performance.memory_pool_size is very large (>8GB). Make sure you have sufficient RAM.".to_string());
166 }
167
168 if self.gpu.enabled && self.gpu.memory_pool_mb > 4096 {
170 warnings.push(
171 "gpu.memory_pool_mb is very large (>4GB). Make sure your GPU has sufficient VRAM."
172 .to_string(),
173 );
174 }
175
176 if self.cache.enabled && self.cache.size_mb > self.performance.memory_pool_size {
178 warnings.push("cache.size_mb is larger than performance.memory_pool_size. This may cause memory pressure.".to_string());
179 }
180
181 warnings
182 }
183}
184
185impl GlobalConfig {
186 fn validate_cross_config_constraints(&self, errors: &mut Vec<TensorError>) {
187 if self.cache.enabled && self.cache.size_mb > self.performance.memory_pool_size {
189 errors.push(
190 ValidationError::new(
191 "cache.size_mb",
192 "Cache size cannot be larger than memory pool size",
193 )
194 .with_current_value(&self.cache.size_mb.to_string())
195 .with_suggestions(vec![
196 &format!("Set to {} or less", self.performance.memory_pool_size),
197 "Increase performance.memory_pool_size",
198 ])
199 .into(),
200 );
201 }
202
203 if self.gpu.enabled && self.transforms.enable_gpu && self.gpu.device_id.is_none() {
205 errors.push(
206 ValidationError::new(
207 "gpu.device_id",
208 "GPU device ID should be specified when GPU acceleration is enabled",
209 )
210 .with_suggestions(vec!["Set gpu.device_id to a valid GPU index"])
211 .into(),
212 );
213 }
214
215 if self.performance.async_io.enabled && self.performance.async_io.io_threads == 0 {
217 errors.push(
218 ValidationError::new(
219 "performance.async_io.io_threads",
220 "Async I/O threads must be greater than 0 when async I/O is enabled",
221 )
222 .with_current_value("0")
223 .with_suggestions(vec!["Set to 1 or more", "Disable async I/O"])
224 .into(),
225 );
226 }
227 }
228}
229
230impl ConfigValidation for DatasetConfig {
231 fn validate(&self) -> Result<()> {
232 let mut errors = Vec::new();
233
234 if self.batch_size == 0 {
235 errors.push(
236 ValidationError::new("dataset.batch_size", "Batch size must be greater than 0")
237 .with_current_value("0")
238 .with_suggestions(vec!["Set to 1 or more"]),
239 );
240 }
241
242 if self.batch_size > 10000 {
243 errors.push(
244 ValidationError::new(
245 "dataset.batch_size",
246 "Batch size is very large and may cause memory issues",
247 )
248 .with_current_value(&self.batch_size.to_string())
249 .with_suggestions(vec!["Consider reducing to 1000 or less"]),
250 );
251 }
252
253 if !errors.is_empty() {
254 return Err(errors
255 .into_iter()
256 .next()
257 .expect("errors vec validated as non-empty")
258 .into());
259 }
260
261 Ok(())
262 }
263}
264
265impl ConfigValidation for DataLoaderConfig {
266 fn validate(&self) -> Result<()> {
267 let mut errors = Vec::new();
268
269 if self.num_workers == 0 {
270 errors.push(
271 ValidationError::new(
272 "dataloader.num_workers",
273 "Number of workers must be greater than 0",
274 )
275 .with_current_value("0")
276 .with_suggestions(vec!["Set to 1 or more"]),
277 );
278 }
279
280 if self.prefetch_factor == 0 {
281 errors.push(
282 ValidationError::new(
283 "dataloader.prefetch_factor",
284 "Prefetch factor must be greater than 0",
285 )
286 .with_current_value("0")
287 .with_suggestions(vec!["Set to 1 or more"]),
288 );
289 }
290
291 if self.prefetch_factor > 100 {
292 errors.push(
293 ValidationError::new(
294 "dataloader.prefetch_factor",
295 "Prefetch factor is very large and may cause memory issues",
296 )
297 .with_current_value(&self.prefetch_factor.to_string())
298 .with_suggestions(vec!["Consider reducing to 10 or less"]),
299 );
300 }
301
302 if !errors.is_empty() {
303 return Err(errors
304 .into_iter()
305 .next()
306 .expect("errors vec validated as non-empty")
307 .into());
308 }
309
310 Ok(())
311 }
312}
313
314impl ConfigValidation for TransformConfig {
315 fn validate(&self) -> Result<()> {
316 let valid_resize_strategies = ["nearest", "bilinear", "bicubic", "lanczos"];
317
318 if !valid_resize_strategies.contains(&self.default_resize_strategy.as_str()) {
319 return Err(ValidationError::new(
320 "transforms.default_resize_strategy",
321 "Invalid resize strategy",
322 )
323 .with_current_value(&self.default_resize_strategy)
324 .with_suggestions(valid_resize_strategies.to_vec())
325 .into());
326 }
327
328 if self.augmentation_probability < 0.0 || self.augmentation_probability > 1.0 {
329 return Err(ValidationError::new(
330 "transforms.augmentation_probability",
331 "Augmentation probability must be between 0.0 and 1.0",
332 )
333 .with_current_value(&self.augmentation_probability.to_string())
334 .with_suggestions(vec!["Set to a value between 0.0 and 1.0"])
335 .into());
336 }
337
338 Ok(())
339 }
340}
341
342impl ConfigValidation for PerformanceConfig {
343 fn validate(&self) -> Result<()> {
344 if self.num_threads == 0 {
345 return Err(ValidationError::new(
346 "performance.num_threads",
347 "Number of threads must be greater than 0",
348 )
349 .with_current_value("0")
350 .with_suggestions(vec!["Set to 1 or more"])
351 .into());
352 }
353
354 if self.memory_pool_size == 0 {
355 return Err(ValidationError::new(
356 "performance.memory_pool_size",
357 "Memory pool size must be greater than 0",
358 )
359 .with_current_value("0")
360 .with_suggestions(vec!["Set to 64 MB or more"])
361 .into());
362 }
363
364 self.async_io.validate()?;
365 self.monitoring.validate()?;
366
367 Ok(())
368 }
369}
370
371impl ConfigValidation for AsyncIoConfig {
372 fn validate(&self) -> Result<()> {
373 if self.enabled && self.io_threads == 0 {
374 return Err(ValidationError::new(
375 "performance.async_io.io_threads",
376 "I/O threads must be greater than 0 when async I/O is enabled",
377 )
378 .with_current_value("0")
379 .with_suggestions(vec!["Set to 1 or more", "Disable async I/O"])
380 .into());
381 }
382
383 if self.buffer_size == 0 {
384 return Err(ValidationError::new(
385 "performance.async_io.buffer_size",
386 "Buffer size must be greater than 0",
387 )
388 .with_current_value("0")
389 .with_suggestions(vec!["Set to 4096 or more"])
390 .into());
391 }
392
393 if self.queue_depth == 0 {
394 return Err(ValidationError::new(
395 "performance.async_io.queue_depth",
396 "Queue depth must be greater than 0",
397 )
398 .with_current_value("0")
399 .with_suggestions(vec!["Set to 1 or more"])
400 .into());
401 }
402
403 Ok(())
404 }
405}
406
407impl ConfigValidation for MonitoringConfig {
408 fn validate(&self) -> Result<()> {
409 if self.enabled && self.interval == 0 {
410 return Err(ValidationError::new(
411 "performance.monitoring.interval",
412 "Monitoring interval must be greater than 0 when monitoring is enabled",
413 )
414 .with_current_value("0")
415 .with_suggestions(vec!["Set to 1 or more seconds", "Disable monitoring"])
416 .into());
417 }
418
419 Ok(())
420 }
421}
422
423impl ConfigValidation for CacheConfig {
424 fn validate(&self) -> Result<()> {
425 if self.enabled && self.size_mb == 0 {
426 return Err(ValidationError::new(
427 "cache.size_mb",
428 "Cache size must be greater than 0 when caching is enabled",
429 )
430 .with_current_value("0")
431 .with_suggestions(vec!["Set to 64 MB or more", "Disable caching"])
432 .into());
433 }
434
435 let valid_policies = ["lru", "lfu", "fifo", "random"];
436 if !valid_policies.contains(&self.eviction_policy.as_str()) {
437 return Err(ValidationError::new(
438 "cache.eviction_policy",
439 "Invalid cache eviction policy",
440 )
441 .with_current_value(&self.eviction_policy)
442 .with_suggestions(valid_policies.to_vec())
443 .into());
444 }
445
446 Ok(())
447 }
448}
449
450impl ConfigValidation for GpuConfig {
451 fn validate(&self) -> Result<()> {
452 if self.enabled && self.memory_pool_mb == 0 {
453 return Err(ValidationError::new(
454 "gpu.memory_pool_mb",
455 "GPU memory pool size must be greater than 0 when GPU is enabled",
456 )
457 .with_current_value("0")
458 .with_suggestions(vec!["Set to 256 MB or more", "Disable GPU"])
459 .into());
460 }
461
462 let valid_precisions = ["fp16", "fp32", "fp64", "bf16"];
463 if !valid_precisions.contains(&self.precision.as_str()) {
464 return Err(
465 ValidationError::new("gpu.precision", "Invalid GPU precision setting")
466 .with_current_value(&self.precision)
467 .with_suggestions(valid_precisions.to_vec())
468 .into(),
469 );
470 }
471
472 Ok(())
473 }
474}
475
476impl ConfigValidation for AudioConfig {
477 fn validate(&self) -> Result<()> {
478 if self.sample_rate == 0 {
479 return Err(ValidationError::new(
480 "audio.sample_rate",
481 "Sample rate must be greater than 0",
482 )
483 .with_current_value("0")
484 .with_suggestions(vec!["Set to 44100, 48000, or other valid sample rate"])
485 .into());
486 }
487
488 if self.channels == 0 {
489 return Err(ValidationError::new(
490 "audio.channels",
491 "Number of channels must be greater than 0",
492 )
493 .with_current_value("0")
494 .with_suggestions(vec!["Set to 1 (mono) or 2 (stereo)"])
495 .into());
496 }
497
498 if self.buffer_size == 0 {
499 return Err(ValidationError::new(
500 "audio.buffer_size",
501 "Buffer size must be greater than 0",
502 )
503 .with_current_value("0")
504 .with_suggestions(vec!["Set to 1024 or other power of 2"])
505 .into());
506 }
507
508 let valid_formats = ["wav", "mp3", "flac", "ogg", "aac"];
509 if !valid_formats.contains(&self.preferred_format.as_str()) {
510 return Err(
511 ValidationError::new("audio.preferred_format", "Invalid audio format")
512 .with_current_value(&self.preferred_format)
513 .with_suggestions(valid_formats.to_vec())
514 .into(),
515 );
516 }
517
518 Ok(())
519 }
520}
521
522impl ConfigValidation for FormatConfig {
523 fn validate(&self) -> Result<()> {
524 self.image.validate()?;
525 self.text.validate()?;
526 self.parquet.validate()?;
527 self.hdf5.validate()?;
528 Ok(())
529 }
530}
531
532impl ConfigValidation for ImageFormatConfig {
533 fn validate(&self) -> Result<()> {
534 if self.default_size.0 == 0 || self.default_size.1 == 0 {
535 return Err(ValidationError::new(
536 "formats.image.default_size",
537 "Image size dimensions must be greater than 0",
538 )
539 .with_current_value(&format!("{:?}", self.default_size))
540 .with_suggestions(vec!["Set to (224, 224) or other valid dimensions"])
541 .into());
542 }
543
544 Ok(())
545 }
546}
547
548impl ConfigValidation for TextFormatConfig {
549 fn validate(&self) -> Result<()> {
550 let valid_encodings = ["utf-8", "utf-16", "latin-1", "ascii"];
551 if !valid_encodings.contains(&self.encoding.as_str()) {
552 return Err(
553 ValidationError::new("formats.text.encoding", "Invalid text encoding")
554 .with_current_value(&self.encoding)
555 .with_suggestions(valid_encodings.to_vec())
556 .into(),
557 );
558 }
559
560 Ok(())
561 }
562}
563
564impl ConfigValidation for ParquetFormatConfig {
565 fn validate(&self) -> Result<()> {
566 if self.batch_size == 0 {
567 return Err(ValidationError::new(
568 "formats.parquet.batch_size",
569 "Parquet batch size must be greater than 0",
570 )
571 .with_current_value("0")
572 .with_suggestions(vec!["Set to 1024 or more"])
573 .into());
574 }
575
576 Ok(())
577 }
578}
579
580impl ConfigValidation for Hdf5FormatConfig {
581 fn validate(&self) -> Result<()> {
582 if self.chunk_cache_size == 0 {
583 return Err(ValidationError::new(
584 "formats.hdf5.chunk_cache_size",
585 "HDF5 chunk cache size must be greater than 0",
586 )
587 .with_current_value("0")
588 .with_suggestions(vec!["Set to 1048576 (1MB) or more"])
589 .into());
590 }
591
592 if let Some(level) = self.compression_level {
593 if level > 9 {
594 return Err(ValidationError::new(
595 "formats.hdf5.compression_level",
596 "HDF5 compression level must be between 0 and 9",
597 )
598 .with_current_value(&level.to_string())
599 .with_suggestions(vec!["Set to a value between 0 and 9"])
600 .into());
601 }
602 }
603
604 Ok(())
605 }
606}
607
608impl ConfigValidation for LoggingConfig {
609 fn validate(&self) -> Result<()> {
610 let valid_levels = ["trace", "debug", "info", "warn", "error"];
611 if !valid_levels.contains(&self.level.as_str()) {
612 return Err(ValidationError::new("logging.level", "Invalid log level")
613 .with_current_value(&self.level)
614 .with_suggestions(valid_levels.to_vec())
615 .into());
616 }
617
618 let valid_formats = ["json", "text", "compact"];
619 if !valid_formats.contains(&self.format.as_str()) {
620 return Err(ValidationError::new("logging.format", "Invalid log format")
621 .with_current_value(&self.format)
622 .with_suggestions(valid_formats.to_vec())
623 .into());
624 }
625
626 Ok(())
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn test_validation_error_creation() {
636 let error = ValidationError::new("test.field", "Test error message")
637 .with_current_value("invalid_value")
638 .with_suggestions(vec!["suggestion1", "suggestion2"]);
639
640 assert_eq!(error.field, "test.field");
641 assert_eq!(error.message, "Test error message");
642 assert_eq!(error.current_value, Some("invalid_value".to_string()));
643 assert_eq!(error.suggestions, vec!["suggestion1", "suggestion2"]);
644 }
645
646 #[test]
647 fn test_valid_global_config() {
648 let config = GlobalConfig::default();
649 assert!(config.validate().is_ok());
650 }
651
652 #[test]
653 fn test_invalid_batch_size() {
654 let mut config = GlobalConfig::default();
655 config.dataset.batch_size = 0;
656 assert!(config.validate().is_err());
657 }
658
659 #[test]
660 fn test_invalid_resize_strategy() {
661 let mut config = GlobalConfig::default();
662 config.transforms.default_resize_strategy = "invalid_strategy".to_string();
663 assert!(config.validate().is_err());
664 }
665
666 #[test]
667 fn test_invalid_cache_policy() {
668 let mut config = GlobalConfig::default();
669 config.cache.eviction_policy = "invalid_policy".to_string();
670 assert!(config.validate().is_err());
671 }
672
673 #[test]
674 fn test_cross_config_validation() {
675 let mut config = GlobalConfig::default();
676 config.cache.size_mb = 2048;
677 config.performance.memory_pool_size = 1024;
678 assert!(config.validate().is_err());
679 }
680
681 #[test]
682 fn test_warnings_generation() {
683 let mut config = GlobalConfig::default();
684 config.dataloader.num_workers = num_cpus::get() * 4;
685
686 let warnings = config.get_warnings();
687 assert!(!warnings.is_empty());
688 assert!(warnings[0].contains("num_workers"));
689 }
690
691 #[test]
692 fn test_audio_config_validation() {
693 let mut config = AudioConfig {
694 sample_rate: 0,
695 ..Default::default()
696 };
697 assert!(config.validate().is_err());
698
699 config.sample_rate = 44100;
700 config.channels = 0;
701 assert!(config.validate().is_err());
702
703 config.channels = 2;
704 config.preferred_format = "invalid".to_string();
705 assert!(config.validate().is_err());
706
707 config.preferred_format = "wav".to_string();
708 assert!(config.validate().is_ok());
709 }
710}