Skip to main content

trustformers_core/export/
mod.rs

1// Model export functionality for various formats
2pub mod async_export;
3pub mod coreml;
4pub mod factory;
5pub mod ggml;
6pub mod gguf;
7pub mod gguf_enhanced;
8pub mod nnef;
9pub mod onnx;
10pub mod onnx_runtime;
11pub mod openvino;
12pub mod optimization;
13pub mod tensorrt;
14pub mod tvm;
15
16pub use async_export::{
17    export_model_async, AsyncExportHandle, AsyncExportManager, ExportProgress, ExportStep,
18};
19pub use coreml::*;
20pub use factory::{
21    ExportConstraints, ExportResult, ExporterFactory, ExporterProvider, ExporterRequirements,
22    TargetPlatform, ValidationResult,
23};
24pub use ggml::*;
25pub use gguf::*;
26pub use gguf_enhanced::{GGUFConverter, GGUFExporter as EnhancedGGUFExporter, GGUFTensorType};
27pub use nnef::*;
28pub use onnx::*;
29pub use onnx_runtime::*;
30pub use openvino::*;
31pub use optimization::{
32    OptimizationConfig, OptimizationImpact, OptimizationPass, OptimizationPipeline,
33    OptimizationStats, PipelineStats, TargetHardware,
34};
35pub use tensorrt::*;
36pub use tvm::*;
37
38use crate::traits::Model;
39use anyhow::Result;
40
41/// Supported export formats
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub enum ExportFormat {
44    ONNX,
45    GGML,
46    GGUF,
47    NNEF,
48    OpenVINO,
49    TensorRT,
50    TVM,
51    CoreML,
52}
53
54/// Export configuration
55#[derive(Debug, Clone)]
56pub struct ExportConfig {
57    pub format: ExportFormat,
58    pub output_path: String,
59    pub optimize: bool,
60    pub precision: ExportPrecision,
61    pub batch_size: Option<usize>,
62    pub sequence_length: Option<usize>,
63    pub opset_version: Option<i64>, // For ONNX
64    pub quantization: Option<ExportQuantization>,
65    pub input_shape: Option<Vec<usize>>,
66    pub task_type: Option<String>,
67    pub vocab_size: Option<usize>,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum ExportPrecision {
72    FP32,
73    FP16,
74    INT8,
75    INT4,
76}
77
78#[derive(Debug, Clone)]
79pub struct ExportQuantization {
80    pub bits: u8,
81    pub group_size: Option<usize>,
82    pub calibration_data: Option<Vec<String>>, // Sample texts for calibration
83}
84
85/// Main export trait
86pub trait ModelExporter {
87    fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()>;
88
89    fn supported_formats(&self) -> Vec<ExportFormat>;
90
91    fn validate_model<M: Model>(&self, model: &M, format: ExportFormat) -> Result<()>;
92}
93
94/// Universal model exporter
95#[derive(Clone)]
96pub struct UniversalExporter;
97
98impl Default for UniversalExporter {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104impl UniversalExporter {
105    pub fn new() -> Self {
106        Self
107    }
108
109    pub fn export_model<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
110        self.validate_model(model, config.format)?;
111
112        match config.format {
113            ExportFormat::ONNX => {
114                let exporter = ONNXExporter::new();
115                exporter.export(model, config)
116            },
117            ExportFormat::GGML => {
118                let exporter = GGMLExporter::new();
119                exporter.export(model, config)
120            },
121            ExportFormat::GGUF => {
122                let exporter = GGUFExporter::new();
123                exporter.export(model, config)
124            },
125            ExportFormat::NNEF => {
126                let exporter = NNEFExporter::new();
127                exporter.export(model, config)
128            },
129            ExportFormat::OpenVINO => {
130                let exporter = OpenVINOExporter::new();
131                exporter.export(model, config)
132            },
133            ExportFormat::TensorRT => {
134                let exporter = TensorRTExporter::new();
135                exporter.export(model, config)
136            },
137            ExportFormat::TVM => {
138                let exporter = TVMExporter::new();
139                exporter.export(model, config)
140            },
141            ExportFormat::CoreML => {
142                let exporter = CoreMLExporter::new();
143                exporter.export(model, config)
144            },
145        }
146    }
147}
148
149/// Concrete enum holding all exporter types for dyn compatibility
150#[derive(Clone)]
151pub enum ConcreteExporter {
152    ONNX(ONNXExporter),
153    GGML(GGMLExporter),
154    GGUF(GGUFExporter),
155    GGUFEnhanced(EnhancedGGUFExporter),
156    NNEF(NNEFExporter),
157    OpenVINO(OpenVINOExporter),
158    TensorRT(TensorRTExporter),
159    TVM(TVMExporter),
160    CoreML(CoreMLExporter),
161    Universal(UniversalExporter),
162}
163
164impl ModelExporter for ConcreteExporter {
165    fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
166        match self {
167            ConcreteExporter::ONNX(exporter) => exporter.export(model, config),
168            ConcreteExporter::GGML(exporter) => exporter.export(model, config),
169            ConcreteExporter::GGUF(exporter) => exporter.export(model, config),
170            ConcreteExporter::GGUFEnhanced(exporter) => exporter.export(model, config),
171            ConcreteExporter::NNEF(exporter) => exporter.export(model, config),
172            ConcreteExporter::OpenVINO(exporter) => exporter.export(model, config),
173            ConcreteExporter::TensorRT(exporter) => exporter.export(model, config),
174            ConcreteExporter::TVM(exporter) => exporter.export(model, config),
175            ConcreteExporter::CoreML(exporter) => exporter.export(model, config),
176            ConcreteExporter::Universal(exporter) => exporter.export(model, config),
177        }
178    }
179
180    fn supported_formats(&self) -> Vec<ExportFormat> {
181        match self {
182            ConcreteExporter::ONNX(exporter) => exporter.supported_formats(),
183            ConcreteExporter::GGML(exporter) => exporter.supported_formats(),
184            ConcreteExporter::GGUF(exporter) => exporter.supported_formats(),
185            ConcreteExporter::GGUFEnhanced(exporter) => exporter.supported_formats(),
186            ConcreteExporter::NNEF(exporter) => exporter.supported_formats(),
187            ConcreteExporter::OpenVINO(exporter) => exporter.supported_formats(),
188            ConcreteExporter::TensorRT(exporter) => exporter.supported_formats(),
189            ConcreteExporter::TVM(exporter) => exporter.supported_formats(),
190            ConcreteExporter::CoreML(exporter) => exporter.supported_formats(),
191            ConcreteExporter::Universal(exporter) => exporter.supported_formats(),
192        }
193    }
194
195    fn validate_model<M: Model>(&self, model: &M, format: ExportFormat) -> Result<()> {
196        match self {
197            ConcreteExporter::ONNX(exporter) => exporter.validate_model(model, format),
198            ConcreteExporter::GGML(exporter) => exporter.validate_model(model, format),
199            ConcreteExporter::GGUF(exporter) => exporter.validate_model(model, format),
200            ConcreteExporter::GGUFEnhanced(exporter) => exporter.validate_model(model, format),
201            ConcreteExporter::NNEF(exporter) => exporter.validate_model(model, format),
202            ConcreteExporter::OpenVINO(exporter) => exporter.validate_model(model, format),
203            ConcreteExporter::TensorRT(exporter) => exporter.validate_model(model, format),
204            ConcreteExporter::TVM(exporter) => exporter.validate_model(model, format),
205            ConcreteExporter::CoreML(exporter) => exporter.validate_model(model, format),
206            ConcreteExporter::Universal(exporter) => exporter.validate_model(model, format),
207        }
208    }
209}
210
211impl ModelExporter for UniversalExporter {
212    fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
213        self.export_model(model, config)
214    }
215
216    fn supported_formats(&self) -> Vec<ExportFormat> {
217        vec![
218            ExportFormat::ONNX,
219            ExportFormat::GGML,
220            ExportFormat::GGUF,
221            ExportFormat::NNEF,
222            ExportFormat::OpenVINO,
223            ExportFormat::TensorRT,
224            ExportFormat::TVM,
225            ExportFormat::CoreML,
226        ]
227    }
228
229    fn validate_model<M: Model>(&self, _model: &M, _format: ExportFormat) -> Result<()> {
230        // Basic validation - can be extended per format
231        Ok(())
232    }
233}
234
235impl Default for ExportConfig {
236    fn default() -> Self {
237        Self {
238            format: ExportFormat::ONNX,
239            output_path: "model".to_string(),
240            optimize: true,
241            precision: ExportPrecision::FP32,
242            batch_size: Some(1),
243            sequence_length: Some(512),
244            opset_version: Some(14),
245            quantization: None,
246            input_shape: None,
247            task_type: None,
248            vocab_size: None,
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_export_config_default() {
259        let config = ExportConfig::default();
260        assert_eq!(config.format, ExportFormat::ONNX);
261        assert_eq!(config.output_path, "model");
262        assert!(config.optimize);
263        assert!(matches!(config.precision, ExportPrecision::FP32));
264    }
265
266    #[test]
267    fn test_universal_exporter_creation() {
268        let exporter = UniversalExporter::new();
269        let formats = exporter.supported_formats();
270        assert!(formats.contains(&ExportFormat::ONNX));
271        assert!(formats.contains(&ExportFormat::GGML));
272        assert!(formats.contains(&ExportFormat::GGUF));
273    }
274
275    #[test]
276    fn test_export_precision_variants() {
277        let precisions = [
278            ExportPrecision::FP32,
279            ExportPrecision::FP16,
280            ExportPrecision::INT8,
281            ExportPrecision::INT4,
282        ];
283
284        for precision in precisions.iter() {
285            let config = ExportConfig {
286                precision: *precision,
287                ..Default::default()
288            };
289            // Just test that we can create configs with different precisions
290            assert!(matches!(
291                config.precision,
292                ExportPrecision::FP32
293                    | ExportPrecision::FP16
294                    | ExportPrecision::INT8
295                    | ExportPrecision::INT4
296            ));
297        }
298    }
299}