1use std::collections::HashMap;
7use std::fs;
8
9use oxiarc_archive::zip::{ZipCompressionLevel, ZipReader, ZipWriter};
10use serde::{Deserialize, Serialize};
11use torsh_core::error::{Result, TorshError};
12
13use crate::package::Package;
14use crate::resources::{Resource, ResourceType};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum PackageFormat {
19 PyTorch,
21 HuggingFace,
23 Onnx,
25 MLflow,
27 ToRSh,
29}
30
31pub trait FormatConverter {
33 fn import_from_format(&self, path: &std::path::Path) -> Result<Package>;
35
36 fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()>;
38
39 fn format(&self) -> PackageFormat;
41
42 fn is_valid_format(&self, path: &std::path::Path) -> bool;
44}
45
46pub struct PyTorchConverter {
48 preserve_python_code: bool,
49 extract_models: bool,
50}
51
52pub struct HuggingFaceConverter {
54 include_tokenizer: bool,
55 include_config: bool,
56 model_type: Option<String>,
57}
58
59pub struct OnnxConverter {
61 include_metadata: bool,
62 optimize_for_inference: bool,
63}
64
65pub struct MLflowConverter {
67 include_conda_env: bool,
68 include_requirements: bool,
69 flavor: Option<String>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74struct PyTorchManifest {
75 code_version: String,
76 main_module: String,
77 dependencies: Vec<String>,
78 python_version: Option<String>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83struct HuggingFaceConfig {
84 model_type: String,
85 task: Option<String>,
86 architectures: Option<Vec<String>>,
87 tokenizer_class: Option<String>,
88 vocab_size: Option<u64>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93struct OnnxMetadata {
94 ir_version: i64,
95 producer_name: String,
96 producer_version: String,
97 domain: String,
98 model_version: i64,
99 doc_string: String,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104struct MLflowMetadata {
105 artifact_path: String,
106 flavors: HashMap<String, serde_json::Value>,
107 model_uuid: String,
108 run_id: String,
109 utc_time_created: String,
110 mlflow_version: String,
111}
112
113impl Default for PyTorchConverter {
114 fn default() -> Self {
115 Self {
116 preserve_python_code: true,
117 extract_models: true,
118 }
119 }
120}
121
122impl PyTorchConverter {
123 pub fn new() -> Self {
125 Self::default()
126 }
127
128 pub fn with_preserve_python_code(mut self, preserve: bool) -> Self {
130 self.preserve_python_code = preserve;
131 self
132 }
133
134 pub fn with_extract_models(mut self, extract: bool) -> Self {
136 self.extract_models = extract;
137 self
138 }
139
140 fn extract_pytorch_package(
142 &self,
143 path: &std::path::Path,
144 ) -> Result<(PyTorchManifest, Vec<Resource>)> {
145 let file = fs::File::open(path)
146 .map_err(|e| TorshError::IoError(format!("Failed to open PyTorch package: {}", e)))?;
147
148 let mut archive = ZipReader::new(file)
149 .map_err(|e| TorshError::InvalidArgument(format!("Invalid ZIP archive: {}", e)))?;
150
151 let mut manifest = None;
152 let mut resources = Vec::new();
153
154 let entries: Vec<_> = archive.entries().to_vec();
156
157 for entry in entries {
158 let file_name = entry.name.clone();
159
160 let contents = archive
162 .extract(&entry)
163 .map_err(|e| TorshError::IoError(format!("Failed to read archive entry: {}", e)))?;
164
165 if file_name == ".data/version" {
166 let version_str = String::from_utf8(contents).map_err(|_| {
168 TorshError::InvalidArgument("Invalid UTF-8 in version file".to_string())
169 })?;
170
171 manifest = Some(PyTorchManifest {
172 code_version: version_str.trim().to_string(),
173 main_module: "main".to_string(), dependencies: Vec::new(),
175 python_version: None,
176 });
177 } else if file_name.ends_with(".py") && self.preserve_python_code {
178 resources.push(Resource {
180 name: file_name.clone(),
181 resource_type: ResourceType::Source,
182 data: contents,
183 metadata: {
184 let mut meta = HashMap::new();
185 meta.insert("language".to_string(), "python".to_string());
186 meta.insert("original_format".to_string(), "pytorch".to_string());
187 meta
188 },
189 });
190 } else if file_name.ends_with(".pkl") && self.extract_models {
191 resources.push(Resource {
193 name: file_name.clone(),
194 resource_type: ResourceType::Model,
195 data: contents,
196 metadata: {
197 let mut meta = HashMap::new();
198 meta.insert("format".to_string(), "pickle".to_string());
199 meta.insert("original_format".to_string(), "pytorch".to_string());
200 meta
201 },
202 });
203 } else {
204 resources.push(Resource {
206 name: file_name.clone(),
207 resource_type: ResourceType::Data,
208 data: contents,
209 metadata: {
210 let mut meta = HashMap::new();
211 meta.insert("original_format".to_string(), "pytorch".to_string());
212 meta
213 },
214 });
215 }
216 }
217
218 let manifest = manifest.unwrap_or_else(|| PyTorchManifest {
219 code_version: "1.0.0".to_string(),
220 main_module: "main".to_string(),
221 dependencies: Vec::new(),
222 python_version: None,
223 });
224
225 Ok((manifest, resources))
226 }
227}
228
229impl FormatConverter for PyTorchConverter {
230 fn import_from_format(&self, path: &std::path::Path) -> Result<Package> {
231 let (pytorch_manifest, resources) = self.extract_pytorch_package(path)?;
232
233 let package_name = path
234 .file_stem()
235 .and_then(|s| s.to_str())
236 .unwrap_or("imported_pytorch_model")
237 .to_string();
238
239 let mut package = Package::new(package_name, pytorch_manifest.code_version);
240
241 for resource in resources {
243 package.add_resource(resource);
244 }
245
246 package
248 .manifest_mut()
249 .metadata
250 .insert("original_format".to_string(), "pytorch".to_string());
251 package
252 .manifest_mut()
253 .metadata
254 .insert("main_module".to_string(), pytorch_manifest.main_module);
255
256 if let Some(python_version) = pytorch_manifest.python_version {
257 package
258 .manifest_mut()
259 .metadata
260 .insert("python_version".to_string(), python_version);
261 }
262
263 for dep in pytorch_manifest.dependencies {
265 package.add_dependency(&dep, "*");
266 }
267
268 Ok(package)
269 }
270
271 fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()> {
272 let file = fs::File::create(path)
273 .map_err(|e| TorshError::IoError(format!("Failed to create output file: {}", e)))?;
274
275 let mut zip = ZipWriter::new(file);
276 zip.set_compression(ZipCompressionLevel::Normal);
277
278 let version_data = package.get_version().as_bytes();
280 zip.add_file(".data/version", version_data)
281 .map_err(|e| TorshError::IoError(format!("Failed to create version file: {}", e)))?;
282
283 for (name, resource) in package.resources() {
285 let file_path =
286 if resource.resource_type == ResourceType::Source && name.ends_with(".py") {
287 format!("code/{}", name)
288 } else if resource.resource_type == ResourceType::Model {
289 format!("data/{}", name)
290 } else {
291 name.clone()
292 };
293
294 zip.add_file(&file_path, &resource.data).map_err(|e| {
295 TorshError::IoError(format!("Failed to create file {}: {}", file_path, e))
296 })?;
297 }
298
299 zip.finish()
300 .map_err(|e| TorshError::IoError(format!("Failed to finalize ZIP archive: {}", e)))?;
301
302 Ok(())
303 }
304
305 fn format(&self) -> PackageFormat {
306 PackageFormat::PyTorch
307 }
308
309 fn is_valid_format(&self, path: &std::path::Path) -> bool {
310 if let Ok(file) = fs::File::open(path) {
312 if let Ok(archive) = ZipReader::new(file) {
313 for entry in archive.entries() {
315 let name = &entry.name;
316 if name == ".data/version"
317 || name.starts_with("code/")
318 || name.ends_with(".pkl")
319 {
320 return true;
321 }
322 }
323 }
324 }
325 false
326 }
327}
328
329impl Default for HuggingFaceConverter {
330 fn default() -> Self {
331 Self {
332 include_tokenizer: true,
333 include_config: true,
334 model_type: None,
335 }
336 }
337}
338
339impl HuggingFaceConverter {
340 pub fn new() -> Self {
342 Self::default()
343 }
344
345 pub fn with_include_tokenizer(mut self, include: bool) -> Self {
347 self.include_tokenizer = include;
348 self
349 }
350
351 pub fn with_include_config(mut self, include: bool) -> Self {
353 self.include_config = include;
354 self
355 }
356
357 pub fn with_model_type(mut self, model_type: String) -> Self {
359 self.model_type = Some(model_type);
360 self
361 }
362
363 fn load_huggingface_model(
365 &self,
366 path: &std::path::Path,
367 ) -> Result<(HuggingFaceConfig, Vec<Resource>)> {
368 let model_dir = path;
369
370 if !model_dir.is_dir() {
371 return Err(TorshError::InvalidArgument(
372 "HuggingFace path must be a directory".to_string(),
373 ));
374 }
375
376 let mut config = None;
377 let mut resources = Vec::new();
378
379 let config_path = model_dir.join("config.json");
381 if config_path.exists() && self.include_config {
382 let config_data = fs::read(&config_path)
383 .map_err(|e| TorshError::IoError(format!("Failed to read config.json: {}", e)))?;
384
385 config = Some(
386 serde_json::from_slice::<HuggingFaceConfig>(&config_data).map_err(|e| {
387 TorshError::SerializationError(format!("Invalid config.json: {}", e))
388 })?,
389 );
390
391 resources.push(Resource {
392 name: "config.json".to_string(),
393 resource_type: ResourceType::Config,
394 data: config_data,
395 metadata: {
396 let mut meta = HashMap::new();
397 meta.insert("original_format".to_string(), "huggingface".to_string());
398 meta
399 },
400 });
401 }
402
403 for entry in fs::read_dir(model_dir)
405 .map_err(|e| TorshError::IoError(format!("Failed to read model directory: {}", e)))?
406 {
407 let entry = entry.map_err(|e| {
408 TorshError::IoError(format!("Failed to read directory entry: {}", e))
409 })?;
410 let file_path = entry.path();
411 let file_name = file_path
412 .file_name()
413 .and_then(|n| n.to_str())
414 .unwrap_or("")
415 .to_string();
416
417 if file_name.ends_with(".bin") || file_name.ends_with(".safetensors") {
418 let data = fs::read(&file_path).map_err(|e| {
420 TorshError::IoError(format!("Failed to read {}: {}", file_name, e))
421 })?;
422
423 resources.push(Resource {
424 name: file_name.clone(),
425 resource_type: ResourceType::Model,
426 data,
427 metadata: {
428 let mut meta = HashMap::new();
429 meta.insert("original_format".to_string(), "huggingface".to_string());
430 if file_name.ends_with(".safetensors") {
431 meta.insert("format".to_string(), "safetensors".to_string());
432 } else {
433 meta.insert("format".to_string(), "pytorch".to_string());
434 }
435 meta
436 },
437 });
438 } else if self.include_tokenizer
439 && (file_name.starts_with("tokenizer") || file_name.ends_with(".json"))
440 {
441 let data = fs::read(&file_path).map_err(|e| {
443 TorshError::IoError(format!("Failed to read {}: {}", file_name, e))
444 })?;
445
446 resources.push(Resource {
447 name: file_name,
448 resource_type: ResourceType::Data,
449 data,
450 metadata: {
451 let mut meta = HashMap::new();
452 meta.insert("original_format".to_string(), "huggingface".to_string());
453 meta.insert("type".to_string(), "tokenizer".to_string());
454 meta
455 },
456 });
457 }
458 }
459
460 let config = config.unwrap_or_else(|| HuggingFaceConfig {
461 model_type: self
462 .model_type
463 .clone()
464 .unwrap_or_else(|| "unknown".to_string()),
465 task: None,
466 architectures: None,
467 tokenizer_class: None,
468 vocab_size: None,
469 });
470
471 Ok((config, resources))
472 }
473}
474
475impl FormatConverter for HuggingFaceConverter {
476 fn import_from_format(&self, path: &std::path::Path) -> Result<Package> {
477 let (hf_config, resources) = self.load_huggingface_model(path)?;
478
479 let package_name = path
480 .file_stem()
481 .and_then(|s| s.to_str())
482 .unwrap_or("imported_huggingface_model")
483 .to_string();
484
485 let mut package = Package::new(package_name, "1.0.0".to_string());
486
487 for resource in resources {
489 package.add_resource(resource);
490 }
491
492 package
494 .manifest_mut()
495 .metadata
496 .insert("original_format".to_string(), "huggingface".to_string());
497 package
498 .manifest_mut()
499 .metadata
500 .insert("model_type".to_string(), hf_config.model_type);
501
502 if let Some(task) = hf_config.task {
503 package
504 .manifest_mut()
505 .metadata
506 .insert("task".to_string(), task);
507 }
508
509 if let Some(architectures) = hf_config.architectures {
510 package.manifest_mut().metadata.insert(
511 "architectures".to_string(),
512 serde_json::to_string(&architectures).unwrap_or_default(),
513 );
514 }
515
516 Ok(package)
517 }
518
519 fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()> {
520 let output_dir = path;
521
522 if !output_dir.exists() {
523 fs::create_dir_all(output_dir).map_err(|e| {
524 TorshError::IoError(format!("Failed to create output directory: {}", e))
525 })?;
526 }
527
528 for (name, resource) in package.resources() {
530 let file_path = output_dir.join(name);
531 fs::write(&file_path, &resource.data)
532 .map_err(|e| TorshError::IoError(format!("Failed to write {}: {}", name, e)))?;
533 }
534
535 let config_path = output_dir.join("config.json");
537 if !config_path.exists() {
538 let default_config = HuggingFaceConfig {
539 model_type: package
540 .metadata()
541 .metadata
542 .get("model_type")
543 .cloned()
544 .unwrap_or_else(|| "unknown".to_string()),
545 task: package.metadata().metadata.get("task").cloned(),
546 architectures: package
547 .metadata()
548 .metadata
549 .get("architectures")
550 .and_then(|s| serde_json::from_str(s).ok()),
551 tokenizer_class: None,
552 vocab_size: None,
553 };
554
555 let config_json = serde_json::to_string_pretty(&default_config).map_err(|e| {
556 TorshError::SerializationError(format!("Failed to serialize config: {}", e))
557 })?;
558
559 fs::write(&config_path, config_json)
560 .map_err(|e| TorshError::IoError(format!("Failed to write config.json: {}", e)))?;
561 }
562
563 Ok(())
564 }
565
566 fn format(&self) -> PackageFormat {
567 PackageFormat::HuggingFace
568 }
569
570 fn is_valid_format(&self, path: &std::path::Path) -> bool {
571 let model_dir = path;
572
573 if !model_dir.is_dir() {
574 return false;
575 }
576
577 let config_path = model_dir.join("config.json");
579 if config_path.exists() {
580 return true;
581 }
582
583 if let Ok(entries) = fs::read_dir(model_dir) {
585 for entry in entries {
586 if let Ok(entry) = entry {
587 let file_name = entry.file_name();
588 let file_name_str = file_name.to_string_lossy();
589 if file_name_str.ends_with(".bin") || file_name_str.ends_with(".safetensors") {
590 return true;
591 }
592 }
593 }
594 }
595
596 false
597 }
598}
599
600impl Default for OnnxConverter {
601 fn default() -> Self {
602 Self {
603 include_metadata: true,
604 optimize_for_inference: false,
605 }
606 }
607}
608
609impl OnnxConverter {
610 pub fn new() -> Self {
612 Self::default()
613 }
614
615 pub fn with_include_metadata(mut self, include: bool) -> Self {
617 self.include_metadata = include;
618 self
619 }
620
621 pub fn with_optimize_for_inference(mut self, optimize: bool) -> Self {
623 self.optimize_for_inference = optimize;
624 self
625 }
626
627 fn extract_onnx_metadata(&self, path: &std::path::Path) -> Result<OnnxMetadata> {
629 Ok(OnnxMetadata {
632 ir_version: 8,
633 producer_name: "torsh-package".to_string(),
634 producer_version: "1.0.0".to_string(),
635 domain: "ai.onnx".to_string(),
636 model_version: 1,
637 doc_string: format!("ONNX model imported from {:?}", path),
638 })
639 }
640}
641
642impl FormatConverter for OnnxConverter {
643 fn import_from_format(&self, path: &std::path::Path) -> Result<Package> {
644 let model_data = fs::read(path)
645 .map_err(|e| TorshError::IoError(format!("Failed to read ONNX model: {}", e)))?;
646
647 let package_name = path
648 .file_stem()
649 .and_then(|s| s.to_str())
650 .unwrap_or("imported_onnx_model")
651 .to_string();
652
653 let mut package = Package::new(package_name, "1.0.0".to_string());
654
655 let model_resource = Resource {
657 name: "model.onnx".to_string(),
658 resource_type: ResourceType::Model,
659 data: model_data,
660 metadata: {
661 let mut meta = HashMap::new();
662 meta.insert("original_format".to_string(), "onnx".to_string());
663 meta.insert("format".to_string(), "onnx".to_string());
664 meta
665 },
666 };
667 package.add_resource(model_resource);
668
669 if self.include_metadata {
671 let onnx_metadata = self.extract_onnx_metadata(path)?;
672 package.manifest_mut().metadata.insert(
673 "onnx_ir_version".to_string(),
674 onnx_metadata.ir_version.to_string(),
675 );
676 package
677 .manifest_mut()
678 .metadata
679 .insert("onnx_producer".to_string(), onnx_metadata.producer_name);
680 package.manifest_mut().metadata.insert(
681 "onnx_producer_version".to_string(),
682 onnx_metadata.producer_version,
683 );
684 }
685
686 package
687 .manifest_mut()
688 .metadata
689 .insert("original_format".to_string(), "onnx".to_string());
690
691 Ok(package)
692 }
693
694 fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()> {
695 let model_resource = package
697 .resources()
698 .iter()
699 .find(|(_, resource)| {
700 resource.resource_type == ResourceType::Model
701 && (resource.name.ends_with(".onnx")
702 || resource
703 .metadata
704 .get("format")
705 .map_or(false, |f| f == "onnx"))
706 })
707 .map(|(_, resource)| resource)
708 .ok_or_else(|| {
709 TorshError::InvalidArgument("No ONNX model found in package".to_string())
710 })?;
711
712 fs::write(path, &model_resource.data)
714 .map_err(|e| TorshError::IoError(format!("Failed to write ONNX model: {}", e)))?;
715
716 Ok(())
717 }
718
719 fn format(&self) -> PackageFormat {
720 PackageFormat::Onnx
721 }
722
723 fn is_valid_format(&self, path: &std::path::Path) -> bool {
724 if let Ok(file) = fs::File::open(path) {
725 use std::io::Read;
726 let mut buffer = [0u8; 16];
727 let mut reader = std::io::BufReader::new(file);
728
729 if reader.read_exact(&mut buffer).is_ok() {
731 return path
734 .extension()
735 .and_then(|e| e.to_str())
736 .map_or(false, |e| e == "onnx");
737 }
738 }
739 false
740 }
741}
742
743impl Default for MLflowConverter {
744 fn default() -> Self {
745 Self {
746 include_conda_env: true,
747 include_requirements: true,
748 flavor: None,
749 }
750 }
751}
752
753impl MLflowConverter {
754 pub fn new() -> Self {
756 Self::default()
757 }
758
759 pub fn with_include_conda_env(mut self, include: bool) -> Self {
761 self.include_conda_env = include;
762 self
763 }
764
765 pub fn with_include_requirements(mut self, include: bool) -> Self {
767 self.include_requirements = include;
768 self
769 }
770
771 pub fn with_flavor(mut self, flavor: String) -> Self {
773 self.flavor = Some(flavor);
774 self
775 }
776
777 fn load_mlflow_model(&self, path: &std::path::Path) -> Result<(MLflowMetadata, Vec<Resource>)> {
779 if !path.is_dir() {
780 return Err(TorshError::InvalidArgument(
781 "MLflow path must be a directory".to_string(),
782 ));
783 }
784
785 let mut metadata = None;
786 let mut resources = Vec::new();
787
788 let mlmodel_path = path.join("MLmodel");
790 if mlmodel_path.exists() {
791 let mlmodel_data = fs::read_to_string(&mlmodel_path)
792 .map_err(|e| TorshError::IoError(format!("Failed to read MLmodel: {}", e)))?;
793
794 metadata = Some(MLflowMetadata {
797 artifact_path: path.to_string_lossy().to_string(),
798 flavors: HashMap::new(),
799 model_uuid: uuid::Uuid::new_v4().to_string(),
800 run_id: "imported".to_string(),
801 utc_time_created: chrono::Utc::now().to_rfc3339(),
802 mlflow_version: "2.0.0".to_string(),
803 });
804
805 resources.push(Resource {
806 name: "MLmodel".to_string(),
807 resource_type: ResourceType::Config,
808 data: mlmodel_data.into_bytes(),
809 metadata: {
810 let mut meta = HashMap::new();
811 meta.insert("original_format".to_string(), "mlflow".to_string());
812 meta
813 },
814 });
815 }
816
817 for entry in fs::read_dir(path)
819 .map_err(|e| TorshError::IoError(format!("Failed to read MLflow directory: {}", e)))?
820 {
821 let entry = entry.map_err(|e| {
822 TorshError::IoError(format!("Failed to read directory entry: {}", e))
823 })?;
824 let file_path = entry.path();
825
826 if file_path.is_file() {
827 let file_name = file_path
828 .file_name()
829 .and_then(|n| n.to_str())
830 .unwrap_or("")
831 .to_string();
832
833 if file_name != "MLmodel" {
834 let data = fs::read(&file_path).map_err(|e| {
835 TorshError::IoError(format!("Failed to read {}: {}", file_name, e))
836 })?;
837
838 let resource_type = if file_name.ends_with(".pkl")
839 || file_name.ends_with(".pt")
840 || file_name.ends_with(".h5")
841 {
842 ResourceType::Model
843 } else if file_name.ends_with(".json") || file_name.ends_with(".yaml") {
844 ResourceType::Config
845 } else if file_name == "requirements.txt" || file_name == "conda.yaml" {
846 ResourceType::Documentation
847 } else {
848 ResourceType::Data
849 };
850
851 resources.push(Resource {
852 name: file_name,
853 resource_type,
854 data,
855 metadata: {
856 let mut meta = HashMap::new();
857 meta.insert("original_format".to_string(), "mlflow".to_string());
858 meta
859 },
860 });
861 }
862 }
863 }
864
865 let metadata = metadata.unwrap_or_else(|| MLflowMetadata {
866 artifact_path: path.to_string_lossy().to_string(),
867 flavors: HashMap::new(),
868 model_uuid: uuid::Uuid::new_v4().to_string(),
869 run_id: "imported".to_string(),
870 utc_time_created: chrono::Utc::now().to_rfc3339(),
871 mlflow_version: "2.0.0".to_string(),
872 });
873
874 Ok((metadata, resources))
875 }
876}
877
878impl FormatConverter for MLflowConverter {
879 fn import_from_format(&self, path: &std::path::Path) -> Result<Package> {
880 let (mlflow_metadata, resources) = self.load_mlflow_model(path)?;
881
882 let package_name = path
883 .file_stem()
884 .and_then(|s| s.to_str())
885 .unwrap_or("imported_mlflow_model")
886 .to_string();
887
888 let mut package = Package::new(package_name, "1.0.0".to_string());
889
890 for resource in resources {
892 package.add_resource(resource);
893 }
894
895 package
897 .manifest_mut()
898 .metadata
899 .insert("original_format".to_string(), "mlflow".to_string());
900 package
901 .manifest_mut()
902 .metadata
903 .insert("mlflow_version".to_string(), mlflow_metadata.mlflow_version);
904 package
905 .manifest_mut()
906 .metadata
907 .insert("model_uuid".to_string(), mlflow_metadata.model_uuid);
908 package
909 .manifest_mut()
910 .metadata
911 .insert("run_id".to_string(), mlflow_metadata.run_id);
912
913 if let Some(flavor) = &self.flavor {
914 package
915 .manifest_mut()
916 .metadata
917 .insert("flavor".to_string(), flavor.clone());
918 }
919
920 Ok(package)
921 }
922
923 fn export_to_format(&self, package: &Package, path: &std::path::Path) -> Result<()> {
924 let output_dir = path;
925
926 if !output_dir.exists() {
927 fs::create_dir_all(output_dir).map_err(|e| {
928 TorshError::IoError(format!("Failed to create output directory: {}", e))
929 })?;
930 }
931
932 for (name, resource) in package.resources() {
934 let file_path = output_dir.join(name);
935 fs::write(&file_path, &resource.data)
936 .map_err(|e| TorshError::IoError(format!("Failed to write {}: {}", name, e)))?;
937 }
938
939 let mlmodel_path = output_dir.join("MLmodel");
941 if !mlmodel_path.exists() {
942 let mlmodel_content = format!(
943 r#"artifact_path: {}
944flavors:
945 python_function:
946 env: conda.yaml
947 loader_module: mlflow.pyfunc.model
948 python_version: 3.9
949model_uuid: {}
950run_id: {}
951utc_time_created: '{}'
952mlflow_version: 2.0.0
953"#,
954 output_dir.to_string_lossy(),
955 package
956 .metadata()
957 .metadata
958 .get("model_uuid")
959 .cloned()
960 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
961 package
962 .metadata()
963 .metadata
964 .get("run_id")
965 .cloned()
966 .unwrap_or_else(|| "exported".to_string()),
967 chrono::Utc::now().to_rfc3339()
968 );
969
970 fs::write(&mlmodel_path, mlmodel_content)
971 .map_err(|e| TorshError::IoError(format!("Failed to write MLmodel: {}", e)))?;
972 }
973
974 Ok(())
975 }
976
977 fn format(&self) -> PackageFormat {
978 PackageFormat::MLflow
979 }
980
981 fn is_valid_format(&self, path: &std::path::Path) -> bool {
982 if !path.is_dir() {
983 return false;
984 }
985
986 let mlmodel_path = path.join("MLmodel");
988 mlmodel_path.exists()
989 }
990}
991
992pub struct FormatCompatibilityManager {
994 converters: HashMap<PackageFormat, Box<dyn FormatConverter>>,
995}
996
997impl Default for FormatCompatibilityManager {
998 fn default() -> Self {
999 let mut manager = Self {
1000 converters: HashMap::new(),
1001 };
1002
1003 manager.register_converter(Box::new(PyTorchConverter::new()));
1005 manager.register_converter(Box::new(HuggingFaceConverter::new()));
1006 manager.register_converter(Box::new(OnnxConverter::new()));
1007 manager.register_converter(Box::new(MLflowConverter::new()));
1008
1009 manager
1010 }
1011}
1012
1013impl FormatCompatibilityManager {
1014 pub fn new() -> Self {
1016 Self::default()
1017 }
1018
1019 pub fn register_converter(&mut self, converter: Box<dyn FormatConverter>) {
1021 let format = converter.format();
1022 self.converters.insert(format, converter);
1023 }
1024
1025 pub fn import_package(&self, path: &std::path::Path) -> Result<(PackageFormat, Package)> {
1027 for (format, converter) in &self.converters {
1028 if converter.is_valid_format(path) {
1029 let package = converter.import_from_format(path)?;
1030 return Ok((*format, package));
1031 }
1032 }
1033
1034 Err(TorshError::InvalidArgument(
1035 "Unrecognized package format".to_string(),
1036 ))
1037 }
1038
1039 pub fn export_package(
1041 &self,
1042 package: &Package,
1043 format: PackageFormat,
1044 path: &std::path::Path,
1045 ) -> Result<()> {
1046 let converter = self.converters.get(&format).ok_or_else(|| {
1047 TorshError::InvalidArgument(format!("Unsupported export format: {:?}", format))
1048 })?;
1049
1050 converter.export_to_format(package, path)
1051 }
1052
1053 pub fn supported_formats(&self) -> Vec<PackageFormat> {
1055 self.converters.keys().copied().collect()
1056 }
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061 use super::*;
1062 use std::io::Write;
1063 use tempfile::TempDir;
1064
1065 #[test]
1066 fn test_pytorch_converter_format_detection() {
1067 let converter = PyTorchConverter::new();
1068 assert_eq!(converter.format(), PackageFormat::PyTorch);
1069 }
1070
1071 #[test]
1072 fn test_huggingface_converter_creation() {
1073 let converter = HuggingFaceConverter::new()
1074 .with_include_tokenizer(false)
1075 .with_model_type("bert".to_string());
1076
1077 assert_eq!(converter.format(), PackageFormat::HuggingFace);
1078 assert!(!converter.include_tokenizer);
1079 assert_eq!(converter.model_type, Some("bert".to_string()));
1080 }
1081
1082 #[test]
1083 fn test_format_manager() {
1084 let manager = FormatCompatibilityManager::new();
1085 let formats = manager.supported_formats();
1086
1087 assert!(formats.contains(&PackageFormat::PyTorch));
1088 assert!(formats.contains(&PackageFormat::HuggingFace));
1089 }
1090
1091 #[test]
1092 fn test_huggingface_directory_validation() {
1093 let temp_dir = TempDir::new().expect("Failed to create temp directory for test");
1094
1095 let config_path = temp_dir.path().join("config.json");
1097 let mut config_file = fs::File::create(&config_path).unwrap();
1098 writeln!(
1099 config_file,
1100 r#"{{"model_type": "bert", "task": "text-classification"}}"#
1101 )
1102 .unwrap();
1103
1104 let converter = HuggingFaceConverter::new();
1105 assert!(converter.is_valid_format(temp_dir.path()));
1106 }
1107
1108 #[test]
1109 fn test_package_format_enum() {
1110 assert_eq!(PackageFormat::PyTorch, PackageFormat::PyTorch);
1111 assert_ne!(PackageFormat::PyTorch, PackageFormat::HuggingFace);
1112 }
1113}