1use crate::DetectError;
2use yscv_tensor::Tensor;
3
4pub fn roi_pool(
13 features: &Tensor,
14 rois: &[(f32, f32, f32, f32)],
15 output_size: (usize, usize),
16) -> Result<Tensor, DetectError> {
17 let shape = features.shape();
18 if shape.len() != 3 {
19 return Err(DetectError::InvalidMapShape {
20 expected_rank: 3,
21 got: shape.to_vec(),
22 });
23 }
24 let feat_h = shape[0];
25 let feat_w = shape[1];
26 let channels = shape[2];
27 let (out_h, out_w) = output_size;
28 let num_rois = rois.len();
29
30 let total = num_rois * out_h * out_w * channels;
31 let mut data = vec![f32::NEG_INFINITY; total];
32
33 for (roi_idx, &(rx1, ry1, rx2, ry2)) in rois.iter().enumerate() {
34 let roi_h = (ry2 - ry1).max(0.0);
35 let roi_w = (rx2 - rx1).max(0.0);
36 let bin_h = roi_h / out_h as f32;
37 let bin_w = roi_w / out_w as f32;
38
39 for oh in 0..out_h {
40 for ow in 0..out_w {
41 let y_start = (ry1 + oh as f32 * bin_h).floor() as isize;
42 let y_end = (ry1 + (oh + 1) as f32 * bin_h).ceil() as isize;
43 let x_start = (rx1 + ow as f32 * bin_w).floor() as isize;
44 let x_end = (rx1 + (ow + 1) as f32 * bin_w).ceil() as isize;
45
46 let y_start = y_start.max(0) as usize;
47 let y_end = (y_end as usize).min(feat_h);
48 let x_start = x_start.max(0) as usize;
49 let x_end = (x_end as usize).min(feat_w);
50
51 if y_start >= y_end || x_start >= x_end {
52 let base = ((roi_idx * out_h + oh) * out_w + ow) * channels;
54 for c in 0..channels {
55 data[base + c] = 0.0;
56 }
57 continue;
58 }
59
60 for fy in y_start..y_end {
61 for fx in x_start..x_end {
62 for c in 0..channels {
63 let val = features.get(&[fy, fx, c])?;
64 let out_idx = ((roi_idx * out_h + oh) * out_w + ow) * channels + c;
65 if val > data[out_idx] {
66 data[out_idx] = val;
67 }
68 }
69 }
70 }
71 }
72 }
73 }
74
75 Ok(Tensor::from_vec(
76 vec![num_rois, out_h, out_w, channels],
77 data,
78 )?)
79}
80
81pub fn roi_align(
91 features: &Tensor,
92 rois: &[(f32, f32, f32, f32)],
93 output_size: (usize, usize),
94 sampling_ratio: usize,
95) -> Result<Tensor, DetectError> {
96 let shape = features.shape();
97 if shape.len() != 3 {
98 return Err(DetectError::InvalidMapShape {
99 expected_rank: 3,
100 got: shape.to_vec(),
101 });
102 }
103 let feat_h = shape[0];
104 let feat_w = shape[1];
105 let channels = shape[2];
106 let (out_h, out_w) = output_size;
107 let num_rois = rois.len();
108
109 let total = num_rois * out_h * out_w * channels;
110 let mut data = vec![0.0f32; total];
111
112 for (roi_idx, &(rx1, ry1, rx2, ry2)) in rois.iter().enumerate() {
113 let roi_h = (ry2 - ry1).max(1e-6);
114 let roi_w = (rx2 - rx1).max(1e-6);
115 let bin_h = roi_h / out_h as f32;
116 let bin_w = roi_w / out_w as f32;
117
118 let sample_h = if sampling_ratio > 0 {
119 sampling_ratio
120 } else {
121 bin_h.ceil() as usize
122 };
123 let sample_w = if sampling_ratio > 0 {
124 sampling_ratio
125 } else {
126 bin_w.ceil() as usize
127 };
128
129 let count = (sample_h * sample_w) as f32;
130
131 for oh in 0..out_h {
132 for ow in 0..out_w {
133 let base = ((roi_idx * out_h + oh) * out_w + ow) * channels;
134
135 for sy in 0..sample_h {
136 let y = ry1 + bin_h * (oh as f32 + (sy as f32 + 0.5) / sample_h as f32);
137 for sx in 0..sample_w {
138 let x = rx1 + bin_w * (ow as f32 + (sx as f32 + 0.5) / sample_w as f32);
139
140 if y < -1.0 || y > feat_h as f32 || x < -1.0 || x > feat_w as f32 {
142 continue; }
144
145 let y = y.max(0.0).min((feat_h - 1) as f32);
146 let x = x.max(0.0).min((feat_w - 1) as f32);
147
148 let y_low = y.floor() as usize;
149 let x_low = x.floor() as usize;
150 let y_high = (y_low + 1).min(feat_h - 1);
151 let x_high = (x_low + 1).min(feat_w - 1);
152
153 let ly = y - y_low as f32;
154 let lx = x - x_low as f32;
155 let hy = 1.0 - ly;
156 let hx = 1.0 - lx;
157
158 let w1 = hy * hx;
159 let w2 = hy * lx;
160 let w3 = ly * hx;
161 let w4 = ly * lx;
162
163 for c in 0..channels {
164 let v1 = features.get(&[y_low, x_low, c])?;
165 let v2 = features.get(&[y_low, x_high, c])?;
166 let v3 = features.get(&[y_high, x_low, c])?;
167 let v4 = features.get(&[y_high, x_high, c])?;
168 data[base + c] += (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) / count;
169 }
170 }
171 }
172 }
173 }
174 }
175
176 Ok(Tensor::from_vec(
177 vec![num_rois, out_h, out_w, channels],
178 data,
179 )?)
180}