1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use std::sync::Arc;

use crate::{
    detection::{DetectionParams, FaceDetector, RustFacesResult},
    model_repository::{GitHubRepository, ModelRepository},
    BlazeFace, Nms,
};

#[derive(Clone, Copy, Debug)]
pub enum FaceDetection {
    BlazeFace640,
}

#[derive(Clone, Debug)]
enum OpenMode {
    File(String),
    Download,
}

#[derive(Clone, Copy, Debug)]
pub enum Provider {
    OrtCpu,
    OrtCuda,
    OrtVino,
}
pub struct InferParams {
    pub provider: Provider,
    pub cpu_cores: usize,
    pub batch_size: usize,
}

impl Default for InferParams {
    fn default() -> Self {
        Self {
            provider: Provider::OrtCpu,
            cpu_cores: 1,
            batch_size: 1,
        }
    }
}

pub struct FaceDetectorBuilder {
    detector: FaceDetection,
    open_mode: OpenMode,
    params: DetectionParams,
    infer_params: InferParams,
}

impl FaceDetectorBuilder {
    pub fn new(detector: FaceDetection) -> Self {
        Self {
            detector,
            open_mode: OpenMode::Download,
            params: DetectionParams::default(),
            infer_params: InferParams::default(),
        }
    }

    pub fn from_file(mut self, path: String) -> Self {
        self.open_mode = OpenMode::File(path);
        self
    }

    pub fn download(mut self) -> Self {
        self.open_mode = OpenMode::Download;
        self
    }

    pub fn detect_params(mut self, params: DetectionParams) -> Self {
        self.params = params;
        self
    }

    pub fn nms(mut self, nms: Nms) -> Self {
        self.params.nms = nms;
        self
    }

    pub fn infer_params(mut self, params: InferParams) -> Self {
        self.infer_params = params;
        self
    }
    pub fn build(&self) -> RustFacesResult<Box<dyn FaceDetector>> {
        let env = Arc::new(
            ort::Environment::builder()
                .with_name("BlazeFace")
                .build()
                .unwrap(),
        );
        let repository = GitHubRepository::new();

        let model_path = match &self.open_mode {
            OpenMode::Download => repository
                .get_model(self.detector)?
                .to_str()
                .unwrap()
                .to_string(),
            OpenMode::File(path) => path.clone(),
        };

        Ok(Box::new(match self.detector {
            FaceDetection::BlazeFace640 => BlazeFace::from_file(env, &model_path, self.params),
        }))
    }
}

#[cfg(test)]
mod tests {}