rust_faces/
builder.rs

1use std::sync::Arc;
2
3use ort::{
4    execution_providers::{CUDAExecutionProviderOptions, CoreMLExecutionProviderOptions},
5    ExecutionProvider,
6};
7
8use crate::{
9    blazeface::BlazeFaceParams,
10    detection::{FaceDetector, RustFacesResult},
11    model_repository::{GitHubRepository, ModelRepository},
12    BlazeFace, MtCnn, MtCnnParams,
13};
14
15pub enum FaceDetection {
16    BlazeFace640(BlazeFaceParams),
17    BlazeFace320(BlazeFaceParams),
18    MtCnn(MtCnnParams),
19}
20
21#[derive(Clone, Debug)]
22enum OpenMode {
23    File(String),
24    Download,
25}
26
27/// Runtime inference provider. Some may not be available depending of your Onnx runtime installation.
28#[derive(Clone, Copy, Debug)]
29pub enum Provider {
30    /// Uses the, default, CPU inference
31    OrtCpu,
32    /// Uses the Cuda inference.
33    OrtCuda(i32),
34    /// Uses Intel's OpenVINO inference.
35    OrtVino(i32),
36    /// Apple's Core ML inference.
37    OrtCoreMl,
38}
39
40/// Inference parameters.
41pub struct InferParams {
42    /// Chooses the ONNX runtime provider.
43    pub provider: Provider,
44    /// Sets the number of intra-op threads.
45    pub intra_threads: Option<usize>,
46    /// Sets the number of inter-op threads.
47    pub inter_threads: Option<usize>,
48}
49
50impl Default for InferParams {
51    /// Default provider is `OrtCpu` (Onnx CPU).
52    fn default() -> Self {
53        Self {
54            provider: Provider::OrtCpu,
55            intra_threads: None,
56            inter_threads: None,
57        }
58    }
59}
60
61/// Builder for loading or downloading, configuring, and creating face detectors.
62pub struct FaceDetectorBuilder {
63    detector: FaceDetection,
64    open_mode: OpenMode,
65    infer_params: InferParams,
66}
67
68impl FaceDetectorBuilder {
69    /// Create a new builder for the given face detector.
70    ///
71    /// # Arguments
72    ///
73    /// * `detector` - The face detector to build.
74    pub fn new(detector: FaceDetection) -> Self {
75        Self {
76            detector,
77            open_mode: OpenMode::Download,
78            infer_params: InferParams::default(),
79        }
80    }
81
82    /// Load the model from the given file path.
83    ///
84    /// # Arguments
85    ///
86    /// * `path` - Path to the model file.
87    pub fn from_file(mut self, path: String) -> Self {
88        self.open_mode = OpenMode::File(path);
89        self
90    }
91
92    /// Sets the model to be downloaded from the model repository.
93    pub fn download(mut self) -> Self {
94        self.open_mode = OpenMode::Download;
95        self
96    }
97
98    /// Sets the inference parameters.
99    pub fn infer_params(mut self, params: InferParams) -> Self {
100        self.infer_params = params;
101        self
102    }
103
104    /// Instantiates a new detector.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if the model can't be loaded.
109    ///
110    /// # Returns
111    ///
112    /// A new face detector.
113    pub fn build(&self) -> RustFacesResult<Box<dyn FaceDetector>> {
114        let mut ort_builder = ort::Environment::builder().with_name("RustFaces");
115
116        ort_builder = match self.infer_params.provider {
117            Provider::OrtCuda(device_id) => {
118                let provider = ExecutionProvider::CUDA(CUDAExecutionProviderOptions {
119                    device_id: device_id as u32,
120                    ..Default::default()
121                });
122
123                if !provider.is_available() {
124                    eprintln!("Warning: CUDA is not available. It'll likely use CPU inference.");
125                }
126                ort_builder.with_execution_providers([provider])
127            }
128            Provider::OrtVino(_device_id) => {
129                return Err(crate::RustFacesError::Other(
130                    "OpenVINO is not supported yet.".to_string(),
131                ));
132            }
133            Provider::OrtCoreMl => {
134                ort_builder.with_execution_providers([ExecutionProvider::CoreML(
135                    CoreMLExecutionProviderOptions::default(),
136                )])
137            }
138            _ => ort_builder,
139        };
140
141        let env = Arc::new(ort_builder.build()?);
142        let repository = GitHubRepository::new();
143
144        let model_paths = match &self.open_mode {
145            OpenMode::Download => repository
146                .get_model(&self.detector)?
147                .iter()
148                .map(|path| path.to_str().unwrap().to_string())
149                .collect(),
150            OpenMode::File(path) => vec![path.clone()],
151        };
152
153        match &self.detector {
154            FaceDetection::BlazeFace640(params) => Ok(Box::new(BlazeFace::from_file(
155                env,
156                &model_paths[0],
157                params.clone(),
158            ))),
159            FaceDetection::BlazeFace320(params) => Ok(Box::new(BlazeFace::from_file(
160                env,
161                &model_paths[0],
162                params.clone(),
163            ))),
164            FaceDetection::MtCnn(params) => Ok(Box::new(
165                MtCnn::from_file(
166                    env,
167                    &model_paths[0],
168                    &model_paths[1],
169                    &model_paths[2],
170                    params.clone(),
171                )
172                .unwrap(),
173            )),
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {}