Skip to main content

xq_vision/
session.rs

1use ort::execution_providers::ExecutionProviderDispatch;
2use ort::session::Session;
3use ort::session::builder::SessionBuilder;
4
5use crate::config::ModelSource;
6use crate::config::SessionConfig;
7use crate::error::Result;
8use crate::error::XqVisionError;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum ProviderFailure {
12    Fallback,
13    Error,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ExecutionProvider {
18    Cpu,
19    CoreMl,
20    Cuda,
21    TensorRt,
22    DirectMl,
23    OpenVino,
24    Xnnpack,
25}
26
27impl ExecutionProvider {
28    fn dispatch(self, failure: ProviderFailure) -> Result<ExecutionProviderDispatch> {
29        let dispatch = match self {
30            Self::Cpu => ort::execution_providers::CPUExecutionProvider::default().build(),
31            Self::CoreMl => {
32                #[cfg(feature = "coreml")]
33                {
34                    ort::execution_providers::CoreMLExecutionProvider::default().build()
35                }
36                #[cfg(not(feature = "coreml"))]
37                {
38                    return Err(XqVisionError::UnsupportedProvider { provider: self });
39                }
40            }
41            Self::Cuda => {
42                #[cfg(feature = "cuda")]
43                {
44                    ort::execution_providers::CUDAExecutionProvider::default().build()
45                }
46                #[cfg(not(feature = "cuda"))]
47                {
48                    return Err(XqVisionError::UnsupportedProvider { provider: self });
49                }
50            }
51            Self::TensorRt => {
52                #[cfg(feature = "tensorrt")]
53                {
54                    ort::execution_providers::TensorRTExecutionProvider::default().build()
55                }
56                #[cfg(not(feature = "tensorrt"))]
57                {
58                    return Err(XqVisionError::UnsupportedProvider { provider: self });
59                }
60            }
61            Self::DirectMl => {
62                #[cfg(feature = "directml")]
63                {
64                    ort::execution_providers::DirectMLExecutionProvider::default().build()
65                }
66                #[cfg(not(feature = "directml"))]
67                {
68                    return Err(XqVisionError::UnsupportedProvider { provider: self });
69                }
70            }
71            Self::OpenVino => {
72                #[cfg(feature = "openvino")]
73                {
74                    ort::execution_providers::OpenVINOExecutionProvider::default().build()
75                }
76                #[cfg(not(feature = "openvino"))]
77                {
78                    return Err(XqVisionError::UnsupportedProvider { provider: self });
79                }
80            }
81            Self::Xnnpack => {
82                #[cfg(feature = "xnnpack")]
83                {
84                    ort::execution_providers::XNNPACKExecutionProvider::default().build()
85                }
86                #[cfg(not(feature = "xnnpack"))]
87                {
88                    return Err(XqVisionError::UnsupportedProvider { provider: self });
89                }
90            }
91        };
92
93        Ok(match failure {
94            ProviderFailure::Fallback => dispatch.fail_silently(),
95            ProviderFailure::Error => dispatch.error_on_failure(),
96        })
97    }
98}
99
100pub(crate) fn create_session(source: &ModelSource, config: &SessionConfig) -> Result<Session> {
101    let mut builder =
102        Session::builder()?.with_optimization_level(config.graph_optimization().into()).map_err(map_builder_error)?;
103
104    if let Some(threads) = config.intra_threads() {
105        builder = builder.with_intra_threads(threads).map_err(map_builder_error)?;
106    }
107    if let Some(threads) = config.inter_threads() {
108        builder = builder.with_inter_threads(threads).map_err(map_builder_error)?;
109    }
110    if config.parallel_execution() {
111        builder = builder.with_parallel_execution(true).map_err(map_builder_error)?;
112    }
113
114    let providers = config
115        .execution_providers()
116        .iter()
117        .copied()
118        .map(|provider| provider.dispatch(config.provider_failure()))
119        .collect::<Result<Vec<_>>>()?;
120    builder = builder.with_execution_providers(providers).map_err(map_builder_error)?;
121
122    match source {
123        ModelSource::File(path) => Ok(builder.commit_from_file(path)?),
124        ModelSource::Memory(bytes) => Ok(builder.commit_from_memory(bytes)?),
125    }
126}
127
128fn map_builder_error(error: ort::Error<SessionBuilder>) -> XqVisionError {
129    XqVisionError::Ort(ort::Error::new(error.to_string()))
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn cpu_provider_dispatches_without_feature_gate() -> Result<()> {
138        let dispatch = ExecutionProvider::Cpu.dispatch(ProviderFailure::Fallback)?;
139        assert!(dispatch.downcast_ref::<ort::execution_providers::CPUExecutionProvider>().is_some());
140        Ok(())
141    }
142}