rust_faces/
priorboxes.rs

1use itertools::iproduct;
2
3use crate::Rect;
4
5#[derive(Debug, Clone)]
6pub struct PriorBoxesParams {
7    min_sizes: Vec<Vec<usize>>,
8    steps: Vec<usize>,
9    variance: (f32, f32),
10}
11
12impl Default for PriorBoxesParams {
13    fn default() -> Self {
14        Self {
15            min_sizes: vec![vec![8, 11], vec![14, 19, 26, 38, 64, 149]],
16            steps: vec![8, 16],
17            variance: (0.1, 0.2),
18        }
19    }
20}
21
22pub struct PriorBoxes {
23    pub anchors: Vec<(f32, f32, f32, f32)>,
24    variances: (f32, f32),
25}
26
27impl PriorBoxes {
28    pub fn new(params: &PriorBoxesParams, image_size: (usize, usize)) -> Self {
29        let feature_map_sizes: Vec<(usize, usize)> = params
30            .steps
31            .iter()
32            .map(|&step| (image_size.0 / step, image_size.1 / step))
33            .collect();
34
35        let mut anchors = Vec::new();
36
37        for ((f, min_sizes), step) in feature_map_sizes
38            .iter()
39            .zip(params.min_sizes.iter())
40            .zip(params.steps.iter())
41        {
42            let step = *step;
43            for (i, j) in iproduct!(0..f.1, 0..f.0) {
44                for min_size in min_sizes {
45                    let s_kx = *min_size as f32 / image_size.0 as f32;
46                    let s_ky = *min_size as f32 / image_size.1 as f32;
47                    let cx = (j as f32 + 0.5) * step as f32 / image_size.0 as f32;
48                    let cy = (i as f32 + 0.5) * step as f32 / image_size.1 as f32;
49                    anchors.push((cx, cy, s_kx, s_ky));
50                }
51            }
52        }
53
54        Self {
55            anchors,
56            variances: params.variance,
57        }
58    }
59
60    pub fn decode_box(&self, prior: &(f32, f32, f32, f32), pred: &(f32, f32, f32, f32)) -> Rect {
61        let (anchor_cx, anchor_cy, s_kx, s_ky) = prior;
62        let (x1, y1, x2, y2) = pred;
63
64        let cx = anchor_cx + x1 * self.variances.0 * s_kx;
65        let cy = anchor_cy + y1 * self.variances.0 * s_ky;
66        let width = s_kx * (x2 * self.variances.1).exp();
67        let height = s_ky * (y2 * self.variances.1).exp();
68        let x_start = cx - width / 2.0;
69        let y_start = cy - height / 2.0;
70        Rect::at(x_start, y_start).ending_at(width + x_start, height + y_start)
71    }
72
73    pub fn decode_landmark(
74        &self,
75        prior: &(f32, f32, f32, f32),
76        landmark: (f32, f32),
77    ) -> (f32, f32) {
78        let (anchor_cx, anchor_cy, s_kx, s_ky) = prior;
79        let (x, y) = landmark;
80        let x = anchor_cx + x * self.variances.0 * s_kx;
81        let y = anchor_cy + y * self.variances.0 * s_ky;
82        (x, y)
83    }
84}