Skip to main content

xq_vision/
config.rs

1use std::path::Path;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use ort::session::builder::GraphOptimizationLevel;
6
7use crate::session::ExecutionProvider;
8use crate::session::ProviderFailure;
9
10#[derive(Debug, Clone)]
11pub enum ModelSource {
12    File(PathBuf),
13    Memory(Arc<[u8]>),
14}
15
16impl ModelSource {
17    #[must_use]
18    pub fn file(path: impl Into<PathBuf>) -> Self { Self::File(path.into()) }
19
20    #[must_use]
21    pub fn memory(bytes: impl Into<Vec<u8>>) -> Self { Self::Memory(Arc::from(bytes.into())) }
22
23    #[must_use]
24    pub fn as_file(&self) -> Option<&Path> {
25        match self {
26            Self::File(path) => Some(path.as_path()),
27            Self::Memory(_) => None,
28        }
29    }
30
31    #[must_use]
32    pub fn as_bytes(&self) -> Option<&[u8]> {
33        match self {
34            Self::File(_) => None,
35            Self::Memory(bytes) => Some(bytes),
36        }
37    }
38}
39
40impl From<PathBuf> for ModelSource {
41    fn from(value: PathBuf) -> Self { Self::file(value) }
42}
43
44impl From<&Path> for ModelSource {
45    fn from(value: &Path) -> Self { Self::file(value) }
46}
47
48impl From<&str> for ModelSource {
49    fn from(value: &str) -> Self { Self::file(value) }
50}
51
52#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum GraphOptimization {
54    Disable,
55    Basic,
56    Extended,
57    Layout,
58    #[default]
59    All,
60}
61
62impl From<GraphOptimization> for GraphOptimizationLevel {
63    fn from(value: GraphOptimization) -> Self {
64        match value {
65            GraphOptimization::Disable => Self::Disable,
66            GraphOptimization::Basic => Self::Level1,
67            GraphOptimization::Extended => Self::Level2,
68            GraphOptimization::Layout => Self::Level3,
69            GraphOptimization::All => Self::All,
70        }
71    }
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct SessionConfig {
76    execution_providers: Vec<ExecutionProvider>,
77    provider_failure: ProviderFailure,
78    graph_optimization: GraphOptimization,
79    intra_threads: Option<usize>,
80    inter_threads: Option<usize>,
81    parallel_execution: bool,
82}
83
84impl Default for SessionConfig {
85    fn default() -> Self {
86        Self {
87            execution_providers: default_execution_providers(),
88            provider_failure: ProviderFailure::Fallback,
89            graph_optimization: GraphOptimization::All,
90            intra_threads: None,
91            inter_threads: None,
92            parallel_execution: false,
93        }
94    }
95}
96
97/// Build the default provider list from compiled-in execution-provider features.
98///
99/// Accelerators come first (in a stable order) and `Cpu` is always appended as
100/// the final fallback. When no provider feature is enabled the list is just
101/// `[Cpu]`. The provider list is intentionally not configurable at runtime —
102/// the cargo feature set is the single source of truth.
103fn default_execution_providers() -> Vec<ExecutionProvider> {
104    vec![
105        #[cfg(feature = "cuda")]
106        ExecutionProvider::Cuda,
107        #[cfg(feature = "tensorrt")]
108        ExecutionProvider::TensorRt,
109        #[cfg(feature = "coreml")]
110        ExecutionProvider::CoreMl,
111        #[cfg(feature = "directml")]
112        ExecutionProvider::DirectMl,
113        #[cfg(feature = "openvino")]
114        ExecutionProvider::OpenVino,
115        #[cfg(feature = "xnnpack")]
116        ExecutionProvider::Xnnpack,
117        ExecutionProvider::Cpu,
118    ]
119}
120
121impl SessionConfig {
122    #[must_use]
123    pub fn new() -> Self { Self::default() }
124
125    #[must_use]
126    pub fn execution_providers(&self) -> &[ExecutionProvider] { &self.execution_providers }
127
128    #[must_use]
129    pub fn provider_failure(&self) -> ProviderFailure { self.provider_failure }
130
131    #[must_use]
132    pub fn graph_optimization(&self) -> GraphOptimization { self.graph_optimization }
133
134    #[must_use]
135    pub fn intra_threads(&self) -> Option<usize> { self.intra_threads }
136
137    #[must_use]
138    pub fn inter_threads(&self) -> Option<usize> { self.inter_threads }
139
140    #[must_use]
141    pub fn parallel_execution(&self) -> bool { self.parallel_execution }
142
143    #[must_use]
144    pub fn with_provider_failure(mut self, failure: ProviderFailure) -> Self {
145        self.provider_failure = failure;
146        self
147    }
148
149    #[must_use]
150    pub fn with_graph_optimization(mut self, level: GraphOptimization) -> Self {
151        self.graph_optimization = level;
152        self
153    }
154
155    #[must_use]
156    pub fn with_intra_threads(mut self, threads: usize) -> Self {
157        self.intra_threads = Some(threads);
158        self
159    }
160
161    #[must_use]
162    pub fn with_inter_threads(mut self, threads: usize) -> Self {
163        self.inter_threads = Some(threads);
164        self
165    }
166
167    #[must_use]
168    pub fn with_parallel_execution(mut self, enabled: bool) -> Self {
169        self.parallel_execution = enabled;
170        self
171    }
172}