scirs2_io/ml_framework/converters/
mod.rs1#![allow(dead_code)]
3
4use crate::error::Result;
5use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
6use std::path::Path;
7
8pub mod coreml;
9pub mod huggingface;
10pub mod jax;
11pub mod mxnet;
12pub mod onnx;
13pub mod pytorch;
14pub mod safetensors;
15pub mod tensorflow;
16
17pub use coreml::CoreMLConverter;
18pub use huggingface::HuggingFaceConverter;
19pub use jax::JAXConverter;
20pub use mxnet::MXNetConverter;
21pub use onnx::ONNXConverter;
22pub use pytorch::PyTorchConverter;
23pub use safetensors::SafeTensorsConverter;
24pub use tensorflow::TensorFlowConverter;
25
26pub trait MLFrameworkConverter {
28 fn save_model(&self, model: &MLModel, path: &Path) -> Result<()>;
29 fn load_model(&self, path: &Path) -> Result<MLModel>;
30 #[allow(dead_code)]
31 fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()>;
32 #[allow(dead_code)]
33 fn load_tensor(&self, path: &Path) -> Result<MLTensor>;
34}
35
36pub fn get_converter(framework: MLFramework) -> Box<dyn MLFrameworkConverter> {
38 match framework {
39 MLFramework::PyTorch => Box::new(PyTorchConverter),
40 MLFramework::ONNX => Box::new(ONNXConverter),
41 MLFramework::SafeTensors => Box::new(SafeTensorsConverter),
42 MLFramework::TensorFlow => Box::new(TensorFlowConverter),
43 MLFramework::JAX => Box::new(JAXConverter),
44 MLFramework::MXNet => Box::new(MXNetConverter),
45 MLFramework::CoreML => Box::new(CoreMLConverter),
46 MLFramework::HuggingFace => Box::new(HuggingFaceConverter),
47 }
48}