use std::collections::HashMap;
use crate::Face;
#[derive(Copy, Clone, Debug)]
pub struct Nms {
pub iou_threshold: f32,
}
impl Default for Nms {
fn default() -> Self {
Self { iou_threshold: 0.3 }
}
}
impl Nms {
pub fn suppress_non_maxima(&self, mut faces: Vec<Face>) -> Vec<Face> {
faces.sort_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap());
let mut faces_map = HashMap::new();
faces.iter().rev().enumerate().for_each(|(i, face)| {
faces_map.insert(i, face);
});
let mut nms_faces = Vec::with_capacity(faces.len());
let mut count = 0;
while !faces_map.is_empty() {
if let Some((_, face)) = faces_map.remove_entry(&count) {
nms_faces.push(face.clone());
faces_map.retain(|_, face2| face.rect.iou(&face2.rect) < self.iou_threshold);
}
count += 1;
}
nms_faces
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::{Face, Rect};
#[rstest]
fn test_nms() {
let nms = Nms::default();
let faces = vec![
Face {
rect: Rect {
x: 0.0,
y: 0.0,
width: 1.0,
height: 1.0,
},
confidence: 0.9,
landmarks: None,
},
Face {
rect: Rect {
x: 0.0,
y: 0.0,
width: 1.0,
height: 1.0,
},
confidence: 0.8,
landmarks: None,
},
Face {
rect: Rect {
x: 0.0,
y: 0.0,
width: 1.0,
height: 1.0,
},
confidence: 0.7,
landmarks: None,
},
];
let faces = nms.suppress_non_maxima(faces);
assert_eq!(faces.len(), 1);
assert_eq!(faces[0].confidence, 0.9);
}
}