1use std::sync::Arc;
2
3use ort::{
4 execution_providers::{CUDAExecutionProviderOptions, CoreMLExecutionProviderOptions},
5 ExecutionProvider,
6};
7
8use crate::{
9 blazeface::BlazeFaceParams,
10 detection::{FaceDetector, RustFacesResult},
11 model_repository::{GitHubRepository, ModelRepository},
12 BlazeFace, MtCnn, MtCnnParams,
13};
14
15pub enum FaceDetection {
16 BlazeFace640(BlazeFaceParams),
17 BlazeFace320(BlazeFaceParams),
18 MtCnn(MtCnnParams),
19}
20
21#[derive(Clone, Debug)]
22enum OpenMode {
23 File(String),
24 Download,
25}
26
27#[derive(Clone, Copy, Debug)]
29pub enum Provider {
30 OrtCpu,
32 OrtCuda(i32),
34 OrtVino(i32),
36 OrtCoreMl,
38}
39
40pub struct InferParams {
42 pub provider: Provider,
44 pub intra_threads: Option<usize>,
46 pub inter_threads: Option<usize>,
48}
49
50impl Default for InferParams {
51 fn default() -> Self {
53 Self {
54 provider: Provider::OrtCpu,
55 intra_threads: None,
56 inter_threads: None,
57 }
58 }
59}
60
61pub struct FaceDetectorBuilder {
63 detector: FaceDetection,
64 open_mode: OpenMode,
65 infer_params: InferParams,
66}
67
68impl FaceDetectorBuilder {
69 pub fn new(detector: FaceDetection) -> Self {
75 Self {
76 detector,
77 open_mode: OpenMode::Download,
78 infer_params: InferParams::default(),
79 }
80 }
81
82 pub fn from_file(mut self, path: String) -> Self {
88 self.open_mode = OpenMode::File(path);
89 self
90 }
91
92 pub fn download(mut self) -> Self {
94 self.open_mode = OpenMode::Download;
95 self
96 }
97
98 pub fn infer_params(mut self, params: InferParams) -> Self {
100 self.infer_params = params;
101 self
102 }
103
104 pub fn build(&self) -> RustFacesResult<Box<dyn FaceDetector>> {
114 let mut ort_builder = ort::Environment::builder().with_name("RustFaces");
115
116 ort_builder = match self.infer_params.provider {
117 Provider::OrtCuda(device_id) => {
118 let provider = ExecutionProvider::CUDA(CUDAExecutionProviderOptions {
119 device_id: device_id as u32,
120 ..Default::default()
121 });
122
123 if !provider.is_available() {
124 eprintln!("Warning: CUDA is not available. It'll likely use CPU inference.");
125 }
126 ort_builder.with_execution_providers([provider])
127 }
128 Provider::OrtVino(_device_id) => {
129 return Err(crate::RustFacesError::Other(
130 "OpenVINO is not supported yet.".to_string(),
131 ));
132 }
133 Provider::OrtCoreMl => {
134 ort_builder.with_execution_providers([ExecutionProvider::CoreML(
135 CoreMLExecutionProviderOptions::default(),
136 )])
137 }
138 _ => ort_builder,
139 };
140
141 let env = Arc::new(ort_builder.build()?);
142 let repository = GitHubRepository::new();
143
144 let model_paths = match &self.open_mode {
145 OpenMode::Download => repository
146 .get_model(&self.detector)?
147 .iter()
148 .map(|path| path.to_str().unwrap().to_string())
149 .collect(),
150 OpenMode::File(path) => vec![path.clone()],
151 };
152
153 match &self.detector {
154 FaceDetection::BlazeFace640(params) => Ok(Box::new(BlazeFace::from_file(
155 env,
156 &model_paths[0],
157 params.clone(),
158 ))),
159 FaceDetection::BlazeFace320(params) => Ok(Box::new(BlazeFace::from_file(
160 env,
161 &model_paths[0],
162 params.clone(),
163 ))),
164 FaceDetection::MtCnn(params) => Ok(Box::new(
165 MtCnn::from_file(
166 env,
167 &model_paths[0],
168 &model_paths[1],
169 &model_paths[2],
170 params.clone(),
171 )
172 .unwrap(),
173 )),
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {}