1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use std::collections::HashMap;

use crate::Face;

#[derive(Copy, Clone, Debug)]
pub struct Nms {
    pub iou_threshold: f32,
}

impl Default for Nms {
    fn default() -> Self {
        Self { iou_threshold: 0.3 }
    }
}

impl Nms {
    pub fn suppress_non_maxima(&self, mut faces: Vec<Face>) -> Vec<Face> {
        faces.sort_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap());

        let mut faces_map = HashMap::new();
        faces.iter().rev().enumerate().for_each(|(i, face)| {
            faces_map.insert(i, face);
        });

        let mut nms_faces = Vec::with_capacity(faces.len());
        let mut count = 0;
        while !faces_map.is_empty() {
            if let Some((_, face)) = faces_map.remove_entry(&count) {
                nms_faces.push(face.clone());
                faces_map.retain(|_, face2| face.rect.iou(&face2.rect) < self.iou_threshold);
            }
            count += 1;
        }

        nms_faces
    }
}

#[cfg(test)]
mod tests {
    use rstest::rstest;

    use super::*;
    use crate::{Face, Rect};

    #[rstest]
    fn test_nms() {
        let nms = Nms::default();
        let faces = vec![
            Face {
                rect: Rect {
                    x: 0.0,
                    y: 0.0,
                    width: 1.0,
                    height: 1.0,
                },
                confidence: 0.9,
                landmarks: None,
            },
            Face {
                rect: Rect {
                    x: 0.0,
                    y: 0.0,
                    width: 1.0,
                    height: 1.0,
                },
                confidence: 0.8,
                landmarks: None,
            },
            Face {
                rect: Rect {
                    x: 0.0,
                    y: 0.0,
                    width: 1.0,
                    height: 1.0,
                },
                confidence: 0.7,
                landmarks: None,
            },
        ];

        let faces = nms.suppress_non_maxima(faces);

        assert_eq!(faces.len(), 1);
        assert_eq!(faces[0].confidence, 0.9);
    }
}