rust_faces/
mtcnn.rs

1use std::sync::Arc;
2
3use image::{
4    imageops::{self, FilterType},
5    ImageBuffer, Rgb, RgbImage,
6};
7use ndarray::{s, Array3, Array4, ArrayViewD, Axis, CowArray, Zip};
8use ort::tensor::OrtOwnedTensor;
9
10use crate::{Face, FaceDetector, Nms, Rect, RustFacesResult};
11
12/// MtCnn parameters.
13#[derive(Clone)]
14pub struct MtCnnParams {
15    /// Minimum face size in pixels.
16    pub min_face_size: usize,
17    /// Confidence thresholds for each stage.
18    pub thresholds: [f32; 3],
19    /// Scale factor for the next pyramid image.
20    pub scale_factor: f32,
21    /// Non-maximum suppression.
22    pub nms: Nms,
23}
24
25impl Default for MtCnnParams {
26    fn default() -> Self {
27        Self {
28            min_face_size: 24,
29            thresholds: [0.6, 0.7, 0.7],
30            scale_factor: 0.709,
31            nms: Nms::default(),
32        }
33    }
34}
35
36/// MtCnn face detector.
37pub struct MtCnn {
38    pnet: ort::Session,
39    rnet: ort::Session,
40    onet: ort::Session,
41    params: MtCnnParams,
42}
43
44impl MtCnn {
45    /// Creates a new MtCnn face detector from the given ONNX model paths.
46    ///
47    /// # Arguments
48    ///
49    /// * `pnet_path` - Path to the P(roposal)Net ONNX model.
50    /// * `rnet_path` - Path to the R(efine)Net ONNX model.
51    /// * `onet_path` - Path to the O(ptimize)Net ONNX model.
52    /// * `params` - MtCnn parameters.
53    ///
54    /// # Returns
55    ///
56    /// * `MtCnn` - MtCnn face detector.
57    pub fn from_file(
58        env: Arc<ort::Environment>,
59        pnet_path: &str,
60        rnet_path: &str,
61        onet_path: &str,
62        params: MtCnnParams,
63    ) -> RustFacesResult<Self> {
64        let pnet = ort::session::SessionBuilder::new(&env)?.with_model_from_file(pnet_path)?;
65        let rnet = ort::session::SessionBuilder::new(&env)?.with_model_from_file(rnet_path)?;
66        let onet = ort::session::SessionBuilder::new(&env)?.with_model_from_file(onet_path)?;
67
68        Ok(Self {
69            pnet,
70            rnet,
71            onet,
72            params,
73        })
74    }
75
76    fn run_proposal_inference(
77        &self,
78        image: &ImageBuffer<Rgb<u8>, &[u8]>,
79    ) -> Result<Vec<Face>, crate::RustFacesError> {
80        const PNET_CELL_SIZE: usize = 12;
81        const PNET_STRIDE: usize = 2;
82
83        let (image_width, image_height) = (image.width() as usize, image.height() as usize);
84
85        let scales = {
86            // Make the first scale to match the minimum face size.
87            // Example, if the minimum face size is the same as PNET_CELL_SIZE,
88            // that means each cell in the output feature map will correspond
89            // to a 12x12 pixel region in the input image, hence no resize (first_scale = 1.0).
90            let first_scale = PNET_CELL_SIZE as f32 / self.params.min_face_size as f32;
91
92            let mut curr_size = image_width.min(image_height) as f32 * first_scale;
93            let mut scale = first_scale;
94            let mut scales = Vec::new();
95
96            while curr_size > PNET_CELL_SIZE as f32 {
97                scales.push(scale);
98                scale *= self.params.scale_factor;
99                curr_size *= self.params.scale_factor;
100            }
101            scales
102        };
103
104        let mut face_proposals = Vec::new();
105        for scale_factor in scales {
106            let image = imageops::resize(
107                image,
108                (scale_factor * image_width as f32) as u32,
109                (scale_factor * image_height as f32) as u32,
110                FilterType::Gaussian,
111            );
112
113            let (in_width, in_height) = image.dimensions();
114            let image = Array4::from_shape_fn(
115                (1, 3, in_height as usize, in_width as usize),
116                |(_n, c, h, w)| (image.get_pixel(w as u32, h as u32)[c] as f32 - 127.5) / 128.0,
117            );
118
119            let output_tensors = self.pnet.run(vec![ort::Value::from_array(
120                self.pnet.allocator(),
121                &CowArray::from(image).into_dyn(),
122            )?])?;
123
124            let box_regressions: OrtOwnedTensor<f32, _> = output_tensors[0].try_extract()?;
125            let scores: OrtOwnedTensor<f32, _> = output_tensors[1].try_extract()?;
126
127            let (net_out_width, net_out_height) = {
128                let shape = scores.view().dim();
129                (shape[3], shape[2])
130            };
131
132            let rescale_factor = 1.0 / scale_factor;
133            let mut faces = Vec::with_capacity(net_out_width * net_out_height);
134
135            Zip::indexed(
136                scores
137                    .view()
138                    .to_shape((2, net_out_height, net_out_width))
139                    .unwrap()
140                    .lanes(Axis(0)),
141            )
142            .and(
143                box_regressions
144                    .view()
145                    .to_shape((4, net_out_height, net_out_width))
146                    .unwrap()
147                    .lanes(Axis(0)),
148            )
149            .for_each(|(row, col), score, regression| {
150                let score = score[1];
151                if score > self.params.thresholds[0] {
152                    let x1 = col as f32 * PNET_STRIDE as f32 + regression[0];
153                    let y1 = row as f32 * PNET_STRIDE as f32 + regression[1];
154                    let x2 =
155                        col as f32 * PNET_STRIDE as f32 + PNET_CELL_SIZE as f32 + regression[2];
156                    let y2 =
157                        row as f32 * PNET_STRIDE as f32 + PNET_CELL_SIZE as f32 + regression[3];
158
159                    faces.push(Face {
160                        rect: Rect::at(x1, y1)
161                            .ending_at(x2, y2)
162                            .scale(rescale_factor, rescale_factor),
163                        confidence: score,
164                        landmarks: None,
165                    })
166                }
167            });
168
169            face_proposals.extend(self.params.nms.suppress_non_maxima(faces));
170        }
171        let mut proposals = self.params.nms.suppress_non_maxima(face_proposals);
172        proposals.iter_mut().for_each(|face| {
173            face.rect = face.rect.clamp(image_width as f32, image_height as f32);
174        });
175        Ok(proposals)
176    }
177
178    fn batch_faces<'a>(
179        &self,
180        image: &'a ImageBuffer<Rgb<u8>, &[u8]>,
181        proposals: &'a [Face],
182        input_size: usize,
183    ) -> impl Iterator<Item = (&'a [Face], Array4<f32>)> + 'a {
184        const BATCH_SIZE: usize = 16;
185        proposals.chunks(BATCH_SIZE).map(move |proposal_batch| {
186            let mut input_tensor = Array4::zeros((proposal_batch.len(), 3, input_size, input_size));
187            for (n, face) in proposal_batch.iter().enumerate() {
188                let face_image =
189                    RgbImage::from_fn(face.rect.width as u32, face.rect.height as u32, |x, y| {
190                        image
191                            .get_pixel(face.rect.x as u32 + x, face.rect.y as u32 + y)
192                            .to_owned()
193                    });
194                let face_image = imageops::resize(
195                    &face_image,
196                    input_size as u32,
197                    input_size as u32,
198                    FilterType::Gaussian,
199                );
200                input_tensor
201                    .slice_mut(s![n, .., .., ..])
202                    .assign(&Array3::from_shape_fn(
203                        (3, input_size, input_size),
204                        |(c, h, w)| {
205                            (face_image.get_pixel(w as u32, h as u32)[c] as f32 - 127.5) / 128.0
206                        },
207                    ));
208            }
209            (proposal_batch, input_tensor)
210        })
211    }
212
213    fn run_refine_net(
214        &self,
215        image: &ImageBuffer<Rgb<u8>, &[u8]>,
216        proposals: &[Face],
217    ) -> Result<Vec<Face>, crate::RustFacesError> {
218        let mut rnet_faces = Vec::new();
219        for (faces, input_tensor) in self.batch_faces(image, proposals, 24) {
220            let output_tensors = self.rnet.run(vec![ort::Value::from_array(
221                self.rnet.allocator(),
222                &CowArray::from(input_tensor).into_dyn(),
223            )?])?;
224            let box_regressions: OrtOwnedTensor<f32, _> = output_tensors[0].try_extract()?;
225            let scores: OrtOwnedTensor<f32, _> = output_tensors[1].try_extract()?;
226            let image_width = (image.width() - 1) as f32;
227            let image_height = (image.height() - 1) as f32;
228
229            let batch_faces = itertools::izip!(
230                faces.iter(),
231                scores
232                    .view()
233                    .to_shape((faces.len(), 2))
234                    .unwrap()
235                    .lanes(Axis(1))
236                    .into_iter(),
237                box_regressions
238                    .view()
239                    .to_shape((faces.len(), 4))
240                    .unwrap()
241                    .lanes(Axis(1))
242                    .into_iter()
243            )
244            .filter_map(|(face, score, regression)| {
245                let score = score[1];
246                if score >= self.params.thresholds[1] {
247                    let face_width = face.rect.width;
248                    let face_height = face.rect.height;
249                    let regression = regression.to_vec();
250
251                    let x1 = face.rect.x + regression[0] * face_width;
252                    let y1 = face.rect.y + regression[1] * face_height;
253                    let x2 = face.rect.right() + regression[2] * face_width;
254                    let y2 = face.rect.bottom() + regression[3] * face_height;
255
256                    Some(Face {
257                        rect: Rect::at(x1, y1)
258                            .ending_at(x2, y2)
259                            .clamp(image_width, image_height),
260                        confidence: score,
261                        landmarks: None,
262                    })
263                } else {
264                    None
265                }
266            })
267            .collect::<Vec<_>>();
268
269            rnet_faces.extend(batch_faces);
270        }
271        let boxes = self.params.nms.suppress_non_maxima_min(rnet_faces);
272        Ok(boxes)
273    }
274
275    fn run_optmized_net(
276        &self,
277        image: &ImageBuffer<Rgb<u8>, &[u8]>,
278        proposals: &[Face],
279    ) -> Result<Vec<Face>, crate::RustFacesError> {
280        let mut onet_faces = Vec::new();
281        for (faces, input_tensor) in self.batch_faces(image, proposals, 48) {
282            let output_tensors = self.onet.run(vec![ort::Value::from_array(
283                self.onet.allocator(),
284                &CowArray::from(input_tensor).into_dyn(),
285            )?])?;
286
287            let box_regressions: OrtOwnedTensor<f32, _> = output_tensors[0].try_extract()?; // 0
288            let landmarks_regressions: OrtOwnedTensor<f32, _> = output_tensors[1].try_extract()?;
289            let scores: OrtOwnedTensor<f32, _> = output_tensors[2].try_extract()?; // 1
290            let image_width = (image.width() - 1) as f32;
291            let image_height = (image.height() - 1) as f32;
292
293            let batch_faces = itertools::izip!(
294                faces.iter(),
295                scores
296                    .view()
297                    .to_shape((faces.len(), 2))
298                    .unwrap()
299                    .lanes(Axis(1))
300                    .into_iter(),
301                box_regressions
302                    .view()
303                    .to_shape((faces.len(), 4))
304                    .unwrap()
305                    .lanes(Axis(1))
306                    .into_iter(),
307                landmarks_regressions
308                    .view()
309                    .to_shape((faces.len(), 10))
310                    .unwrap()
311                    .lanes(Axis(1))
312                    .into_iter()
313            )
314            .filter_map(|(face, score, regression, landmarks)| {
315                let score = score[1];
316                if score >= self.params.thresholds[1] {
317                    let face_width = face.rect.width;
318                    let face_height = face.rect.height;
319                    let regression = regression.to_vec();
320
321                    let x1 = face.rect.x + regression[0] * face_width;
322                    let y1 = face.rect.y + regression[1] * face_height;
323                    let x2 = face.rect.right() + regression[2] * face_width;
324                    let y2 = face.rect.bottom() + regression[3] * face_height;
325
326                    let rect = Rect::at(x1, y1)
327                        .ending_at(x2, y2)
328                        .clamp(image_width, image_height);
329                    let mut landmarks_vec = Vec::new();
330
331                    for i in 0..5 {
332                        landmarks_vec.push((
333                            face.rect.x + landmarks[i] * face_width,
334                            face.rect.y + landmarks[i + 5] * face_height,
335                        ));
336                    }
337                    Some(Face {
338                        rect,
339                        confidence: score,
340                        landmarks: Some(landmarks_vec),
341                    })
342                } else {
343                    None
344                }
345            })
346            .collect::<Vec<_>>();
347
348            onet_faces.extend(batch_faces);
349        }
350        let boxes = self.params.nms.suppress_non_maxima_min(onet_faces);
351        Ok(boxes)
352    }
353}
354
355impl FaceDetector for MtCnn {
356    fn detect(&self, image: ArrayViewD<u8>) -> RustFacesResult<Vec<Face>> {
357        let shape = image.shape().to_vec();
358        let (image_width, image_height) = (shape[1], shape[0]);
359        let image = ImageBuffer::<Rgb<u8>, &[u8]>::from_raw(
360            image_width as u32,
361            image_height as u32,
362            image.as_slice().unwrap(),
363        )
364        .unwrap();
365
366        let proposals = self.run_proposal_inference(&image)?;
367        let refined_faces = self.run_refine_net(&image, &proposals)?;
368        let optimized_faces = self.run_optmized_net(&image, &refined_faces)?;
369        Ok(optimized_faces)
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use std::path::PathBuf;
376
377    use super::*;
378    use crate::{
379        imaging::ToRgb8,
380        model_repository::{GitHubRepository, ModelRepository},
381        mtcnn::MtCnn,
382        testing::{output_dir, sample_array_image},
383        viz,
384    };
385    use ndarray::Array3;
386    use rstest::rstest;
387    use std::sync::Arc;
388
389    #[cfg(feature = "viz")]
390    #[rstest]
391    fn should_detect(sample_array_image: Array3<u8>, output_dir: PathBuf) {
392        use crate::FaceDetection;
393
394        let environment = Arc::new(
395            ort::Environment::builder()
396                .with_name("MtCnn")
397                .build()
398                .unwrap(),
399        );
400
401        let drive = GitHubRepository::new();
402        let model_paths = drive
403            .get_model(&FaceDetection::MtCnn(MtCnnParams::default()))
404            .expect("Can't download model");
405
406        let face_detector = MtCnn::from_file(
407            environment,
408            model_paths[0].to_str().unwrap(),
409            model_paths[1].to_str().unwrap(),
410            model_paths[2].to_str().unwrap(),
411            MtCnnParams::default(),
412        )
413        .expect("Failed to load MTCNN detector.");
414        let mut canvas = sample_array_image.to_rgb8();
415        let faces = face_detector
416            .detect(sample_array_image.into_dyn().view())
417            .expect("Can't detect faces");
418
419        viz::draw_faces(&mut canvas, faces);
420
421        canvas
422            .save(output_dir.join("mtcnn.png"))
423            .expect("Can't save image");
424    }
425}