1use crate::error::Result;
4use crate::ml_framework::{DataType, MLFramework, MLModel};
5use std::collections::{BTreeMap, HashMap, HashSet};
6
7pub struct ModelValidator {
9 source_framework: MLFramework,
10 target_framework: MLFramework,
11 validation_config: ValidationConfig,
12}
13
14#[derive(Debug, Clone)]
15pub struct ValidationConfig {
16 pub check_data_types: bool,
17 pub check_tensorshapes: bool,
18 pub check_operations: bool,
19 pub check_metadata: bool,
20 pub strict_mode: bool,
21 pub allow_type_conversion: bool,
22 pub maxshape_dimension: Option<usize>,
23 pub supported_dtypes: Option<HashSet<DataType>>,
24}
25
26impl Default for ValidationConfig {
27 fn default() -> Self {
28 Self {
29 check_data_types: true,
30 check_tensorshapes: true,
31 check_operations: true,
32 check_metadata: true,
33 strict_mode: false,
34 allow_type_conversion: true,
35 maxshape_dimension: Some(8), supported_dtypes: None,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct ValidationReport {
43 pub is_compatible: bool,
44 pub compatibility_score: f32, pub errors: Vec<ValidationError>,
46 pub warnings: Vec<ValidationWarning>,
47 pub recommendations: Vec<ValidationRecommendation>,
48 pub conversion_path: Option<ConversionPath>,
49}
50
51#[derive(Debug, Clone)]
52pub struct ValidationError {
53 pub category: ErrorCategory,
54 pub severity: ErrorSeverity,
55 pub message: String,
56 pub location: Option<String>, pub fix_suggestion: Option<String>,
58}
59
60#[derive(Debug, Clone)]
61pub struct ValidationWarning {
62 pub category: WarningCategory,
63 pub message: String,
64 pub location: Option<String>,
65 pub impact: WarningImpact,
66}
67
68#[derive(Debug, Clone)]
69pub struct ValidationRecommendation {
70 pub category: RecommendationCategory,
71 pub message: String,
72 pub priority: RecommendationPriority,
73 pub estimated_effort: EstimatedEffort,
74}
75
76#[derive(Debug, Clone, PartialEq)]
77pub enum ErrorCategory {
78 DataType,
79 Shape,
80 Operation,
81 Metadata,
82 Framework,
83 Version,
84}
85
86#[derive(Debug, Clone, PartialEq)]
87pub enum ErrorSeverity {
88 Critical, High, Medium, Low, }
93
94#[derive(Debug, Clone, PartialEq)]
95pub enum WarningCategory {
96 Performance,
97 Precision,
98 Compatibility,
99 BestPractice,
100}
101
102#[derive(Debug, Clone, PartialEq)]
103pub enum WarningImpact {
104 High, Medium, Low, }
108
109#[derive(Debug, Clone, PartialEq)]
110pub enum RecommendationCategory {
111 Optimization,
112 Conversion,
113 Preprocessing,
114 Alternative,
115 BestPractice,
116}
117
118#[derive(Debug, Clone, PartialEq)]
119pub enum RecommendationPriority {
120 High,
121 Medium,
122 Low,
123}
124
125#[derive(Debug, Clone)]
126pub enum EstimatedEffort {
127 Minimal, Low, Medium, High, VeryHigh, }
133
134#[derive(Debug, Clone)]
135pub struct ConversionPath {
136 pub steps: Vec<ConversionStep>,
137 pub estimated_accuracy_loss: f32, pub estimated_performance_impact: f32, pub complexity: ConversionComplexity,
140}
141
142#[derive(Debug, Clone)]
143pub struct ConversionStep {
144 pub operation: ConversionOperation,
145 pub description: String,
146 pub required_tools: Vec<String>,
147 pub estimated_time: EstimatedEffort,
148}
149
150#[derive(Debug, Clone)]
151pub enum ConversionOperation {
152 DirectConversion,
153 TypeConversion,
154 ShapeReshaping,
155 OperationMapping,
156 ManualIntervention,
157 AlternativeImplementation,
158}
159
160#[derive(Debug, Clone)]
161pub enum ConversionComplexity {
162 Trivial, Simple, Moderate, Complex, VeryComplex, }
168
169impl ModelValidator {
170 pub fn new(source: MLFramework, target: MLFramework, config: ValidationConfig) -> Self {
171 Self {
172 source_framework: source,
173 target_framework: target,
174 validation_config: config,
175 }
176 }
177
178 pub fn validate(&self, model: &MLModel) -> Result<ValidationReport> {
180 let mut errors = Vec::new();
181 let mut warnings = Vec::new();
182 let mut recommendations = Vec::new();
183
184 let framework_compatibility = self.check_framework_compatibility(model);
186 if let Some(error) = framework_compatibility.error {
187 errors.push(error);
188 }
189 warnings.extend(framework_compatibility.warnings);
190 recommendations.extend(framework_compatibility.recommendations);
191
192 if self.validation_config.check_data_types {
194 let dtype_check = self.check_data_types(model);
195 errors.extend(dtype_check.errors);
196 warnings.extend(dtype_check.warnings);
197 recommendations.extend(dtype_check.recommendations);
198 }
199
200 if self.validation_config.check_tensorshapes {
202 let shape_check = self.check_tensorshapes(model);
203 errors.extend(shape_check.errors);
204 warnings.extend(shape_check.warnings);
205 recommendations.extend(shape_check.recommendations);
206 }
207
208 if self.validation_config.check_operations {
210 let ops_check = self.check_operations(model);
211 errors.extend(ops_check.errors);
212 warnings.extend(ops_check.warnings);
213 recommendations.extend(ops_check.recommendations);
214 }
215
216 if self.validation_config.check_metadata {
218 let metadata_check = self.check_metadata(model);
219 errors.extend(metadata_check.errors);
220 warnings.extend(metadata_check.warnings);
221 recommendations.extend(metadata_check.recommendations);
222 }
223
224 let compatibility_score = self.calculate_compatibility_score(&errors, &warnings);
226 let is_compatible = compatibility_score > 0.7
227 && errors.iter().all(|e| e.severity != ErrorSeverity::Critical);
228
229 let conversion_path = if is_compatible {
231 Some(self.generate_conversion_path(model, &errors, &warnings)?)
232 } else {
233 None
234 };
235
236 Ok(ValidationReport {
237 is_compatible,
238 compatibility_score,
239 errors,
240 warnings,
241 recommendations,
242 conversion_path,
243 })
244 }
245
246 fn check_framework_compatibility(&self, model: &MLModel) -> FrameworkCompatibilityResult {
248 let mut warnings = Vec::new();
249 let mut recommendations = Vec::new();
250
251 if self.source_framework == self.target_framework {
253 return FrameworkCompatibilityResult {
254 error: None,
255 warnings,
256 recommendations,
257 };
258 }
259
260 let compatibility_score = crate::ml_framework::validation::utils::quick_compatibility_check(
262 self.source_framework,
263 self.target_framework,
264 );
265
266 if compatibility_score < 0.5 {
267 warnings.push(ValidationWarning {
268 category: WarningCategory::Compatibility,
269 message: format!(
270 "Low compatibility between {:?} and {:?} (score: {:.2})",
271 self.source_framework, self.target_framework, compatibility_score
272 ),
273 location: None,
274 impact: WarningImpact::High,
275 });
276
277 recommendations.push(ValidationRecommendation {
278 category: RecommendationCategory::Alternative,
279 message: "Consider using ONNX as an intermediate format".to_string(),
280 priority: RecommendationPriority::High,
281 estimated_effort: EstimatedEffort::Medium,
282 });
283 }
284
285 FrameworkCompatibilityResult {
286 error: None,
287 warnings,
288 recommendations,
289 }
290 }
291
292 fn check_data_types(&self, model: &MLModel) -> ValidationCheckResult {
294 let mut errors = Vec::new();
295 let mut warnings = Vec::new();
296 let recommendations = Vec::new();
297
298 for (tensor_name, tensor) in &model.weights {
299 if let Some(ref supported_dtypes) = self.validation_config.supported_dtypes {
301 if !supported_dtypes.contains(&tensor.metadata.dtype) {
302 if self.validation_config.allow_type_conversion {
303 warnings.push(ValidationWarning {
304 category: WarningCategory::Precision,
305 message: format!(
306 "Tensor '{}' has unsupported data type {:?}, conversion may be needed",
307 tensor_name, tensor.metadata.dtype
308 ),
309 location: Some(tensor_name.clone()),
310 impact: WarningImpact::Medium,
311 });
312 } else {
313 errors.push(ValidationError {
314 category: ErrorCategory::DataType,
315 severity: ErrorSeverity::High,
316 message: format!(
317 "Tensor '{}' has unsupported data type {:?}",
318 tensor_name, tensor.metadata.dtype
319 ),
320 location: Some(tensor_name.clone()),
321 fix_suggestion: Some(
322 "Enable type conversion or change tensor data type".to_string(),
323 ),
324 });
325 }
326 }
327 }
328
329 if let (MLFramework::PyTorch, MLFramework::CoreML, DataType::Float64) = (
331 &self.source_framework,
332 &self.target_framework,
333 &tensor.metadata.dtype,
334 ) {
335 warnings.push(ValidationWarning {
336 category: WarningCategory::Precision,
337 message: format!(
338 "Tensor '{}' uses Float64 which may be converted to Float32 in CoreML",
339 tensor_name
340 ),
341 location: Some(tensor_name.clone()),
342 impact: WarningImpact::Medium,
343 });
344 }
345 }
346
347 ValidationCheckResult {
348 errors,
349 warnings,
350 recommendations,
351 }
352 }
353
354 fn check_tensorshapes(&self, model: &MLModel) -> ValidationCheckResult {
356 let mut errors = Vec::new();
357 let mut warnings = Vec::new();
358 let recommendations = Vec::new();
359
360 for (tensor_name, tensor) in &model.weights {
361 let shape = &tensor.metadata.shape;
362
363 if let Some(max_dims) = self.validation_config.maxshape_dimension {
365 if shape.len() > max_dims {
366 errors.push(ValidationError {
367 category: ErrorCategory::Shape,
368 severity: ErrorSeverity::High,
369 message: format!(
370 "Tensor '{}' has {} dimensions, but target framework supports max {}",
371 tensor_name,
372 shape.len(),
373 max_dims
374 ),
375 location: Some(tensor_name.clone()),
376 fix_suggestion: Some(
377 "Reshape tensor or use tensor decomposition".to_string(),
378 ),
379 });
380 }
381 }
382
383 if shape.contains(&0) {
385 warnings.push(ValidationWarning {
386 category: WarningCategory::Compatibility,
387 message: format!(
388 "Tensor '{}' has dynamic shape dimensions which may not be supported",
389 tensor_name
390 ),
391 location: Some(tensor_name.clone()),
392 impact: WarningImpact::High,
393 });
394 }
395
396 let total_elements: usize = shape.iter().product();
398 if total_elements > 1_000_000_000 {
399 warnings.push(ValidationWarning {
400 category: WarningCategory::Performance,
401 message: format!(
402 "Tensor '{}' is very large ({} elements), may cause memory issues",
403 tensor_name, total_elements
404 ),
405 location: Some(tensor_name.clone()),
406 impact: WarningImpact::Medium,
407 });
408 }
409 }
410
411 ValidationCheckResult {
412 errors,
413 warnings,
414 recommendations,
415 }
416 }
417
418 fn check_operations(&self, model: &MLModel) -> ValidationCheckResult {
420 let errors = Vec::new();
421 let mut warnings = Vec::new();
422 let mut recommendations = Vec::new();
423
424 match (&self.source_framework, &self.target_framework) {
427 (MLFramework::PyTorch, MLFramework::CoreML) => {
428 warnings.push(ValidationWarning {
429 category: WarningCategory::Compatibility,
430 message: "Some PyTorch operations may not have direct CoreML equivalents"
431 .to_string(),
432 location: None,
433 impact: WarningImpact::Medium,
434 });
435 }
436 (MLFramework::TensorFlow, MLFramework::PyTorch) => {
437 recommendations.push(ValidationRecommendation {
438 category: RecommendationCategory::Conversion,
439 message: "Consider using ONNX as intermediate format for TensorFlow -> PyTorch conversion".to_string(),
440 priority: RecommendationPriority::Medium,
441 estimated_effort: EstimatedEffort::Low,
442 });
443 }
444 _ => {}
445 }
446
447 ValidationCheckResult {
448 errors,
449 warnings,
450 recommendations,
451 }
452 }
453
454 fn check_metadata(&self, model: &MLModel) -> ValidationCheckResult {
456 let errors = Vec::new();
457 let mut warnings = Vec::new();
458 let mut recommendations = Vec::new();
459
460 if let Some(ref framework_version) = model.metadata.framework_version {
462 if framework_version.starts_with("0.") {
464 warnings.push(ValidationWarning {
465 category: WarningCategory::Compatibility,
466 message: format!(
467 "Framework version {} appears to be a pre-release version",
468 framework_version
469 ),
470 location: None,
471 impact: WarningImpact::Low,
472 });
473 }
474 }
475
476 if model.metadata.model_name.is_none() {
478 recommendations.push(ValidationRecommendation {
479 category: RecommendationCategory::BestPractice,
480 message: "Consider adding a model name for better tracking".to_string(),
481 priority: RecommendationPriority::Low,
482 estimated_effort: EstimatedEffort::Minimal,
483 });
484 }
485
486 if model.config.is_empty() {
488 warnings.push(ValidationWarning {
489 category: WarningCategory::BestPractice,
490 message: "Model configuration is empty, may cause issues during conversion"
491 .to_string(),
492 location: None,
493 impact: WarningImpact::Low,
494 });
495 }
496
497 ValidationCheckResult {
498 errors,
499 warnings,
500 recommendations,
501 }
502 }
503
504 fn calculate_compatibility_score(
506 &self,
507 errors: &[ValidationError],
508 warnings: &[ValidationWarning],
509 ) -> f32 {
510 let base_score = crate::ml_framework::validation::utils::quick_compatibility_check(
511 self.source_framework,
512 self.target_framework,
513 );
514
515 let error_penalty: f32 = errors
517 .iter()
518 .map(|e| match e.severity {
519 ErrorSeverity::Critical => 0.5,
520 ErrorSeverity::High => 0.3,
521 ErrorSeverity::Medium => 0.1,
522 ErrorSeverity::Low => 0.05,
523 })
524 .sum();
525
526 let warning_penalty: f32 = warnings
527 .iter()
528 .map(|w| match w.impact {
529 WarningImpact::High => 0.1,
530 WarningImpact::Medium => 0.05,
531 WarningImpact::Low => 0.02,
532 })
533 .sum();
534
535 (base_score - error_penalty - warning_penalty)
536 .max(0.0)
537 .min(1.0)
538 }
539
540 fn generate_conversion_path(
542 &self,
543 _model: &MLModel,
544 errors: &[ValidationError],
545 warnings: &[ValidationWarning],
546 ) -> Result<ConversionPath> {
547 let mut steps = Vec::new();
548
549 let has_dtype_issues = errors.iter().any(|e| e.category == ErrorCategory::DataType)
551 || warnings
552 .iter()
553 .any(|w| w.category == WarningCategory::Precision);
554
555 let hasshape_issues = errors.iter().any(|e| e.category == ErrorCategory::Shape);
556
557 if has_dtype_issues {
558 steps.push(ConversionStep {
559 operation: ConversionOperation::TypeConversion,
560 description: "Convert incompatible data types".to_string(),
561 required_tools: vec!["dtype_converter".to_string()],
562 estimated_time: EstimatedEffort::Low,
563 });
564 }
565
566 if hasshape_issues {
567 steps.push(ConversionStep {
568 operation: ConversionOperation::ShapeReshaping,
569 description: "Reshape tensors for target framework".to_string(),
570 required_tools: vec!["shape_converter".to_string()],
571 estimated_time: EstimatedEffort::Medium,
572 });
573 }
574
575 let conversion_complexity = if steps.is_empty() {
577 ConversionComplexity::Trivial
578 } else if steps.len() <= 2 {
579 ConversionComplexity::Simple
580 } else {
581 ConversionComplexity::Moderate
582 };
583
584 steps.push(ConversionStep {
585 operation: ConversionOperation::DirectConversion,
586 description: format!(
587 "Convert from {:?} to {:?}",
588 self.source_framework, self.target_framework
589 ),
590 required_tools: vec![format!("{:?}_converter", self.target_framework)],
591 estimated_time: match conversion_complexity {
592 ConversionComplexity::Trivial => EstimatedEffort::Minimal,
593 ConversionComplexity::Simple => EstimatedEffort::Low,
594 _ => EstimatedEffort::Medium,
595 },
596 });
597
598 Ok(ConversionPath {
599 steps,
600 estimated_accuracy_loss: if has_dtype_issues { 0.05 } else { 0.01 },
601 estimated_performance_impact: if hasshape_issues { 0.1 } else { 0.02 },
602 complexity: conversion_complexity,
603 })
604 }
605}
606
607pub struct BatchValidator {
609 validators: Vec<ModelValidator>,
610 #[allow(dead_code)]
611 parallel: bool,
612}
613
614impl Default for BatchValidator {
615 fn default() -> Self {
616 Self::new()
617 }
618}
619
620impl BatchValidator {
621 pub fn new() -> Self {
622 Self {
623 validators: Vec::new(),
624 parallel: true,
625 }
626 }
627
628 pub fn add_validation(
629 &mut self,
630 source: MLFramework,
631 target: MLFramework,
632 config: ValidationConfig,
633 ) {
634 self.validators
635 .push(ModelValidator::new(source, target, config));
636 }
637
638 pub fn validate_all(&self, models: &[MLModel]) -> Result<Vec<ValidationReport>> {
639 let mut reports = Vec::new();
640
641 for model in models {
642 for validator in &self.validators {
643 reports.push(validator.validate(model)?);
644 }
645 }
646
647 Ok(reports)
648 }
649}
650
651pub mod utils {
653 use super::*;
654
655 pub fn quick_compatibility_check(source: MLFramework, target: MLFramework) -> f32 {
657 if source == target {
659 1.0
660 } else if matches!(
661 (source, target),
662 (MLFramework::PyTorch, MLFramework::ONNX)
663 | (MLFramework::TensorFlow, MLFramework::ONNX)
664 | (MLFramework::ONNX, MLFramework::PyTorch)
665 | (MLFramework::ONNX, MLFramework::TensorFlow)
666 ) {
667 0.9
668 } else {
669 0.5
670 }
671 }
672
673 pub fn generate_compatibility_matrix() -> BTreeMap<String, BTreeMap<String, f32>> {
675 let frameworks = [
676 MLFramework::PyTorch,
677 MLFramework::TensorFlow,
678 MLFramework::ONNX,
679 MLFramework::SafeTensors,
680 MLFramework::JAX,
681 MLFramework::MXNet,
682 MLFramework::CoreML,
683 MLFramework::HuggingFace,
684 ];
685
686 let mut matrix = BTreeMap::new();
687
688 for source in &frameworks {
689 let mut row = BTreeMap::new();
690 for target in &frameworks {
691 let score = quick_compatibility_check(*source, *target);
692 row.insert(format!("{:?}", target), score);
693 }
694 matrix.insert(format!("{:?}", source), row);
695 }
696
697 matrix
698 }
699
700 pub fn find_best_conversion_path(source: MLFramework, target: MLFramework) -> Vec<MLFramework> {
702 if source == target {
704 return vec![source];
705 }
706
707 if quick_compatibility_check(source, target) > 0.7 {
709 return vec![source, target];
710 }
711
712 if quick_compatibility_check(source, MLFramework::ONNX) > 0.7
714 && quick_compatibility_check(MLFramework::ONNX, target) > 0.7
715 {
716 return vec![source, MLFramework::ONNX, target];
717 }
718
719 vec![source, target]
721 }
722}
723
724#[derive(Debug, Clone)]
726struct FrameworkCompatibilityResult {
727 error: Option<ValidationError>,
728 warnings: Vec<ValidationWarning>,
729 recommendations: Vec<ValidationRecommendation>,
730}
731
732#[derive(Debug, Clone)]
733struct ValidationCheckResult {
734 errors: Vec<ValidationError>,
735 warnings: Vec<ValidationWarning>,
736 recommendations: Vec<ValidationRecommendation>,
737}
738
739#[derive(Debug, Clone)]
740struct FrameworkCompatibility {
741 level: CompatibilityLevel,
742 recommendations: Vec<ValidationRecommendation>,
743}
744
745#[derive(Debug, Clone)]
746enum CompatibilityLevel {
747 FullyCompatible,
748 MostlyCompatible,
749 PartiallyCompatible,
750 #[allow(dead_code)]
751 Incompatible,
752}