1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt;
5
6pub trait Validatable {
8 fn validate(&self) -> Result<(), Vec<ValidationError>>;
9}
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub struct ValidationError {
13 pub field: String,
14 pub message: String,
15 pub severity: Severity,
16 pub error_code: String,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20pub enum Severity {
21 Error,
22 Warning,
23 Info,
24}
25
26impl fmt::Display for ValidationError {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 write!(f, "[{}] {}: {}", self.severity, self.field, self.message)
29 }
30}
31
32impl fmt::Display for Severity {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Severity::Error => write!(f, "ERROR"),
36 Severity::Warning => write!(f, "WARNING"),
37 Severity::Info => write!(f, "INFO"),
38 }
39 }
40}
41
42pub struct ConfigValidator {
44 rules: Vec<Box<dyn ValidationRule>>,
45}
46
47pub trait ValidationRule: Send + Sync {
48 fn validate(&self, config: &dyn std::any::Any) -> Vec<ValidationError>;
49 fn rule_name(&self) -> &str;
50}
51
52impl Default for ConfigValidator {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl ConfigValidator {
59 pub fn new() -> Self {
60 Self { rules: Vec::new() }
61 }
62
63 pub fn add_rule<R: ValidationRule + 'static>(mut self, rule: R) -> Self {
64 self.rules.push(Box::new(rule));
65 self
66 }
67
68 pub fn validate<T: 'static>(
69 &self,
70 config: &T,
71 ) -> Result<ValidationReport, Vec<ValidationError>> {
72 let mut all_errors = Vec::new();
73 let mut warnings = Vec::new();
74 let mut infos = Vec::new();
75
76 for rule in &self.rules {
77 let errors = rule.validate(config as &dyn std::any::Any);
78 for error in errors {
79 match error.severity {
80 Severity::Error => all_errors.push(error),
81 Severity::Warning => warnings.push(error),
82 Severity::Info => infos.push(error),
83 }
84 }
85 }
86
87 if !all_errors.is_empty() {
88 return Err(all_errors);
89 }
90
91 Ok(ValidationReport {
92 is_valid: true,
93 errors: Vec::new(),
94 warnings,
95 infos,
96 rules_applied: self.rules.len(),
97 })
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ValidationReport {
103 pub is_valid: bool,
104 pub errors: Vec<ValidationError>,
105 pub warnings: Vec<ValidationError>,
106 pub infos: Vec<ValidationError>,
107 pub rules_applied: usize,
108}
109
110impl ValidationReport {
111 pub fn has_warnings(&self) -> bool {
112 !self.warnings.is_empty()
113 }
114
115 pub fn has_infos(&self) -> bool {
116 !self.infos.is_empty()
117 }
118
119 pub fn print_summary(&self) {
120 println!("đ Validation Report");
121 println!(
122 " Status: {}",
123 if self.is_valid { "â
Valid" } else { "â Invalid" }
124 );
125 println!(" Rules Applied: {}", self.rules_applied);
126
127 if !self.errors.is_empty() {
128 println!(" â Errors: {}", self.errors.len());
129 for error in &self.errors {
130 println!(" {}", error);
131 }
132 }
133
134 if !self.warnings.is_empty() {
135 println!(" â ī¸ Warnings: {}", self.warnings.len());
136 for warning in &self.warnings {
137 println!(" {}", warning);
138 }
139 }
140
141 if !self.infos.is_empty() {
142 println!(" âšī¸ Infos: {}", self.infos.len());
143 for info in &self.infos {
144 println!(" {}", info);
145 }
146 }
147 }
148}
149
150pub struct RangeRule<T> {
152 field_name: String,
153 min: Option<T>,
154 max: Option<T>,
155 extractor: fn(&dyn std::any::Any) -> Option<T>,
156}
157
158impl<T> RangeRule<T>
159where
160 T: PartialOrd + Copy + fmt::Display + 'static,
161{
162 pub fn new(
163 field_name: String,
164 min: Option<T>,
165 max: Option<T>,
166 extractor: fn(&dyn std::any::Any) -> Option<T>,
167 ) -> Self {
168 Self {
169 field_name,
170 min,
171 max,
172 extractor,
173 }
174 }
175}
176
177impl<T> ValidationRule for RangeRule<T>
178where
179 T: PartialOrd + Copy + fmt::Display + 'static + Sync + Send,
180{
181 fn validate(&self, config: &dyn std::any::Any) -> Vec<ValidationError> {
182 let mut errors = Vec::new();
183
184 if let Some(value) = (self.extractor)(config) {
185 if let Some(min) = self.min {
186 if value < min {
187 errors.push(ValidationError {
188 field: self.field_name.clone(),
189 message: format!("Value {} is below minimum {}", value, min),
190 severity: Severity::Error,
191 error_code: "RANGE_BELOW_MIN".to_string(),
192 });
193 }
194 }
195
196 if let Some(max) = self.max {
197 if value > max {
198 errors.push(ValidationError {
199 field: self.field_name.clone(),
200 message: format!("Value {} exceeds maximum {}", value, max),
201 severity: Severity::Error,
202 error_code: "RANGE_ABOVE_MAX".to_string(),
203 });
204 }
205 }
206 }
207
208 errors
209 }
210
211 fn rule_name(&self) -> &str {
212 "RangeRule"
213 }
214}
215
216pub struct RequiredFieldRule {
217 field_name: String,
218 checker: fn(&dyn std::any::Any) -> bool,
219}
220
221impl RequiredFieldRule {
222 pub fn new(field_name: String, checker: fn(&dyn std::any::Any) -> bool) -> Self {
223 Self {
224 field_name,
225 checker,
226 }
227 }
228}
229
230impl ValidationRule for RequiredFieldRule {
231 fn validate(&self, config: &dyn std::any::Any) -> Vec<ValidationError> {
232 if !(self.checker)(config) {
233 vec![ValidationError {
234 field: self.field_name.clone(),
235 message: "Required field is missing or invalid".to_string(),
236 severity: Severity::Error,
237 error_code: "REQUIRED_FIELD_MISSING".to_string(),
238 }]
239 } else {
240 Vec::new()
241 }
242 }
243
244 fn rule_name(&self) -> &str {
245 "RequiredFieldRule"
246 }
247}
248
249pub struct CompatibilityRule {
250 name: String,
251 checker: fn(&dyn std::any::Any) -> Vec<ValidationError>,
252}
253
254impl CompatibilityRule {
255 pub fn new(name: String, checker: fn(&dyn std::any::Any) -> Vec<ValidationError>) -> Self {
256 Self { name, checker }
257 }
258}
259
260impl ValidationRule for CompatibilityRule {
261 fn validate(&self, config: &dyn std::any::Any) -> Vec<ValidationError> {
262 (self.checker)(config)
263 }
264
265 fn rule_name(&self) -> &str {
266 &self.name
267 }
268}
269
270#[macro_export]
272macro_rules! validate_range {
273 ($field:ident, $min:expr, $max:expr, $type:ty) => {
274 RangeRule::new(stringify!($field).to_string(), $min, $max, |config| {
275 config.downcast_ref::<Self>().map(|c| c.$field as $type)
276 })
277 };
278}
279
280#[macro_export]
281macro_rules! validate_required {
282 ($field:ident) => {
283 RequiredFieldRule::new(stringify!($field).to_string(), |config| {
284 config.downcast_ref::<Self>().map(|c| !c.$field.is_empty()).unwrap_or(false)
285 })
286 };
287}
288
289use crate::training_args::TrainingArguments;
291
292impl Validatable for TrainingArguments {
293 fn validate(&self) -> Result<(), Vec<ValidationError>> {
294 let validator = ConfigValidator::new()
295 .add_rule(RangeRule::new(
296 "learning_rate".to_string(),
297 Some(1e-10_f64),
298 Some(1.0_f64),
299 |config| {
300 config
301 .downcast_ref::<TrainingArguments>()
302 .map(|c| c.learning_rate as f64)
303 },
304 ))
305 .add_rule(RangeRule::new(
306 "per_device_train_batch_size".to_string(),
307 Some(1_usize),
308 Some(1024_usize),
309 |config| {
310 config
311 .downcast_ref::<TrainingArguments>()
312 .map(|c| c.per_device_train_batch_size)
313 },
314 ))
315 .add_rule(RangeRule::new(
316 "num_train_epochs".to_string(),
317 Some(1_u32),
318 Some(10000_u32),
319 |config| {
320 config
321 .downcast_ref::<TrainingArguments>()
322 .map(|c| c.num_train_epochs as u32)
323 },
324 ))
325 .add_rule(RequiredFieldRule::new(
326 "output_dir".to_string(),
327 |config| {
328 config
329 .downcast_ref::<TrainingArguments>()
330 .map(|c| !c.output_dir.to_string_lossy().is_empty())
331 .unwrap_or(false)
332 },
333 ))
334 .add_rule(CompatibilityRule::new(
335 "eval_strategy_compatibility".to_string(),
336 |config| {
337 if let Some(args) = config.downcast_ref::<TrainingArguments>() {
338 let mut errors = Vec::new();
339
340 if args.evaluation_strategy == crate::training_args::EvaluationStrategy::Steps && args.eval_steps == 0 {
342 errors.push(ValidationError {
343 field: "eval_steps".to_string(),
344 message: "eval_steps must be greater than 0 when evaluation_strategy is Steps".to_string(),
345 severity: Severity::Warning,
346 error_code: "EVAL_STRATEGY_COMPATIBILITY".to_string(),
347 });
348 }
349
350 if args.gradient_accumulation_steps > 1 && args.per_device_train_batch_size > 64 {
352 errors.push(ValidationError {
353 field: "batch_size_gradient_accumulation".to_string(),
354 message: "Large batch size with gradient accumulation may cause memory issues".to_string(),
355 severity: Severity::Warning,
356 error_code: "MEMORY_WARNING".to_string(),
357 });
358 }
359
360 errors
361 } else {
362 Vec::new()
363 }
364 },
365 ));
366
367 match validator.validate(self) {
368 Ok(report) => {
369 if report.has_warnings() || report.has_infos() {
370 report.print_summary();
371 }
372 Ok(())
373 },
374 Err(errors) => Err(errors),
375 }
376 }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct ValidatedConfig<T> {
382 inner: T,
383 validation_report: Option<ValidationReport>,
384}
385
386impl<T> ValidatedConfig<T>
387where
388 T: Validatable + Clone + 'static,
389{
390 pub fn new(config: T) -> Result<Self, Vec<ValidationError>> {
391 match config.validate() {
392 Ok(_) => Ok(Self {
393 inner: config,
394 validation_report: None,
395 }),
396 Err(errors) => Err(errors),
397 }
398 }
399
400 pub fn new_with_warnings(config: T) -> Result<Self, Vec<ValidationError>> {
401 let validator = ConfigValidator::new();
402 match validator.validate(&config) {
403 Ok(report) => Ok(Self {
404 inner: config,
405 validation_report: Some(report),
406 }),
407 Err(errors) => Err(errors),
408 }
409 }
410
411 pub fn into_inner(self) -> T {
412 self.inner
413 }
414
415 pub fn get(&self) -> &T {
416 &self.inner
417 }
418
419 pub fn get_validation_report(&self) -> Option<&ValidationReport> {
420 self.validation_report.as_ref()
421 }
422
423 pub fn update<F>(mut self, updater: F) -> Result<Self, Vec<ValidationError>>
425 where
426 F: FnOnce(&mut T),
427 {
428 updater(&mut self.inner);
429 match self.inner.validate() {
430 Ok(_) => Ok(self),
431 Err(errors) => Err(errors),
432 }
433 }
434}
435
436#[derive(Debug, Clone, Serialize, Deserialize)]
438pub struct ConfigSchema {
439 pub fields: HashMap<String, FieldSchema>,
440 pub required_fields: Vec<String>,
441 pub dependencies: HashMap<String, Vec<String>>,
442}
443
444#[derive(Debug, Clone, Serialize, Deserialize)]
445pub struct FieldSchema {
446 pub field_type: FieldType,
447 pub constraints: Vec<Constraint>,
448 pub description: String,
449 pub default_value: Option<serde_json::Value>,
450}
451
452#[derive(Debug, Clone, Serialize, Deserialize)]
453pub enum FieldType {
454 String,
455 Integer,
456 Float,
457 Boolean,
458 Array { item_type: Box<FieldType> },
459 Object { schema: Box<ConfigSchema> },
460 Union { types: Vec<FieldType> },
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
464pub enum Constraint {
465 Range {
466 min: Option<f64>,
467 max: Option<f64>,
468 },
469 Length {
470 min: Option<usize>,
471 max: Option<usize>,
472 },
473 Pattern {
474 regex: String,
475 },
476 OneOf {
477 values: Vec<serde_json::Value>,
478 },
479 Custom {
480 name: String,
481 description: String,
482 },
483}
484
485impl Default for ConfigSchema {
486 fn default() -> Self {
487 Self::new()
488 }
489}
490
491impl ConfigSchema {
492 pub fn new() -> Self {
493 Self {
494 fields: HashMap::new(),
495 required_fields: Vec::new(),
496 dependencies: HashMap::new(),
497 }
498 }
499
500 pub fn add_field(mut self, name: String, schema: FieldSchema) -> Self {
501 self.fields.insert(name, schema);
502 self
503 }
504
505 pub fn require_field(mut self, name: String) -> Self {
506 self.required_fields.push(name);
507 self
508 }
509
510 pub fn add_dependency(mut self, field: String, depends_on: Vec<String>) -> Self {
511 self.dependencies.insert(field, depends_on);
512 self
513 }
514
515 pub fn validate_json(&self, value: &serde_json::Value) -> Vec<ValidationError> {
516 let mut errors = Vec::new();
517
518 if let serde_json::Value::Object(obj) = value {
519 for required in &self.required_fields {
521 if !obj.contains_key(required) {
522 errors.push(ValidationError {
523 field: required.clone(),
524 message: "Required field is missing".to_string(),
525 severity: Severity::Error,
526 error_code: "REQUIRED_FIELD_MISSING".to_string(),
527 });
528 }
529 }
530
531 for (field_name, field_value) in obj {
533 if let Some(field_schema) = self.fields.get(field_name) {
534 errors.extend(self.validate_field_value(field_name, field_value, field_schema));
535 }
536 }
537
538 for (field_name, dependencies) in &self.dependencies {
540 if obj.contains_key(field_name) {
541 for dep in dependencies {
542 if !obj.contains_key(dep) {
543 errors.push(ValidationError {
544 field: field_name.clone(),
545 message: format!(
546 "Field {} requires {} to be present",
547 field_name, dep
548 ),
549 severity: Severity::Error,
550 error_code: "DEPENDENCY_MISSING".to_string(),
551 });
552 }
553 }
554 }
555 }
556 } else {
557 errors.push(ValidationError {
558 field: "root".to_string(),
559 message: "Expected object at root level".to_string(),
560 severity: Severity::Error,
561 error_code: "INVALID_ROOT_TYPE".to_string(),
562 });
563 }
564
565 errors
566 }
567
568 fn validate_field_value(
569 &self,
570 field_name: &str,
571 value: &serde_json::Value,
572 schema: &FieldSchema,
573 ) -> Vec<ValidationError> {
574 let mut errors = Vec::new();
575
576 if !self.is_type_compatible(value, &schema.field_type) {
578 errors.push(ValidationError {
579 field: field_name.to_string(),
580 message: format!("Type mismatch: expected {:?}", schema.field_type),
581 severity: Severity::Error,
582 error_code: "TYPE_MISMATCH".to_string(),
583 });
584 return errors; }
586
587 for constraint in &schema.constraints {
589 if let Some(error) = self.check_constraint(field_name, value, constraint) {
590 errors.push(error);
591 }
592 }
593
594 errors
595 }
596
597 fn is_type_compatible(&self, value: &serde_json::Value, field_type: &FieldType) -> bool {
598 match (value, field_type) {
599 (serde_json::Value::String(_), FieldType::String) => true,
600 (serde_json::Value::Number(n), FieldType::Integer) => n.is_i64(),
601 (serde_json::Value::Number(_), FieldType::Float) => true,
602 (serde_json::Value::Bool(_), FieldType::Boolean) => true,
603 (serde_json::Value::Array(_), FieldType::Array { .. }) => true,
604 (serde_json::Value::Object(_), FieldType::Object { .. }) => true,
605 (val, FieldType::Union { types }) => {
606 types.iter().any(|t| self.is_type_compatible(val, t))
607 },
608 _ => false,
609 }
610 }
611
612 fn check_constraint(
613 &self,
614 field_name: &str,
615 value: &serde_json::Value,
616 constraint: &Constraint,
617 ) -> Option<ValidationError> {
618 match constraint {
619 Constraint::Range { min, max } => {
620 if let serde_json::Value::Number(n) = value {
621 let val = match n.as_f64() {
622 Some(v) => v,
623 None => {
624 log::warn!("Failed to convert number to f64 for field: {}", field_name);
625 return Some(ValidationError {
626 field: field_name.to_string(),
627 message: format!(
628 "Number value cannot be represented as f64: {}",
629 n
630 ),
631 severity: Severity::Error,
632 error_code: "INVALID_NUMBER_CONVERSION".to_string(),
633 });
634 },
635 };
636 if let Some(min_val) = min {
637 if val < *min_val {
638 return Some(ValidationError {
639 field: field_name.to_string(),
640 message: format!("Value {} is below minimum {}", val, min_val),
641 severity: Severity::Error,
642 error_code: "RANGE_BELOW_MIN".to_string(),
643 });
644 }
645 }
646 if let Some(max_val) = max {
647 if val > *max_val {
648 return Some(ValidationError {
649 field: field_name.to_string(),
650 message: format!("Value {} exceeds maximum {}", val, max_val),
651 severity: Severity::Error,
652 error_code: "RANGE_ABOVE_MAX".to_string(),
653 });
654 }
655 }
656 }
657 },
658 Constraint::Length { min, max } => {
659 let len = match value {
660 serde_json::Value::String(s) => s.len(),
661 serde_json::Value::Array(a) => a.len(),
662 _ => return None,
663 };
664
665 if let Some(min_len) = min {
666 if len < *min_len {
667 return Some(ValidationError {
668 field: field_name.to_string(),
669 message: format!("Length {} is below minimum {}", len, min_len),
670 severity: Severity::Error,
671 error_code: "LENGTH_BELOW_MIN".to_string(),
672 });
673 }
674 }
675 if let Some(max_len) = max {
676 if len > *max_len {
677 return Some(ValidationError {
678 field: field_name.to_string(),
679 message: format!("Length {} exceeds maximum {}", len, max_len),
680 severity: Severity::Error,
681 error_code: "LENGTH_ABOVE_MAX".to_string(),
682 });
683 }
684 }
685 },
686 Constraint::OneOf { values } => {
687 if !values.contains(value) {
688 return Some(ValidationError {
689 field: field_name.to_string(),
690 message: format!("Value must be one of: {:?}", values),
691 severity: Severity::Error,
692 error_code: "VALUE_NOT_IN_SET".to_string(),
693 });
694 }
695 },
696 Constraint::Pattern { regex } => {
697 if let serde_json::Value::String(s) = value {
698 if s.is_empty() {
700 return Some(ValidationError {
701 field: field_name.to_string(),
702 message: format!("String does not match pattern: {}", regex),
703 severity: Severity::Error,
704 error_code: "PATTERN_MISMATCH".to_string(),
705 });
706 }
707 }
708 },
709 Constraint::Custom {
710 name: _,
711 description: _,
712 } => {
713 },
716 }
717
718 None
719 }
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725
726 #[derive(Clone)]
727 struct TestConfig {
728 learning_rate: f64,
729 #[allow(dead_code)]
730 batch_size: usize,
731 output_dir: String,
732 }
733
734 impl Validatable for TestConfig {
735 fn validate(&self) -> Result<(), Vec<ValidationError>> {
736 let validator = ConfigValidator::new()
737 .add_rule(RangeRule::new(
738 "learning_rate".to_string(),
739 Some(1e-6_f64),
740 Some(1.0_f64),
741 |config| config.downcast_ref::<TestConfig>().map(|c| c.learning_rate),
742 ))
743 .add_rule(RequiredFieldRule::new("output_dir".to_string(), |config| {
744 config
745 .downcast_ref::<TestConfig>()
746 .map(|c| !c.output_dir.is_empty())
747 .unwrap_or(false)
748 }));
749
750 match validator.validate(self) {
751 Ok(_) => Ok(()),
752 Err(errors) => Err(errors),
753 }
754 }
755 }
756
757 #[test]
758 fn test_valid_config() {
759 let config = TestConfig {
760 learning_rate: 0.001,
761 batch_size: 32,
762 output_dir: "/tmp/output".to_string(),
763 };
764
765 assert!(config.validate().is_ok());
766 }
767
768 #[test]
769 fn test_invalid_learning_rate() {
770 let config = TestConfig {
771 learning_rate: 2.0, batch_size: 32,
773 output_dir: "/tmp/output".to_string(),
774 };
775
776 let result = config.validate();
777 assert!(result.is_err());
778 let errors = result.unwrap_err();
779 assert_eq!(errors.len(), 1);
780 assert_eq!(errors[0].field, "learning_rate");
781 assert_eq!(errors[0].error_code, "RANGE_ABOVE_MAX");
782 }
783
784 #[test]
785 fn test_missing_output_dir() {
786 let config = TestConfig {
787 learning_rate: 0.001,
788 batch_size: 32,
789 output_dir: "".to_string(), };
791
792 let result = config.validate();
793 assert!(result.is_err());
794 let errors = result.unwrap_err();
795 assert_eq!(errors.len(), 1);
796 assert_eq!(errors[0].field, "output_dir");
797 assert_eq!(errors[0].error_code, "REQUIRED_FIELD_MISSING");
798 }
799
800 #[test]
801 fn test_validated_config() {
802 let config = TestConfig {
803 learning_rate: 0.001,
804 batch_size: 32,
805 output_dir: "/tmp/output".to_string(),
806 };
807
808 let validated = ValidatedConfig::new(config.clone()).expect("operation failed in test");
809 assert_eq!(validated.get().learning_rate, 0.001);
810
811 let inner = validated.into_inner();
812 assert_eq!(inner.learning_rate, 0.001);
813 }
814
815 #[test]
816 fn test_config_schema_validation() {
817 let schema = ConfigSchema::new()
818 .add_field(
819 "learning_rate".to_string(),
820 FieldSchema {
821 field_type: FieldType::Float,
822 constraints: vec![Constraint::Range {
823 min: Some(1e-6),
824 max: Some(1.0),
825 }],
826 description: "Learning rate for training".to_string(),
827 default_value: Some(serde_json::json!(0.001)),
828 },
829 )
830 .require_field("learning_rate".to_string());
831
832 let valid_json = serde_json::json!({
834 "learning_rate": 0.001
835 });
836 let errors = schema.validate_json(&valid_json);
837 assert!(errors.is_empty());
838
839 let invalid_json = serde_json::json!({});
841 let errors = schema.validate_json(&invalid_json);
842 assert_eq!(errors.len(), 1);
843 assert_eq!(errors[0].error_code, "REQUIRED_FIELD_MISSING");
844
845 let invalid_json = serde_json::json!({
847 "learning_rate": 2.0
848 });
849 let errors = schema.validate_json(&invalid_json);
850 assert_eq!(errors.len(), 1);
851 assert_eq!(errors[0].error_code, "RANGE_ABOVE_MAX");
852 }
853
854 #[test]
855 fn test_validation_report() {
856 let report = ValidationReport {
857 is_valid: true,
858 errors: Vec::new(),
859 warnings: vec![ValidationError {
860 field: "test".to_string(),
861 message: "Test warning".to_string(),
862 severity: Severity::Warning,
863 error_code: "TEST_WARNING".to_string(),
864 }],
865 infos: Vec::new(),
866 rules_applied: 1,
867 };
868
869 assert!(report.has_warnings());
870 assert!(!report.has_infos());
871
872 report.print_summary(); }
875}