scirs2_io/ml_framework/converters/
mod.rs

1//! ML framework converters
2#![allow(dead_code)]
3
4use crate::error::Result;
5use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
6use std::path::Path;
7
8pub mod coreml;
9pub mod huggingface;
10pub mod jax;
11pub mod mxnet;
12pub mod onnx;
13pub mod pytorch;
14pub mod safetensors;
15pub mod tensorflow;
16
17pub use coreml::CoreMLConverter;
18pub use huggingface::HuggingFaceConverter;
19pub use jax::JAXConverter;
20pub use mxnet::MXNetConverter;
21pub use onnx::ONNXConverter;
22pub use pytorch::PyTorchConverter;
23pub use safetensors::SafeTensorsConverter;
24pub use tensorflow::TensorFlowConverter;
25
26/// Trait for ML framework converters
27pub trait MLFrameworkConverter {
28    fn save_model(&self, model: &MLModel, path: &Path) -> Result<()>;
29    fn load_model(&self, path: &Path) -> Result<MLModel>;
30    #[allow(dead_code)]
31    fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()>;
32    #[allow(dead_code)]
33    fn load_tensor(&self, path: &Path) -> Result<MLTensor>;
34}
35
36/// Get appropriate converter for framework
37pub fn get_converter(framework: MLFramework) -> Box<dyn MLFrameworkConverter> {
38    match framework {
39        MLFramework::PyTorch => Box::new(PyTorchConverter),
40        MLFramework::ONNX => Box::new(ONNXConverter),
41        MLFramework::SafeTensors => Box::new(SafeTensorsConverter),
42        MLFramework::TensorFlow => Box::new(TensorFlowConverter),
43        MLFramework::JAX => Box::new(JAXConverter),
44        MLFramework::MXNet => Box::new(MXNetConverter),
45        MLFramework::CoreML => Box::new(CoreMLConverter),
46        MLFramework::HuggingFace => Box::new(HuggingFaceConverter),
47    }
48}