1use std::path::Path;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use ort::session::builder::GraphOptimizationLevel;
6
7use crate::session::ExecutionProvider;
8use crate::session::ProviderFailure;
9
10#[derive(Debug, Clone)]
11pub enum ModelSource {
12 File(PathBuf),
13 Memory(Arc<[u8]>),
14}
15
16impl ModelSource {
17 #[must_use]
18 pub fn file(path: impl Into<PathBuf>) -> Self { Self::File(path.into()) }
19
20 #[must_use]
21 pub fn memory(bytes: impl Into<Vec<u8>>) -> Self { Self::Memory(Arc::from(bytes.into())) }
22
23 #[must_use]
24 pub fn as_file(&self) -> Option<&Path> {
25 match self {
26 Self::File(path) => Some(path.as_path()),
27 Self::Memory(_) => None,
28 }
29 }
30
31 #[must_use]
32 pub fn as_bytes(&self) -> Option<&[u8]> {
33 match self {
34 Self::File(_) => None,
35 Self::Memory(bytes) => Some(bytes),
36 }
37 }
38}
39
40impl From<PathBuf> for ModelSource {
41 fn from(value: PathBuf) -> Self { Self::file(value) }
42}
43
44impl From<&Path> for ModelSource {
45 fn from(value: &Path) -> Self { Self::file(value) }
46}
47
48impl From<&str> for ModelSource {
49 fn from(value: &str) -> Self { Self::file(value) }
50}
51
52#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum GraphOptimization {
54 Disable,
55 Basic,
56 Extended,
57 Layout,
58 #[default]
59 All,
60}
61
62impl From<GraphOptimization> for GraphOptimizationLevel {
63 fn from(value: GraphOptimization) -> Self {
64 match value {
65 GraphOptimization::Disable => Self::Disable,
66 GraphOptimization::Basic => Self::Level1,
67 GraphOptimization::Extended => Self::Level2,
68 GraphOptimization::Layout => Self::Level3,
69 GraphOptimization::All => Self::All,
70 }
71 }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct SessionConfig {
76 execution_providers: Vec<ExecutionProvider>,
77 provider_failure: ProviderFailure,
78 graph_optimization: GraphOptimization,
79 intra_threads: Option<usize>,
80 inter_threads: Option<usize>,
81 parallel_execution: bool,
82}
83
84impl Default for SessionConfig {
85 fn default() -> Self {
86 Self {
87 execution_providers: default_execution_providers(),
88 provider_failure: ProviderFailure::Fallback,
89 graph_optimization: GraphOptimization::All,
90 intra_threads: None,
91 inter_threads: None,
92 parallel_execution: false,
93 }
94 }
95}
96
97fn default_execution_providers() -> Vec<ExecutionProvider> {
104 vec![
105 #[cfg(feature = "cuda")]
106 ExecutionProvider::Cuda,
107 #[cfg(feature = "tensorrt")]
108 ExecutionProvider::TensorRt,
109 #[cfg(feature = "coreml")]
110 ExecutionProvider::CoreMl,
111 #[cfg(feature = "directml")]
112 ExecutionProvider::DirectMl,
113 #[cfg(feature = "openvino")]
114 ExecutionProvider::OpenVino,
115 #[cfg(feature = "xnnpack")]
116 ExecutionProvider::Xnnpack,
117 ExecutionProvider::Cpu,
118 ]
119}
120
121impl SessionConfig {
122 #[must_use]
123 pub fn new() -> Self { Self::default() }
124
125 #[must_use]
126 pub fn execution_providers(&self) -> &[ExecutionProvider] { &self.execution_providers }
127
128 #[must_use]
129 pub fn provider_failure(&self) -> ProviderFailure { self.provider_failure }
130
131 #[must_use]
132 pub fn graph_optimization(&self) -> GraphOptimization { self.graph_optimization }
133
134 #[must_use]
135 pub fn intra_threads(&self) -> Option<usize> { self.intra_threads }
136
137 #[must_use]
138 pub fn inter_threads(&self) -> Option<usize> { self.inter_threads }
139
140 #[must_use]
141 pub fn parallel_execution(&self) -> bool { self.parallel_execution }
142
143 #[must_use]
144 pub fn with_provider_failure(mut self, failure: ProviderFailure) -> Self {
145 self.provider_failure = failure;
146 self
147 }
148
149 #[must_use]
150 pub fn with_graph_optimization(mut self, level: GraphOptimization) -> Self {
151 self.graph_optimization = level;
152 self
153 }
154
155 #[must_use]
156 pub fn with_intra_threads(mut self, threads: usize) -> Self {
157 self.intra_threads = Some(threads);
158 self
159 }
160
161 #[must_use]
162 pub fn with_inter_threads(mut self, threads: usize) -> Self {
163 self.inter_threads = Some(threads);
164 self
165 }
166
167 #[must_use]
168 pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
169 self.parallel_execution = enabled;
170 self
171 }
172}