1pub mod async_export;
3pub mod coreml;
4pub mod factory;
5pub mod ggml;
6pub mod gguf;
7pub mod gguf_enhanced;
8pub mod nnef;
9pub mod onnx;
10pub mod onnx_runtime;
11pub mod openvino;
12pub mod optimization;
13pub mod tensorrt;
14pub mod tvm;
15
16pub use async_export::{
17 export_model_async, AsyncExportHandle, AsyncExportManager, ExportProgress, ExportStep,
18};
19pub use coreml::*;
20pub use factory::{
21 ExportConstraints, ExportResult, ExporterFactory, ExporterProvider, ExporterRequirements,
22 TargetPlatform, ValidationResult,
23};
24pub use ggml::*;
25pub use gguf::*;
26pub use gguf_enhanced::{GGUFConverter, GGUFExporter as EnhancedGGUFExporter, GGUFTensorType};
27pub use nnef::*;
28pub use onnx::*;
29pub use onnx_runtime::*;
30pub use openvino::*;
31pub use optimization::{
32 OptimizationConfig, OptimizationImpact, OptimizationPass, OptimizationPipeline,
33 OptimizationStats, PipelineStats, TargetHardware,
34};
35pub use tensorrt::*;
36pub use tvm::*;
37
38use crate::traits::Model;
39use anyhow::Result;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub enum ExportFormat {
44 ONNX,
45 GGML,
46 GGUF,
47 NNEF,
48 OpenVINO,
49 TensorRT,
50 TVM,
51 CoreML,
52}
53
54#[derive(Debug, Clone)]
56pub struct ExportConfig {
57 pub format: ExportFormat,
58 pub output_path: String,
59 pub optimize: bool,
60 pub precision: ExportPrecision,
61 pub batch_size: Option<usize>,
62 pub sequence_length: Option<usize>,
63 pub opset_version: Option<i64>, pub quantization: Option<ExportQuantization>,
65 pub input_shape: Option<Vec<usize>>,
66 pub task_type: Option<String>,
67 pub vocab_size: Option<usize>,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum ExportPrecision {
72 FP32,
73 FP16,
74 INT8,
75 INT4,
76}
77
78#[derive(Debug, Clone)]
79pub struct ExportQuantization {
80 pub bits: u8,
81 pub group_size: Option<usize>,
82 pub calibration_data: Option<Vec<String>>, }
84
85pub trait ModelExporter {
87 fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()>;
88
89 fn supported_formats(&self) -> Vec<ExportFormat>;
90
91 fn validate_model<M: Model>(&self, model: &M, format: ExportFormat) -> Result<()>;
92}
93
94#[derive(Clone)]
96pub struct UniversalExporter;
97
98impl Default for UniversalExporter {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl UniversalExporter {
105 pub fn new() -> Self {
106 Self
107 }
108
109 pub fn export_model<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
110 self.validate_model(model, config.format)?;
111
112 match config.format {
113 ExportFormat::ONNX => {
114 let exporter = ONNXExporter::new();
115 exporter.export(model, config)
116 },
117 ExportFormat::GGML => {
118 let exporter = GGMLExporter::new();
119 exporter.export(model, config)
120 },
121 ExportFormat::GGUF => {
122 let exporter = GGUFExporter::new();
123 exporter.export(model, config)
124 },
125 ExportFormat::NNEF => {
126 let exporter = NNEFExporter::new();
127 exporter.export(model, config)
128 },
129 ExportFormat::OpenVINO => {
130 let exporter = OpenVINOExporter::new();
131 exporter.export(model, config)
132 },
133 ExportFormat::TensorRT => {
134 let exporter = TensorRTExporter::new();
135 exporter.export(model, config)
136 },
137 ExportFormat::TVM => {
138 let exporter = TVMExporter::new();
139 exporter.export(model, config)
140 },
141 ExportFormat::CoreML => {
142 let exporter = CoreMLExporter::new();
143 exporter.export(model, config)
144 },
145 }
146 }
147}
148
149#[derive(Clone)]
151pub enum ConcreteExporter {
152 ONNX(ONNXExporter),
153 GGML(GGMLExporter),
154 GGUF(GGUFExporter),
155 GGUFEnhanced(EnhancedGGUFExporter),
156 NNEF(NNEFExporter),
157 OpenVINO(OpenVINOExporter),
158 TensorRT(TensorRTExporter),
159 TVM(TVMExporter),
160 CoreML(CoreMLExporter),
161 Universal(UniversalExporter),
162}
163
164impl ModelExporter for ConcreteExporter {
165 fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
166 match self {
167 ConcreteExporter::ONNX(exporter) => exporter.export(model, config),
168 ConcreteExporter::GGML(exporter) => exporter.export(model, config),
169 ConcreteExporter::GGUF(exporter) => exporter.export(model, config),
170 ConcreteExporter::GGUFEnhanced(exporter) => exporter.export(model, config),
171 ConcreteExporter::NNEF(exporter) => exporter.export(model, config),
172 ConcreteExporter::OpenVINO(exporter) => exporter.export(model, config),
173 ConcreteExporter::TensorRT(exporter) => exporter.export(model, config),
174 ConcreteExporter::TVM(exporter) => exporter.export(model, config),
175 ConcreteExporter::CoreML(exporter) => exporter.export(model, config),
176 ConcreteExporter::Universal(exporter) => exporter.export(model, config),
177 }
178 }
179
180 fn supported_formats(&self) -> Vec<ExportFormat> {
181 match self {
182 ConcreteExporter::ONNX(exporter) => exporter.supported_formats(),
183 ConcreteExporter::GGML(exporter) => exporter.supported_formats(),
184 ConcreteExporter::GGUF(exporter) => exporter.supported_formats(),
185 ConcreteExporter::GGUFEnhanced(exporter) => exporter.supported_formats(),
186 ConcreteExporter::NNEF(exporter) => exporter.supported_formats(),
187 ConcreteExporter::OpenVINO(exporter) => exporter.supported_formats(),
188 ConcreteExporter::TensorRT(exporter) => exporter.supported_formats(),
189 ConcreteExporter::TVM(exporter) => exporter.supported_formats(),
190 ConcreteExporter::CoreML(exporter) => exporter.supported_formats(),
191 ConcreteExporter::Universal(exporter) => exporter.supported_formats(),
192 }
193 }
194
195 fn validate_model<M: Model>(&self, model: &M, format: ExportFormat) -> Result<()> {
196 match self {
197 ConcreteExporter::ONNX(exporter) => exporter.validate_model(model, format),
198 ConcreteExporter::GGML(exporter) => exporter.validate_model(model, format),
199 ConcreteExporter::GGUF(exporter) => exporter.validate_model(model, format),
200 ConcreteExporter::GGUFEnhanced(exporter) => exporter.validate_model(model, format),
201 ConcreteExporter::NNEF(exporter) => exporter.validate_model(model, format),
202 ConcreteExporter::OpenVINO(exporter) => exporter.validate_model(model, format),
203 ConcreteExporter::TensorRT(exporter) => exporter.validate_model(model, format),
204 ConcreteExporter::TVM(exporter) => exporter.validate_model(model, format),
205 ConcreteExporter::CoreML(exporter) => exporter.validate_model(model, format),
206 ConcreteExporter::Universal(exporter) => exporter.validate_model(model, format),
207 }
208 }
209}
210
211impl ModelExporter for UniversalExporter {
212 fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
213 self.export_model(model, config)
214 }
215
216 fn supported_formats(&self) -> Vec<ExportFormat> {
217 vec![
218 ExportFormat::ONNX,
219 ExportFormat::GGML,
220 ExportFormat::GGUF,
221 ExportFormat::NNEF,
222 ExportFormat::OpenVINO,
223 ExportFormat::TensorRT,
224 ExportFormat::TVM,
225 ExportFormat::CoreML,
226 ]
227 }
228
229 fn validate_model<M: Model>(&self, _model: &M, _format: ExportFormat) -> Result<()> {
230 Ok(())
232 }
233}
234
235impl Default for ExportConfig {
236 fn default() -> Self {
237 Self {
238 format: ExportFormat::ONNX,
239 output_path: "model".to_string(),
240 optimize: true,
241 precision: ExportPrecision::FP32,
242 batch_size: Some(1),
243 sequence_length: Some(512),
244 opset_version: Some(14),
245 quantization: None,
246 input_shape: None,
247 task_type: None,
248 vocab_size: None,
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_export_config_default() {
259 let config = ExportConfig::default();
260 assert_eq!(config.format, ExportFormat::ONNX);
261 assert_eq!(config.output_path, "model");
262 assert!(config.optimize);
263 assert!(matches!(config.precision, ExportPrecision::FP32));
264 }
265
266 #[test]
267 fn test_universal_exporter_creation() {
268 let exporter = UniversalExporter::new();
269 let formats = exporter.supported_formats();
270 assert!(formats.contains(&ExportFormat::ONNX));
271 assert!(formats.contains(&ExportFormat::GGML));
272 assert!(formats.contains(&ExportFormat::GGUF));
273 }
274
275 #[test]
276 fn test_export_precision_variants() {
277 let precisions = [
278 ExportPrecision::FP32,
279 ExportPrecision::FP16,
280 ExportPrecision::INT8,
281 ExportPrecision::INT4,
282 ];
283
284 for precision in precisions.iter() {
285 let config = ExportConfig {
286 precision: *precision,
287 ..Default::default()
288 };
289 assert!(matches!(
291 config.precision,
292 ExportPrecision::FP32
293 | ExportPrecision::FP16
294 | ExportPrecision::INT8
295 | ExportPrecision::INT4
296 ));
297 }
298 }
299}