Skip to main content

yscv_detect/
roi.rs

1use crate::DetectError;
2use yscv_tensor::Tensor;
3
4/// RoI Pooling: for each RoI, divides the region into an `output_size` grid
5/// and max-pools each cell.
6///
7/// * `features` -- `[H, W, C]` single-image feature map.
8/// * `rois` -- slice of `(x1, y1, x2, y2)` in feature-map coordinates.
9/// * `output_size` -- `(out_h, out_w)`.
10///
11/// Returns a tensor of shape `[num_rois, out_h, out_w, C]`.
12pub fn roi_pool(
13    features: &Tensor,
14    rois: &[(f32, f32, f32, f32)],
15    output_size: (usize, usize),
16) -> Result<Tensor, DetectError> {
17    let shape = features.shape();
18    if shape.len() != 3 {
19        return Err(DetectError::InvalidMapShape {
20            expected_rank: 3,
21            got: shape.to_vec(),
22        });
23    }
24    let feat_h = shape[0];
25    let feat_w = shape[1];
26    let channels = shape[2];
27    let (out_h, out_w) = output_size;
28    let num_rois = rois.len();
29
30    let total = num_rois * out_h * out_w * channels;
31    let mut data = vec![f32::NEG_INFINITY; total];
32
33    for (roi_idx, &(rx1, ry1, rx2, ry2)) in rois.iter().enumerate() {
34        let roi_h = (ry2 - ry1).max(0.0);
35        let roi_w = (rx2 - rx1).max(0.0);
36        let bin_h = roi_h / out_h as f32;
37        let bin_w = roi_w / out_w as f32;
38
39        for oh in 0..out_h {
40            for ow in 0..out_w {
41                let y_start = (ry1 + oh as f32 * bin_h).floor() as isize;
42                let y_end = (ry1 + (oh + 1) as f32 * bin_h).ceil() as isize;
43                let x_start = (rx1 + ow as f32 * bin_w).floor() as isize;
44                let x_end = (rx1 + (ow + 1) as f32 * bin_w).ceil() as isize;
45
46                let y_start = y_start.max(0) as usize;
47                let y_end = (y_end as usize).min(feat_h);
48                let x_start = x_start.max(0) as usize;
49                let x_end = (x_end as usize).min(feat_w);
50
51                if y_start >= y_end || x_start >= x_end {
52                    // Empty bin -- leave as NEG_INFINITY (or could use 0).
53                    let base = ((roi_idx * out_h + oh) * out_w + ow) * channels;
54                    for c in 0..channels {
55                        data[base + c] = 0.0;
56                    }
57                    continue;
58                }
59
60                for fy in y_start..y_end {
61                    for fx in x_start..x_end {
62                        for c in 0..channels {
63                            let val = features.get(&[fy, fx, c])?;
64                            let out_idx = ((roi_idx * out_h + oh) * out_w + ow) * channels + c;
65                            if val > data[out_idx] {
66                                data[out_idx] = val;
67                            }
68                        }
69                    }
70                }
71            }
72        }
73    }
74
75    Ok(Tensor::from_vec(
76        vec![num_rois, out_h, out_w, channels],
77        data,
78    )?)
79}
80
81/// RoI Align: bilinear-interpolation version of RoI pooling (no quantisation).
82///
83/// * `features` -- `[H, W, C]` single-image feature map.
84/// * `rois` -- slice of `(x1, y1, x2, y2)` in feature-map coordinates.
85/// * `output_size` -- `(out_h, out_w)`.
86/// * `sampling_ratio` -- number of sampling points per bin dimension
87///   (0 means adaptive: `ceil(bin_size)`).
88///
89/// Returns a tensor of shape `[num_rois, out_h, out_w, C]`.
90pub fn roi_align(
91    features: &Tensor,
92    rois: &[(f32, f32, f32, f32)],
93    output_size: (usize, usize),
94    sampling_ratio: usize,
95) -> Result<Tensor, DetectError> {
96    let shape = features.shape();
97    if shape.len() != 3 {
98        return Err(DetectError::InvalidMapShape {
99            expected_rank: 3,
100            got: shape.to_vec(),
101        });
102    }
103    let feat_h = shape[0];
104    let feat_w = shape[1];
105    let channels = shape[2];
106    let (out_h, out_w) = output_size;
107    let num_rois = rois.len();
108
109    let total = num_rois * out_h * out_w * channels;
110    let mut data = vec![0.0f32; total];
111
112    for (roi_idx, &(rx1, ry1, rx2, ry2)) in rois.iter().enumerate() {
113        let roi_h = (ry2 - ry1).max(1e-6);
114        let roi_w = (rx2 - rx1).max(1e-6);
115        let bin_h = roi_h / out_h as f32;
116        let bin_w = roi_w / out_w as f32;
117
118        let sample_h = if sampling_ratio > 0 {
119            sampling_ratio
120        } else {
121            bin_h.ceil() as usize
122        };
123        let sample_w = if sampling_ratio > 0 {
124            sampling_ratio
125        } else {
126            bin_w.ceil() as usize
127        };
128
129        let count = (sample_h * sample_w) as f32;
130
131        for oh in 0..out_h {
132            for ow in 0..out_w {
133                let base = ((roi_idx * out_h + oh) * out_w + ow) * channels;
134
135                for sy in 0..sample_h {
136                    let y = ry1 + bin_h * (oh as f32 + (sy as f32 + 0.5) / sample_h as f32);
137                    for sx in 0..sample_w {
138                        let x = rx1 + bin_w * (ow as f32 + (sx as f32 + 0.5) / sample_w as f32);
139
140                        // Bilinear interpolation.
141                        if y < -1.0 || y > feat_h as f32 || x < -1.0 || x > feat_w as f32 {
142                            continue; // outside -- contributes 0
143                        }
144
145                        let y = y.max(0.0).min((feat_h - 1) as f32);
146                        let x = x.max(0.0).min((feat_w - 1) as f32);
147
148                        let y_low = y.floor() as usize;
149                        let x_low = x.floor() as usize;
150                        let y_high = (y_low + 1).min(feat_h - 1);
151                        let x_high = (x_low + 1).min(feat_w - 1);
152
153                        let ly = y - y_low as f32;
154                        let lx = x - x_low as f32;
155                        let hy = 1.0 - ly;
156                        let hx = 1.0 - lx;
157
158                        let w1 = hy * hx;
159                        let w2 = hy * lx;
160                        let w3 = ly * hx;
161                        let w4 = ly * lx;
162
163                        for c in 0..channels {
164                            let v1 = features.get(&[y_low, x_low, c])?;
165                            let v2 = features.get(&[y_low, x_high, c])?;
166                            let v3 = features.get(&[y_high, x_low, c])?;
167                            let v4 = features.get(&[y_high, x_high, c])?;
168                            data[base + c] += (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) / count;
169                        }
170                    }
171                }
172            }
173        }
174    }
175
176    Ok(Tensor::from_vec(
177        vec![num_rois, out_h, out_w, channels],
178        data,
179    )?)
180}