Skip to main content

torsh_package/
format_compat.rs

1//! Format compatibility layer for different model package formats
2//!
3//! This module provides compatibility with external package formats including
4//! PyTorch torch.package, HuggingFace Hub, ONNX models, and MLflow packages.
5
6use std::collections::HashMap;
7use std::fs;
8
9use oxiarc_archive::zip::{ZipCompressionLevel, ZipReader, ZipWriter};
10use serde::{Deserialize, Serialize};
11use torsh_core::error::{Result, TorshError};
12
13use crate::package::Package;
14use crate::resources::{Resource, ResourceType};
15
16/// Supported external package formats
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum PackageFormat {
19    /// PyTorch torch.package format
20    PyTorch,
21    /// HuggingFace Hub format
22    HuggingFace,
23    /// ONNX model format
24    Onnx,
25    /// MLflow model format
26    MLflow,
27    /// Native ToRSh format
28    ToRSh,
29}
30
31/// Format-specific converter trait
32pub trait FormatConverter {
33    /// Convert from the external format to ToRSh Package
34    fn import_from_format(&self, path: &std::path::Path) -> Result<Package>;
35
36    /// Convert from ToRSh Package to the external format
37    fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()>;
38
39    /// Get the format this converter handles
40    fn format(&self) -> PackageFormat;
41
42    /// Validate if a path contains a valid package of this format
43    fn is_valid_format(&self, path: &std::path::Path) -> bool;
44}
45
46/// PyTorch torch.package compatibility layer
47pub struct PyTorchConverter {
48    preserve_python_code: bool,
49    extract_models: bool,
50}
51
52/// HuggingFace Hub compatibility layer
53pub struct HuggingFaceConverter {
54    include_tokenizer: bool,
55    include_config: bool,
56    model_type: Option<String>,
57}
58
59/// ONNX model compatibility layer
60pub struct OnnxConverter {
61    include_metadata: bool,
62    optimize_for_inference: bool,
63}
64
65/// MLflow model compatibility layer
66pub struct MLflowConverter {
67    include_conda_env: bool,
68    include_requirements: bool,
69    flavor: Option<String>,
70}
71
72/// PyTorch package manifest structure (simplified)
73#[derive(Debug, Clone, Serialize, Deserialize)]
74struct PyTorchManifest {
75    code_version: String,
76    main_module: String,
77    dependencies: Vec<String>,
78    python_version: Option<String>,
79}
80
81/// HuggingFace model configuration
82#[derive(Debug, Clone, Serialize, Deserialize)]
83struct HuggingFaceConfig {
84    model_type: String,
85    task: Option<String>,
86    architectures: Option<Vec<String>>,
87    tokenizer_class: Option<String>,
88    vocab_size: Option<u64>,
89}
90
91/// ONNX model metadata
92#[derive(Debug, Clone, Serialize, Deserialize)]
93struct OnnxMetadata {
94    ir_version: i64,
95    producer_name: String,
96    producer_version: String,
97    domain: String,
98    model_version: i64,
99    doc_string: String,
100}
101
102/// MLflow model metadata
103#[derive(Debug, Clone, Serialize, Deserialize)]
104struct MLflowMetadata {
105    artifact_path: String,
106    flavors: HashMap<String, serde_json::Value>,
107    model_uuid: String,
108    run_id: String,
109    utc_time_created: String,
110    mlflow_version: String,
111}
112
113impl Default for PyTorchConverter {
114    fn default() -> Self {
115        Self {
116            preserve_python_code: true,
117            extract_models: true,
118        }
119    }
120}
121
122impl PyTorchConverter {
123    /// Create a new PyTorch converter
124    pub fn new() -> Self {
125        Self::default()
126    }
127
128    /// Configure whether to preserve Python code
129    pub fn with_preserve_python_code(mut self, preserve: bool) -> Self {
130        self.preserve_python_code = preserve;
131        self
132    }
133
134    /// Configure whether to extract model weights
135    pub fn with_extract_models(mut self, extract: bool) -> Self {
136        self.extract_models = extract;
137        self
138    }
139
140    /// Extract PyTorch package contents
141    fn extract_pytorch_package(
142        &self,
143        path: &std::path::Path,
144    ) -> Result<(PyTorchManifest, Vec<Resource>)> {
145        let file = fs::File::open(path)
146            .map_err(|e| TorshError::IoError(format!("Failed to open PyTorch package: {}", e)))?;
147
148        let mut archive = ZipReader::new(file)
149            .map_err(|e| TorshError::InvalidArgument(format!("Invalid ZIP archive: {}", e)))?;
150
151        let mut manifest = None;
152        let mut resources = Vec::new();
153
154        // Collect entries to avoid borrow checker issues
155        let entries: Vec<_> = archive.entries().to_vec();
156
157        for entry in entries {
158            let file_name = entry.name.clone();
159
160            // Read file contents
161            let contents = archive
162                .extract(&entry)
163                .map_err(|e| TorshError::IoError(format!("Failed to read archive entry: {}", e)))?;
164
165            if file_name == ".data/version" {
166                // PyTorch package version info - convert to manifest
167                let version_str = String::from_utf8(contents).map_err(|_| {
168                    TorshError::InvalidArgument("Invalid UTF-8 in version file".to_string())
169                })?;
170
171                manifest = Some(PyTorchManifest {
172                    code_version: version_str.trim().to_string(),
173                    main_module: "main".to_string(), // Default
174                    dependencies: Vec::new(),
175                    python_version: None,
176                });
177            } else if file_name.ends_with(".py") && self.preserve_python_code {
178                // Python source code
179                resources.push(Resource {
180                    name: file_name.clone(),
181                    resource_type: ResourceType::Source,
182                    data: contents,
183                    metadata: {
184                        let mut meta = HashMap::new();
185                        meta.insert("language".to_string(), "python".to_string());
186                        meta.insert("original_format".to_string(), "pytorch".to_string());
187                        meta
188                    },
189                });
190            } else if file_name.ends_with(".pkl") && self.extract_models {
191                // Pickle files (likely model weights)
192                resources.push(Resource {
193                    name: file_name.clone(),
194                    resource_type: ResourceType::Model,
195                    data: contents,
196                    metadata: {
197                        let mut meta = HashMap::new();
198                        meta.insert("format".to_string(), "pickle".to_string());
199                        meta.insert("original_format".to_string(), "pytorch".to_string());
200                        meta
201                    },
202                });
203            } else {
204                // Other data files
205                resources.push(Resource {
206                    name: file_name.clone(),
207                    resource_type: ResourceType::Data,
208                    data: contents,
209                    metadata: {
210                        let mut meta = HashMap::new();
211                        meta.insert("original_format".to_string(), "pytorch".to_string());
212                        meta
213                    },
214                });
215            }
216        }
217
218        let manifest = manifest.unwrap_or_else(|| PyTorchManifest {
219            code_version: "1.0.0".to_string(),
220            main_module: "main".to_string(),
221            dependencies: Vec::new(),
222            python_version: None,
223        });
224
225        Ok((manifest, resources))
226    }
227}
228
229impl FormatConverter for PyTorchConverter {
230    fn import_from_format(&self, path: &std::path::Path) -> Result<Package> {
231        let (pytorch_manifest, resources) = self.extract_pytorch_package(path)?;
232
233        let package_name = path
234            .file_stem()
235            .and_then(|s| s.to_str())
236            .unwrap_or("imported_pytorch_model")
237            .to_string();
238
239        let mut package = Package::new(package_name, pytorch_manifest.code_version);
240
241        // Add resources
242        for resource in resources {
243            package.add_resource(resource);
244        }
245
246        // Add PyTorch-specific metadata
247        package
248            .manifest_mut()
249            .metadata
250            .insert("original_format".to_string(), "pytorch".to_string());
251        package
252            .manifest_mut()
253            .metadata
254            .insert("main_module".to_string(), pytorch_manifest.main_module);
255
256        if let Some(python_version) = pytorch_manifest.python_version {
257            package
258                .manifest_mut()
259                .metadata
260                .insert("python_version".to_string(), python_version);
261        }
262
263        // Add dependencies
264        for dep in pytorch_manifest.dependencies {
265            package.add_dependency(&dep, "*");
266        }
267
268        Ok(package)
269    }
270
271    fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()> {
272        let file = fs::File::create(path)
273            .map_err(|e| TorshError::IoError(format!("Failed to create output file: {}", e)))?;
274
275        let mut zip = ZipWriter::new(file);
276        zip.set_compression(ZipCompressionLevel::Normal);
277
278        // Add version file
279        let version_data = package.get_version().as_bytes();
280        zip.add_file(".data/version", version_data)
281            .map_err(|e| TorshError::IoError(format!("Failed to create version file: {}", e)))?;
282
283        // Add resources
284        for (name, resource) in package.resources() {
285            let file_path =
286                if resource.resource_type == ResourceType::Source && name.ends_with(".py") {
287                    format!("code/{}", name)
288                } else if resource.resource_type == ResourceType::Model {
289                    format!("data/{}", name)
290                } else {
291                    name.clone()
292                };
293
294            zip.add_file(&file_path, &resource.data).map_err(|e| {
295                TorshError::IoError(format!("Failed to create file {}: {}", file_path, e))
296            })?;
297        }
298
299        zip.finish()
300            .map_err(|e| TorshError::IoError(format!("Failed to finalize ZIP archive: {}", e)))?;
301
302        Ok(())
303    }
304
305    fn format(&self) -> PackageFormat {
306        PackageFormat::PyTorch
307    }
308
309    fn is_valid_format(&self, path: &std::path::Path) -> bool {
310        // Check if it's a ZIP file with PyTorch package structure
311        if let Ok(file) = fs::File::open(path) {
312            if let Ok(archive) = ZipReader::new(file) {
313                // Look for characteristic PyTorch package files
314                for entry in archive.entries() {
315                    let name = &entry.name;
316                    if name == ".data/version"
317                        || name.starts_with("code/")
318                        || name.ends_with(".pkl")
319                    {
320                        return true;
321                    }
322                }
323            }
324        }
325        false
326    }
327}
328
329impl Default for HuggingFaceConverter {
330    fn default() -> Self {
331        Self {
332            include_tokenizer: true,
333            include_config: true,
334            model_type: None,
335        }
336    }
337}
338
339impl HuggingFaceConverter {
340    /// Create a new HuggingFace converter
341    pub fn new() -> Self {
342        Self::default()
343    }
344
345    /// Configure whether to include tokenizer files
346    pub fn with_include_tokenizer(mut self, include: bool) -> Self {
347        self.include_tokenizer = include;
348        self
349    }
350
351    /// Configure whether to include model configuration
352    pub fn with_include_config(mut self, include: bool) -> Self {
353        self.include_config = include;
354        self
355    }
356
357    /// Set expected model type
358    pub fn with_model_type(mut self, model_type: String) -> Self {
359        self.model_type = Some(model_type);
360        self
361    }
362
363    /// Load HuggingFace model directory
364    fn load_huggingface_model(
365        &self,
366        path: &std::path::Path,
367    ) -> Result<(HuggingFaceConfig, Vec<Resource>)> {
368        let model_dir = path;
369
370        if !model_dir.is_dir() {
371            return Err(TorshError::InvalidArgument(
372                "HuggingFace path must be a directory".to_string(),
373            ));
374        }
375
376        let mut config = None;
377        let mut resources = Vec::new();
378
379        // Read configuration file
380        let config_path = model_dir.join("config.json");
381        if config_path.exists() && self.include_config {
382            let config_data = fs::read(&config_path)
383                .map_err(|e| TorshError::IoError(format!("Failed to read config.json: {}", e)))?;
384
385            config = Some(
386                serde_json::from_slice::<HuggingFaceConfig>(&config_data).map_err(|e| {
387                    TorshError::SerializationError(format!("Invalid config.json: {}", e))
388                })?,
389            );
390
391            resources.push(Resource {
392                name: "config.json".to_string(),
393                resource_type: ResourceType::Config,
394                data: config_data,
395                metadata: {
396                    let mut meta = HashMap::new();
397                    meta.insert("original_format".to_string(), "huggingface".to_string());
398                    meta
399                },
400            });
401        }
402
403        // Read model files (pytorch_model.bin, model.safetensors, etc.)
404        for entry in fs::read_dir(model_dir)
405            .map_err(|e| TorshError::IoError(format!("Failed to read model directory: {}", e)))?
406        {
407            let entry = entry.map_err(|e| {
408                TorshError::IoError(format!("Failed to read directory entry: {}", e))
409            })?;
410            let file_path = entry.path();
411            let file_name = file_path
412                .file_name()
413                .and_then(|n| n.to_str())
414                .unwrap_or("")
415                .to_string();
416
417            if file_name.ends_with(".bin") || file_name.ends_with(".safetensors") {
418                // Model weight files
419                let data = fs::read(&file_path).map_err(|e| {
420                    TorshError::IoError(format!("Failed to read {}: {}", file_name, e))
421                })?;
422
423                resources.push(Resource {
424                    name: file_name.clone(),
425                    resource_type: ResourceType::Model,
426                    data,
427                    metadata: {
428                        let mut meta = HashMap::new();
429                        meta.insert("original_format".to_string(), "huggingface".to_string());
430                        if file_name.ends_with(".safetensors") {
431                            meta.insert("format".to_string(), "safetensors".to_string());
432                        } else {
433                            meta.insert("format".to_string(), "pytorch".to_string());
434                        }
435                        meta
436                    },
437                });
438            } else if self.include_tokenizer
439                && (file_name.starts_with("tokenizer") || file_name.ends_with(".json"))
440            {
441                // Tokenizer files
442                let data = fs::read(&file_path).map_err(|e| {
443                    TorshError::IoError(format!("Failed to read {}: {}", file_name, e))
444                })?;
445
446                resources.push(Resource {
447                    name: file_name,
448                    resource_type: ResourceType::Data,
449                    data,
450                    metadata: {
451                        let mut meta = HashMap::new();
452                        meta.insert("original_format".to_string(), "huggingface".to_string());
453                        meta.insert("type".to_string(), "tokenizer".to_string());
454                        meta
455                    },
456                });
457            }
458        }
459
460        let config = config.unwrap_or_else(|| HuggingFaceConfig {
461            model_type: self
462                .model_type
463                .clone()
464                .unwrap_or_else(|| "unknown".to_string()),
465            task: None,
466            architectures: None,
467            tokenizer_class: None,
468            vocab_size: None,
469        });
470
471        Ok((config, resources))
472    }
473}
474
475impl FormatConverter for HuggingFaceConverter {
476    fn import_from_format(&self, path: &std::path::Path) -> Result<Package> {
477        let (hf_config, resources) = self.load_huggingface_model(path)?;
478
479        let package_name = path
480            .file_stem()
481            .and_then(|s| s.to_str())
482            .unwrap_or("imported_huggingface_model")
483            .to_string();
484
485        let mut package = Package::new(package_name, "1.0.0".to_string());
486
487        // Add resources
488        for resource in resources {
489            package.add_resource(resource);
490        }
491
492        // Add HuggingFace-specific metadata
493        package
494            .manifest_mut()
495            .metadata
496            .insert("original_format".to_string(), "huggingface".to_string());
497        package
498            .manifest_mut()
499            .metadata
500            .insert("model_type".to_string(), hf_config.model_type);
501
502        if let Some(task) = hf_config.task {
503            package
504                .manifest_mut()
505                .metadata
506                .insert("task".to_string(), task);
507        }
508
509        if let Some(architectures) = hf_config.architectures {
510            package.manifest_mut().metadata.insert(
511                "architectures".to_string(),
512                serde_json::to_string(&architectures).unwrap_or_default(),
513            );
514        }
515
516        Ok(package)
517    }
518
519    fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()> {
520        let output_dir = path;
521
522        if !output_dir.exists() {
523            fs::create_dir_all(output_dir).map_err(|e| {
524                TorshError::IoError(format!("Failed to create output directory: {}", e))
525            })?;
526        }
527
528        // Export resources to appropriate files
529        for (name, resource) in package.resources() {
530            let file_path = output_dir.join(name);
531            fs::write(&file_path, &resource.data)
532                .map_err(|e| TorshError::IoError(format!("Failed to write {}: {}", name, e)))?;
533        }
534
535        // Create or update config.json if not present
536        let config_path = output_dir.join("config.json");
537        if !config_path.exists() {
538            let default_config = HuggingFaceConfig {
539                model_type: package
540                    .metadata()
541                    .metadata
542                    .get("model_type")
543                    .cloned()
544                    .unwrap_or_else(|| "unknown".to_string()),
545                task: package.metadata().metadata.get("task").cloned(),
546                architectures: package
547                    .metadata()
548                    .metadata
549                    .get("architectures")
550                    .and_then(|s| serde_json::from_str(s).ok()),
551                tokenizer_class: None,
552                vocab_size: None,
553            };
554
555            let config_json = serde_json::to_string_pretty(&default_config).map_err(|e| {
556                TorshError::SerializationError(format!("Failed to serialize config: {}", e))
557            })?;
558
559            fs::write(&config_path, config_json)
560                .map_err(|e| TorshError::IoError(format!("Failed to write config.json: {}", e)))?;
561        }
562
563        Ok(())
564    }
565
566    fn format(&self) -> PackageFormat {
567        PackageFormat::HuggingFace
568    }
569
570    fn is_valid_format(&self, path: &std::path::Path) -> bool {
571        let model_dir = path;
572
573        if !model_dir.is_dir() {
574            return false;
575        }
576
577        // Check for characteristic HuggingFace files
578        let config_path = model_dir.join("config.json");
579        if config_path.exists() {
580            return true;
581        }
582
583        // Check for model weight files
584        if let Ok(entries) = fs::read_dir(model_dir) {
585            for entry in entries {
586                if let Ok(entry) = entry {
587                    let file_name = entry.file_name();
588                    let file_name_str = file_name.to_string_lossy();
589                    if file_name_str.ends_with(".bin") || file_name_str.ends_with(".safetensors") {
590                        return true;
591                    }
592                }
593            }
594        }
595
596        false
597    }
598}
599
600impl Default for OnnxConverter {
601    fn default() -> Self {
602        Self {
603            include_metadata: true,
604            optimize_for_inference: false,
605        }
606    }
607}
608
609impl OnnxConverter {
610    /// Create a new ONNX converter
611    pub fn new() -> Self {
612        Self::default()
613    }
614
615    /// Configure whether to include model metadata
616    pub fn with_include_metadata(mut self, include: bool) -> Self {
617        self.include_metadata = include;
618        self
619    }
620
621    /// Configure optimization for inference
622    pub fn with_optimize_for_inference(mut self, optimize: bool) -> Self {
623        self.optimize_for_inference = optimize;
624        self
625    }
626
627    /// Extract ONNX model metadata
628    fn extract_onnx_metadata(&self, path: &std::path::Path) -> Result<OnnxMetadata> {
629        // For now, return a basic metadata structure
630        // In a real implementation, you would parse the ONNX protobuf format
631        Ok(OnnxMetadata {
632            ir_version: 8,
633            producer_name: "torsh-package".to_string(),
634            producer_version: "1.0.0".to_string(),
635            domain: "ai.onnx".to_string(),
636            model_version: 1,
637            doc_string: format!("ONNX model imported from {:?}", path),
638        })
639    }
640}
641
642impl FormatConverter for OnnxConverter {
643    fn import_from_format(&self, path: &std::path::Path) -> Result<Package> {
644        let model_data = fs::read(path)
645            .map_err(|e| TorshError::IoError(format!("Failed to read ONNX model: {}", e)))?;
646
647        let package_name = path
648            .file_stem()
649            .and_then(|s| s.to_str())
650            .unwrap_or("imported_onnx_model")
651            .to_string();
652
653        let mut package = Package::new(package_name, "1.0.0".to_string());
654
655        // Add ONNX model as a resource
656        let model_resource = Resource {
657            name: "model.onnx".to_string(),
658            resource_type: ResourceType::Model,
659            data: model_data,
660            metadata: {
661                let mut meta = HashMap::new();
662                meta.insert("original_format".to_string(), "onnx".to_string());
663                meta.insert("format".to_string(), "onnx".to_string());
664                meta
665            },
666        };
667        package.add_resource(model_resource);
668
669        // Add metadata if requested
670        if self.include_metadata {
671            let onnx_metadata = self.extract_onnx_metadata(path)?;
672            package.manifest_mut().metadata.insert(
673                "onnx_ir_version".to_string(),
674                onnx_metadata.ir_version.to_string(),
675            );
676            package
677                .manifest_mut()
678                .metadata
679                .insert("onnx_producer".to_string(), onnx_metadata.producer_name);
680            package.manifest_mut().metadata.insert(
681                "onnx_producer_version".to_string(),
682                onnx_metadata.producer_version,
683            );
684        }
685
686        package
687            .manifest_mut()
688            .metadata
689            .insert("original_format".to_string(), "onnx".to_string());
690
691        Ok(package)
692    }
693
694    fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()> {
695        // Find the ONNX model resource
696        let model_resource = package
697            .resources()
698            .iter()
699            .find(|(_, resource)| {
700                resource.resource_type == ResourceType::Model
701                    && (resource.name.ends_with(".onnx")
702                        || resource
703                            .metadata
704                            .get("format")
705                            .map_or(false, |f| f == "onnx"))
706            })
707            .map(|(_, resource)| resource)
708            .ok_or_else(|| {
709                TorshError::InvalidArgument("No ONNX model found in package".to_string())
710            })?;
711
712        // Write the ONNX model to file
713        fs::write(path, &model_resource.data)
714            .map_err(|e| TorshError::IoError(format!("Failed to write ONNX model: {}", e)))?;
715
716        Ok(())
717    }
718
719    fn format(&self) -> PackageFormat {
720        PackageFormat::Onnx
721    }
722
723    fn is_valid_format(&self, path: &std::path::Path) -> bool {
724        if let Ok(file) = fs::File::open(path) {
725            use std::io::Read;
726            let mut buffer = [0u8; 16];
727            let mut reader = std::io::BufReader::new(file);
728
729            // Check for ONNX magic bytes (protobuf)
730            if reader.read_exact(&mut buffer).is_ok() {
731                // ONNX files typically start with protobuf headers
732                // This is a simplified check; real implementation would parse protobuf
733                return path
734                    .extension()
735                    .and_then(|e| e.to_str())
736                    .map_or(false, |e| e == "onnx");
737            }
738        }
739        false
740    }
741}
742
743impl Default for MLflowConverter {
744    fn default() -> Self {
745        Self {
746            include_conda_env: true,
747            include_requirements: true,
748            flavor: None,
749        }
750    }
751}
752
753impl MLflowConverter {
754    /// Create a new MLflow converter
755    pub fn new() -> Self {
756        Self::default()
757    }
758
759    /// Configure whether to include conda environment
760    pub fn with_include_conda_env(mut self, include: bool) -> Self {
761        self.include_conda_env = include;
762        self
763    }
764
765    /// Configure whether to include requirements.txt
766    pub fn with_include_requirements(mut self, include: bool) -> Self {
767        self.include_requirements = include;
768        self
769    }
770
771    /// Set MLflow flavor
772    pub fn with_flavor(mut self, flavor: String) -> Self {
773        self.flavor = Some(flavor);
774        self
775    }
776
777    /// Load MLflow model directory
778    fn load_mlflow_model(&self, path: &std::path::Path) -> Result<(MLflowMetadata, Vec<Resource>)> {
779        if !path.is_dir() {
780            return Err(TorshError::InvalidArgument(
781                "MLflow path must be a directory".to_string(),
782            ));
783        }
784
785        let mut metadata = None;
786        let mut resources = Vec::new();
787
788        // Read MLmodel file
789        let mlmodel_path = path.join("MLmodel");
790        if mlmodel_path.exists() {
791            let mlmodel_data = fs::read_to_string(&mlmodel_path)
792                .map_err(|e| TorshError::IoError(format!("Failed to read MLmodel: {}", e)))?;
793
794            // Parse MLmodel (YAML format)
795            // For simplicity, we'll create a basic metadata structure
796            metadata = Some(MLflowMetadata {
797                artifact_path: path.to_string_lossy().to_string(),
798                flavors: HashMap::new(),
799                model_uuid: uuid::Uuid::new_v4().to_string(),
800                run_id: "imported".to_string(),
801                utc_time_created: chrono::Utc::now().to_rfc3339(),
802                mlflow_version: "2.0.0".to_string(),
803            });
804
805            resources.push(Resource {
806                name: "MLmodel".to_string(),
807                resource_type: ResourceType::Config,
808                data: mlmodel_data.into_bytes(),
809                metadata: {
810                    let mut meta = HashMap::new();
811                    meta.insert("original_format".to_string(), "mlflow".to_string());
812                    meta
813                },
814            });
815        }
816
817        // Read model files from subdirectories
818        for entry in fs::read_dir(path)
819            .map_err(|e| TorshError::IoError(format!("Failed to read MLflow directory: {}", e)))?
820        {
821            let entry = entry.map_err(|e| {
822                TorshError::IoError(format!("Failed to read directory entry: {}", e))
823            })?;
824            let file_path = entry.path();
825
826            if file_path.is_file() {
827                let file_name = file_path
828                    .file_name()
829                    .and_then(|n| n.to_str())
830                    .unwrap_or("")
831                    .to_string();
832
833                if file_name != "MLmodel" {
834                    let data = fs::read(&file_path).map_err(|e| {
835                        TorshError::IoError(format!("Failed to read {}: {}", file_name, e))
836                    })?;
837
838                    let resource_type = if file_name.ends_with(".pkl")
839                        || file_name.ends_with(".pt")
840                        || file_name.ends_with(".h5")
841                    {
842                        ResourceType::Model
843                    } else if file_name.ends_with(".json") || file_name.ends_with(".yaml") {
844                        ResourceType::Config
845                    } else if file_name == "requirements.txt" || file_name == "conda.yaml" {
846                        ResourceType::Documentation
847                    } else {
848                        ResourceType::Data
849                    };
850
851                    resources.push(Resource {
852                        name: file_name,
853                        resource_type,
854                        data,
855                        metadata: {
856                            let mut meta = HashMap::new();
857                            meta.insert("original_format".to_string(), "mlflow".to_string());
858                            meta
859                        },
860                    });
861                }
862            }
863        }
864
865        let metadata = metadata.unwrap_or_else(|| MLflowMetadata {
866            artifact_path: path.to_string_lossy().to_string(),
867            flavors: HashMap::new(),
868            model_uuid: uuid::Uuid::new_v4().to_string(),
869            run_id: "imported".to_string(),
870            utc_time_created: chrono::Utc::now().to_rfc3339(),
871            mlflow_version: "2.0.0".to_string(),
872        });
873
874        Ok((metadata, resources))
875    }
876}
877
878impl FormatConverter for MLflowConverter {
879    fn import_from_format(&self, path: &std::path::Path) -> Result<Package> {
880        let (mlflow_metadata, resources) = self.load_mlflow_model(path)?;
881
882        let package_name = path
883            .file_stem()
884            .and_then(|s| s.to_str())
885            .unwrap_or("imported_mlflow_model")
886            .to_string();
887
888        let mut package = Package::new(package_name, "1.0.0".to_string());
889
890        // Add resources
891        for resource in resources {
892            package.add_resource(resource);
893        }
894
895        // Add MLflow-specific metadata
896        package
897            .manifest_mut()
898            .metadata
899            .insert("original_format".to_string(), "mlflow".to_string());
900        package
901            .manifest_mut()
902            .metadata
903            .insert("mlflow_version".to_string(), mlflow_metadata.mlflow_version);
904        package
905            .manifest_mut()
906            .metadata
907            .insert("model_uuid".to_string(), mlflow_metadata.model_uuid);
908        package
909            .manifest_mut()
910            .metadata
911            .insert("run_id".to_string(), mlflow_metadata.run_id);
912
913        if let Some(flavor) = &self.flavor {
914            package
915                .manifest_mut()
916                .metadata
917                .insert("flavor".to_string(), flavor.clone());
918        }
919
920        Ok(package)
921    }
922
923    fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()> {
924        let output_dir = path;
925
926        if !output_dir.exists() {
927            fs::create_dir_all(output_dir).map_err(|e| {
928                TorshError::IoError(format!("Failed to create output directory: {}", e))
929            })?;
930        }
931
932        // Export resources to appropriate files
933        for (name, resource) in package.resources() {
934            let file_path = output_dir.join(name);
935            fs::write(&file_path, &resource.data)
936                .map_err(|e| TorshError::IoError(format!("Failed to write {}: {}", name, e)))?;
937        }
938
939        // Create MLmodel file if not present
940        let mlmodel_path = output_dir.join("MLmodel");
941        if !mlmodel_path.exists() {
942            let mlmodel_content = format!(
943                r#"artifact_path: {}
944flavors:
945  python_function:
946    env: conda.yaml
947    loader_module: mlflow.pyfunc.model
948    python_version: 3.9
949model_uuid: {}
950run_id: {}
951utc_time_created: '{}'
952mlflow_version: 2.0.0
953"#,
954                output_dir.to_string_lossy(),
955                package
956                    .metadata()
957                    .metadata
958                    .get("model_uuid")
959                    .cloned()
960                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
961                package
962                    .metadata()
963                    .metadata
964                    .get("run_id")
965                    .cloned()
966                    .unwrap_or_else(|| "exported".to_string()),
967                chrono::Utc::now().to_rfc3339()
968            );
969
970            fs::write(&mlmodel_path, mlmodel_content)
971                .map_err(|e| TorshError::IoError(format!("Failed to write MLmodel: {}", e)))?;
972        }
973
974        Ok(())
975    }
976
977    fn format(&self) -> PackageFormat {
978        PackageFormat::MLflow
979    }
980
981    fn is_valid_format(&self, path: &std::path::Path) -> bool {
982        if !path.is_dir() {
983            return false;
984        }
985
986        // Check for MLmodel file
987        let mlmodel_path = path.join("MLmodel");
988        mlmodel_path.exists()
989    }
990}
991
992/// Format compatibility manager
993pub struct FormatCompatibilityManager {
994    converters: HashMap<PackageFormat, Box<dyn FormatConverter>>,
995}
996
997impl Default for FormatCompatibilityManager {
998    fn default() -> Self {
999        let mut manager = Self {
1000            converters: HashMap::new(),
1001        };
1002
1003        // Register default converters
1004        manager.register_converter(Box::new(PyTorchConverter::new()));
1005        manager.register_converter(Box::new(HuggingFaceConverter::new()));
1006        manager.register_converter(Box::new(OnnxConverter::new()));
1007        manager.register_converter(Box::new(MLflowConverter::new()));
1008
1009        manager
1010    }
1011}
1012
1013impl FormatCompatibilityManager {
1014    /// Create a new format compatibility manager
1015    pub fn new() -> Self {
1016        Self::default()
1017    }
1018
1019    /// Register a format converter
1020    pub fn register_converter(&mut self, converter: Box<dyn FormatConverter>) {
1021        let format = converter.format();
1022        self.converters.insert(format, converter);
1023    }
1024
1025    /// Auto-detect format and import package
1026    pub fn import_package(&self, path: &std::path::Path) -> Result<(PackageFormat, Package)> {
1027        for (format, converter) in &self.converters {
1028            if converter.is_valid_format(path) {
1029                let package = converter.import_from_format(path)?;
1030                return Ok((*format, package));
1031            }
1032        }
1033
1034        Err(TorshError::InvalidArgument(
1035            "Unrecognized package format".to_string(),
1036        ))
1037    }
1038
1039    /// Export package to specific format
1040    pub fn export_package(
1041        &self,
1042        package: &Package,
1043        format: PackageFormat,
1044        path: &std::path::Path,
1045    ) -> Result<()> {
1046        let converter = self.converters.get(&format).ok_or_else(|| {
1047            TorshError::InvalidArgument(format!("Unsupported export format: {:?}", format))
1048        })?;
1049
1050        converter.export_to_format(package, path)
1051    }
1052
1053    /// List supported formats
1054    pub fn supported_formats(&self) -> Vec<PackageFormat> {
1055        self.converters.keys().copied().collect()
1056    }
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061    use super::*;
1062    use std::io::Write;
1063    use tempfile::TempDir;
1064
1065    #[test]
1066    fn test_pytorch_converter_format_detection() {
1067        let converter = PyTorchConverter::new();
1068        assert_eq!(converter.format(), PackageFormat::PyTorch);
1069    }
1070
1071    #[test]
1072    fn test_huggingface_converter_creation() {
1073        let converter = HuggingFaceConverter::new()
1074            .with_include_tokenizer(false)
1075            .with_model_type("bert".to_string());
1076
1077        assert_eq!(converter.format(), PackageFormat::HuggingFace);
1078        assert!(!converter.include_tokenizer);
1079        assert_eq!(converter.model_type, Some("bert".to_string()));
1080    }
1081
1082    #[test]
1083    fn test_format_manager() {
1084        let manager = FormatCompatibilityManager::new();
1085        let formats = manager.supported_formats();
1086
1087        assert!(formats.contains(&PackageFormat::PyTorch));
1088        assert!(formats.contains(&PackageFormat::HuggingFace));
1089    }
1090
1091    #[test]
1092    fn test_huggingface_directory_validation() {
1093        let temp_dir = TempDir::new().expect("Failed to create temp directory for test");
1094
1095        // Create a mock HuggingFace model directory
1096        let config_path = temp_dir.path().join("config.json");
1097        let mut config_file = fs::File::create(&config_path).unwrap();
1098        writeln!(
1099            config_file,
1100            r#"{{"model_type": "bert", "task": "text-classification"}}"#
1101        )
1102        .unwrap();
1103
1104        let converter = HuggingFaceConverter::new();
1105        assert!(converter.is_valid_format(temp_dir.path()));
1106    }
1107
1108    #[test]
1109    fn test_package_format_enum() {
1110        assert_eq!(PackageFormat::PyTorch, PackageFormat::PyTorch);
1111        assert_ne!(PackageFormat::PyTorch, PackageFormat::HuggingFace);
1112    }
1113}