1use ort::execution_providers::ExecutionProviderDispatch;
2use ort::session::Session;
3use ort::session::builder::SessionBuilder;
4
5use crate::config::ModelSource;
6use crate::config::SessionConfig;
7use crate::error::Result;
8use crate::error::XqVisionError;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum ProviderFailure {
12 Fallback,
13 Error,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ExecutionProvider {
18 Cpu,
19 CoreMl,
20 Cuda,
21 TensorRt,
22 DirectMl,
23 OpenVino,
24 Xnnpack,
25}
26
27impl ExecutionProvider {
28 fn dispatch(self, failure: ProviderFailure) -> Result<ExecutionProviderDispatch> {
29 let dispatch = match self {
30 Self::Cpu => ort::execution_providers::CPUExecutionProvider::default().build(),
31 Self::CoreMl => {
32 #[cfg(feature = "coreml")]
33 {
34 ort::execution_providers::CoreMLExecutionProvider::default().build()
35 }
36 #[cfg(not(feature = "coreml"))]
37 {
38 return Err(XqVisionError::UnsupportedProvider { provider: self });
39 }
40 }
41 Self::Cuda => {
42 #[cfg(feature = "cuda")]
43 {
44 ort::execution_providers::CUDAExecutionProvider::default().build()
45 }
46 #[cfg(not(feature = "cuda"))]
47 {
48 return Err(XqVisionError::UnsupportedProvider { provider: self });
49 }
50 }
51 Self::TensorRt => {
52 #[cfg(feature = "tensorrt")]
53 {
54 ort::execution_providers::TensorRTExecutionProvider::default().build()
55 }
56 #[cfg(not(feature = "tensorrt"))]
57 {
58 return Err(XqVisionError::UnsupportedProvider { provider: self });
59 }
60 }
61 Self::DirectMl => {
62 #[cfg(feature = "directml")]
63 {
64 ort::execution_providers::DirectMLExecutionProvider::default().build()
65 }
66 #[cfg(not(feature = "directml"))]
67 {
68 return Err(XqVisionError::UnsupportedProvider { provider: self });
69 }
70 }
71 Self::OpenVino => {
72 #[cfg(feature = "openvino")]
73 {
74 ort::execution_providers::OpenVINOExecutionProvider::default().build()
75 }
76 #[cfg(not(feature = "openvino"))]
77 {
78 return Err(XqVisionError::UnsupportedProvider { provider: self });
79 }
80 }
81 Self::Xnnpack => {
82 #[cfg(feature = "xnnpack")]
83 {
84 ort::execution_providers::XNNPACKExecutionProvider::default().build()
85 }
86 #[cfg(not(feature = "xnnpack"))]
87 {
88 return Err(XqVisionError::UnsupportedProvider { provider: self });
89 }
90 }
91 };
92
93 Ok(match failure {
94 ProviderFailure::Fallback => dispatch.fail_silently(),
95 ProviderFailure::Error => dispatch.error_on_failure(),
96 })
97 }
98}
99
100pub(crate) fn create_session(source: &ModelSource, config: &SessionConfig) -> Result<Session> {
101 let mut builder =
102 Session::builder()?.with_optimization_level(config.graph_optimization().into()).map_err(map_builder_error)?;
103
104 if let Some(threads) = config.intra_threads() {
105 builder = builder.with_intra_threads(threads).map_err(map_builder_error)?;
106 }
107 if let Some(threads) = config.inter_threads() {
108 builder = builder.with_inter_threads(threads).map_err(map_builder_error)?;
109 }
110 if config.parallel_execution() {
111 builder = builder.with_parallel_execution(true).map_err(map_builder_error)?;
112 }
113
114 let providers = config
115 .execution_providers()
116 .iter()
117 .copied()
118 .map(|provider| provider.dispatch(config.provider_failure()))
119 .collect::<Result<Vec<_>>>()?;
120 builder = builder.with_execution_providers(providers).map_err(map_builder_error)?;
121
122 match source {
123 ModelSource::File(path) => Ok(builder.commit_from_file(path)?),
124 ModelSource::Memory(bytes) => Ok(builder.commit_from_memory(bytes)?),
125 }
126}
127
128fn map_builder_error(error: ort::Error<SessionBuilder>) -> XqVisionError {
129 XqVisionError::Ort(ort::Error::new(error.to_string()))
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn cpu_provider_dispatches_without_feature_gate() -> Result<()> {
138 let dispatch = ExecutionProvider::Cpu.dispatch(ProviderFailure::Fallback)?;
139 assert!(dispatch.downcast_ref::<ort::execution_providers::CPUExecutionProvider>().is_some());
140 Ok(())
141 }
142}