1use std::sync::Arc;
2
3use image::{
4 imageops::{self, FilterType},
5 ImageBuffer, Rgb, RgbImage,
6};
7use ndarray::{s, Array3, Array4, ArrayViewD, Axis, CowArray, Zip};
8use ort::tensor::OrtOwnedTensor;
9
10use crate::{Face, FaceDetector, Nms, Rect, RustFacesResult};
11
12#[derive(Clone)]
14pub struct MtCnnParams {
15 pub min_face_size: usize,
17 pub thresholds: [f32; 3],
19 pub scale_factor: f32,
21 pub nms: Nms,
23}
24
25impl Default for MtCnnParams {
26 fn default() -> Self {
27 Self {
28 min_face_size: 24,
29 thresholds: [0.6, 0.7, 0.7],
30 scale_factor: 0.709,
31 nms: Nms::default(),
32 }
33 }
34}
35
36pub struct MtCnn {
38 pnet: ort::Session,
39 rnet: ort::Session,
40 onet: ort::Session,
41 params: MtCnnParams,
42}
43
44impl MtCnn {
45 pub fn from_file(
58 env: Arc<ort::Environment>,
59 pnet_path: &str,
60 rnet_path: &str,
61 onet_path: &str,
62 params: MtCnnParams,
63 ) -> RustFacesResult<Self> {
64 let pnet = ort::session::SessionBuilder::new(&env)?.with_model_from_file(pnet_path)?;
65 let rnet = ort::session::SessionBuilder::new(&env)?.with_model_from_file(rnet_path)?;
66 let onet = ort::session::SessionBuilder::new(&env)?.with_model_from_file(onet_path)?;
67
68 Ok(Self {
69 pnet,
70 rnet,
71 onet,
72 params,
73 })
74 }
75
76 fn run_proposal_inference(
77 &self,
78 image: &ImageBuffer<Rgb<u8>, &[u8]>,
79 ) -> Result<Vec<Face>, crate::RustFacesError> {
80 const PNET_CELL_SIZE: usize = 12;
81 const PNET_STRIDE: usize = 2;
82
83 let (image_width, image_height) = (image.width() as usize, image.height() as usize);
84
85 let scales = {
86 let first_scale = PNET_CELL_SIZE as f32 / self.params.min_face_size as f32;
91
92 let mut curr_size = image_width.min(image_height) as f32 * first_scale;
93 let mut scale = first_scale;
94 let mut scales = Vec::new();
95
96 while curr_size > PNET_CELL_SIZE as f32 {
97 scales.push(scale);
98 scale *= self.params.scale_factor;
99 curr_size *= self.params.scale_factor;
100 }
101 scales
102 };
103
104 let mut face_proposals = Vec::new();
105 for scale_factor in scales {
106 let image = imageops::resize(
107 image,
108 (scale_factor * image_width as f32) as u32,
109 (scale_factor * image_height as f32) as u32,
110 FilterType::Gaussian,
111 );
112
113 let (in_width, in_height) = image.dimensions();
114 let image = Array4::from_shape_fn(
115 (1, 3, in_height as usize, in_width as usize),
116 |(_n, c, h, w)| (image.get_pixel(w as u32, h as u32)[c] as f32 - 127.5) / 128.0,
117 );
118
119 let output_tensors = self.pnet.run(vec![ort::Value::from_array(
120 self.pnet.allocator(),
121 &CowArray::from(image).into_dyn(),
122 )?])?;
123
124 let box_regressions: OrtOwnedTensor<f32, _> = output_tensors[0].try_extract()?;
125 let scores: OrtOwnedTensor<f32, _> = output_tensors[1].try_extract()?;
126
127 let (net_out_width, net_out_height) = {
128 let shape = scores.view().dim();
129 (shape[3], shape[2])
130 };
131
132 let rescale_factor = 1.0 / scale_factor;
133 let mut faces = Vec::with_capacity(net_out_width * net_out_height);
134
135 Zip::indexed(
136 scores
137 .view()
138 .to_shape((2, net_out_height, net_out_width))
139 .unwrap()
140 .lanes(Axis(0)),
141 )
142 .and(
143 box_regressions
144 .view()
145 .to_shape((4, net_out_height, net_out_width))
146 .unwrap()
147 .lanes(Axis(0)),
148 )
149 .for_each(|(row, col), score, regression| {
150 let score = score[1];
151 if score > self.params.thresholds[0] {
152 let x1 = col as f32 * PNET_STRIDE as f32 + regression[0];
153 let y1 = row as f32 * PNET_STRIDE as f32 + regression[1];
154 let x2 =
155 col as f32 * PNET_STRIDE as f32 + PNET_CELL_SIZE as f32 + regression[2];
156 let y2 =
157 row as f32 * PNET_STRIDE as f32 + PNET_CELL_SIZE as f32 + regression[3];
158
159 faces.push(Face {
160 rect: Rect::at(x1, y1)
161 .ending_at(x2, y2)
162 .scale(rescale_factor, rescale_factor),
163 confidence: score,
164 landmarks: None,
165 })
166 }
167 });
168
169 face_proposals.extend(self.params.nms.suppress_non_maxima(faces));
170 }
171 let mut proposals = self.params.nms.suppress_non_maxima(face_proposals);
172 proposals.iter_mut().for_each(|face| {
173 face.rect = face.rect.clamp(image_width as f32, image_height as f32);
174 });
175 Ok(proposals)
176 }
177
178 fn batch_faces<'a>(
179 &self,
180 image: &'a ImageBuffer<Rgb<u8>, &[u8]>,
181 proposals: &'a [Face],
182 input_size: usize,
183 ) -> impl Iterator<Item = (&'a [Face], Array4<f32>)> + 'a {
184 const BATCH_SIZE: usize = 16;
185 proposals.chunks(BATCH_SIZE).map(move |proposal_batch| {
186 let mut input_tensor = Array4::zeros((proposal_batch.len(), 3, input_size, input_size));
187 for (n, face) in proposal_batch.iter().enumerate() {
188 let face_image =
189 RgbImage::from_fn(face.rect.width as u32, face.rect.height as u32, |x, y| {
190 image
191 .get_pixel(face.rect.x as u32 + x, face.rect.y as u32 + y)
192 .to_owned()
193 });
194 let face_image = imageops::resize(
195 &face_image,
196 input_size as u32,
197 input_size as u32,
198 FilterType::Gaussian,
199 );
200 input_tensor
201 .slice_mut(s![n, .., .., ..])
202 .assign(&Array3::from_shape_fn(
203 (3, input_size, input_size),
204 |(c, h, w)| {
205 (face_image.get_pixel(w as u32, h as u32)[c] as f32 - 127.5) / 128.0
206 },
207 ));
208 }
209 (proposal_batch, input_tensor)
210 })
211 }
212
213 fn run_refine_net(
214 &self,
215 image: &ImageBuffer<Rgb<u8>, &[u8]>,
216 proposals: &[Face],
217 ) -> Result<Vec<Face>, crate::RustFacesError> {
218 let mut rnet_faces = Vec::new();
219 for (faces, input_tensor) in self.batch_faces(image, proposals, 24) {
220 let output_tensors = self.rnet.run(vec![ort::Value::from_array(
221 self.rnet.allocator(),
222 &CowArray::from(input_tensor).into_dyn(),
223 )?])?;
224 let box_regressions: OrtOwnedTensor<f32, _> = output_tensors[0].try_extract()?;
225 let scores: OrtOwnedTensor<f32, _> = output_tensors[1].try_extract()?;
226 let image_width = (image.width() - 1) as f32;
227 let image_height = (image.height() - 1) as f32;
228
229 let batch_faces = itertools::izip!(
230 faces.iter(),
231 scores
232 .view()
233 .to_shape((faces.len(), 2))
234 .unwrap()
235 .lanes(Axis(1))
236 .into_iter(),
237 box_regressions
238 .view()
239 .to_shape((faces.len(), 4))
240 .unwrap()
241 .lanes(Axis(1))
242 .into_iter()
243 )
244 .filter_map(|(face, score, regression)| {
245 let score = score[1];
246 if score >= self.params.thresholds[1] {
247 let face_width = face.rect.width;
248 let face_height = face.rect.height;
249 let regression = regression.to_vec();
250
251 let x1 = face.rect.x + regression[0] * face_width;
252 let y1 = face.rect.y + regression[1] * face_height;
253 let x2 = face.rect.right() + regression[2] * face_width;
254 let y2 = face.rect.bottom() + regression[3] * face_height;
255
256 Some(Face {
257 rect: Rect::at(x1, y1)
258 .ending_at(x2, y2)
259 .clamp(image_width, image_height),
260 confidence: score,
261 landmarks: None,
262 })
263 } else {
264 None
265 }
266 })
267 .collect::<Vec<_>>();
268
269 rnet_faces.extend(batch_faces);
270 }
271 let boxes = self.params.nms.suppress_non_maxima_min(rnet_faces);
272 Ok(boxes)
273 }
274
275 fn run_optmized_net(
276 &self,
277 image: &ImageBuffer<Rgb<u8>, &[u8]>,
278 proposals: &[Face],
279 ) -> Result<Vec<Face>, crate::RustFacesError> {
280 let mut onet_faces = Vec::new();
281 for (faces, input_tensor) in self.batch_faces(image, proposals, 48) {
282 let output_tensors = self.onet.run(vec![ort::Value::from_array(
283 self.onet.allocator(),
284 &CowArray::from(input_tensor).into_dyn(),
285 )?])?;
286
287 let box_regressions: OrtOwnedTensor<f32, _> = output_tensors[0].try_extract()?; let landmarks_regressions: OrtOwnedTensor<f32, _> = output_tensors[1].try_extract()?;
289 let scores: OrtOwnedTensor<f32, _> = output_tensors[2].try_extract()?; let image_width = (image.width() - 1) as f32;
291 let image_height = (image.height() - 1) as f32;
292
293 let batch_faces = itertools::izip!(
294 faces.iter(),
295 scores
296 .view()
297 .to_shape((faces.len(), 2))
298 .unwrap()
299 .lanes(Axis(1))
300 .into_iter(),
301 box_regressions
302 .view()
303 .to_shape((faces.len(), 4))
304 .unwrap()
305 .lanes(Axis(1))
306 .into_iter(),
307 landmarks_regressions
308 .view()
309 .to_shape((faces.len(), 10))
310 .unwrap()
311 .lanes(Axis(1))
312 .into_iter()
313 )
314 .filter_map(|(face, score, regression, landmarks)| {
315 let score = score[1];
316 if score >= self.params.thresholds[1] {
317 let face_width = face.rect.width;
318 let face_height = face.rect.height;
319 let regression = regression.to_vec();
320
321 let x1 = face.rect.x + regression[0] * face_width;
322 let y1 = face.rect.y + regression[1] * face_height;
323 let x2 = face.rect.right() + regression[2] * face_width;
324 let y2 = face.rect.bottom() + regression[3] * face_height;
325
326 let rect = Rect::at(x1, y1)
327 .ending_at(x2, y2)
328 .clamp(image_width, image_height);
329 let mut landmarks_vec = Vec::new();
330
331 for i in 0..5 {
332 landmarks_vec.push((
333 face.rect.x + landmarks[i] * face_width,
334 face.rect.y + landmarks[i + 5] * face_height,
335 ));
336 }
337 Some(Face {
338 rect,
339 confidence: score,
340 landmarks: Some(landmarks_vec),
341 })
342 } else {
343 None
344 }
345 })
346 .collect::<Vec<_>>();
347
348 onet_faces.extend(batch_faces);
349 }
350 let boxes = self.params.nms.suppress_non_maxima_min(onet_faces);
351 Ok(boxes)
352 }
353}
354
355impl FaceDetector for MtCnn {
356 fn detect(&self, image: ArrayViewD<u8>) -> RustFacesResult<Vec<Face>> {
357 let shape = image.shape().to_vec();
358 let (image_width, image_height) = (shape[1], shape[0]);
359 let image = ImageBuffer::<Rgb<u8>, &[u8]>::from_raw(
360 image_width as u32,
361 image_height as u32,
362 image.as_slice().unwrap(),
363 )
364 .unwrap();
365
366 let proposals = self.run_proposal_inference(&image)?;
367 let refined_faces = self.run_refine_net(&image, &proposals)?;
368 let optimized_faces = self.run_optmized_net(&image, &refined_faces)?;
369 Ok(optimized_faces)
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use std::path::PathBuf;
376
377 use super::*;
378 use crate::{
379 imaging::ToRgb8,
380 model_repository::{GitHubRepository, ModelRepository},
381 mtcnn::MtCnn,
382 testing::{output_dir, sample_array_image},
383 viz,
384 };
385 use ndarray::Array3;
386 use rstest::rstest;
387 use std::sync::Arc;
388
389 #[cfg(feature = "viz")]
390 #[rstest]
391 fn should_detect(sample_array_image: Array3<u8>, output_dir: PathBuf) {
392 use crate::FaceDetection;
393
394 let environment = Arc::new(
395 ort::Environment::builder()
396 .with_name("MtCnn")
397 .build()
398 .unwrap(),
399 );
400
401 let drive = GitHubRepository::new();
402 let model_paths = drive
403 .get_model(&FaceDetection::MtCnn(MtCnnParams::default()))
404 .expect("Can't download model");
405
406 let face_detector = MtCnn::from_file(
407 environment,
408 model_paths[0].to_str().unwrap(),
409 model_paths[1].to_str().unwrap(),
410 model_paths[2].to_str().unwrap(),
411 MtCnnParams::default(),
412 )
413 .expect("Failed to load MTCNN detector.");
414 let mut canvas = sample_array_image.to_rgb8();
415 let faces = face_detector
416 .detect(sample_array_image.into_dyn().view())
417 .expect("Can't detect faces");
418
419 viz::draw_faces(&mut canvas, faces);
420
421 canvas
422 .save(output_dir.join("mtcnn.png"))
423 .expect("Can't save image");
424 }
425}