1use nalgebra::{DMatrix, Scalar};
9use num_traits::{self, cast::AsPrimitive, NumCast};
10use rand::Rng;
11use std::ops::{Add, Div, Mul};
12
13use crate::core::multires;
14use crate::misc::helper::div_rem;
15use crate::misc::type_aliases::Float;
16
17pub trait Number<T>:
19 Scalar
20 + Ord
21 + NumCast
22 + Add<T, Output = T>
23 + Div<T, Output = T>
24 + Mul<T, Output = T>
25 + AsPrimitive<Float>
26 + std::fmt::Display
27{
28}
29
30impl Number<u16> for u16 {}
31
32pub type Picked = u8;
35
36pub struct RegionConfig<T> {
38 pub size: usize,
40 pub threshold_coefs: (Float, T),
42}
43
44#[derive(Copy, Clone)]
46pub struct BlockConfig {
47 pub base_size: usize,
49 pub nb_levels: usize,
51 pub threshold_factor: Float,
53}
54
55pub struct RecursiveConfig {
59 pub nb_iterations_left: usize,
61 pub low_thresh: Float,
63 pub high_thresh: Float,
65 pub random_thresh: Float,
69}
70
71pub const DEFAULT_REGION_CONFIG: RegionConfig<u16> = RegionConfig {
73 size: 32,
74 threshold_coefs: (1.0, 3), };
76
77pub const DEFAULT_BLOCK_CONFIG: BlockConfig = BlockConfig {
79 base_size: 4,
80 nb_levels: 3,
81 threshold_factor: 0.5,
82};
83
84pub const DEFAULT_RECURSIVE_CONFIG: RecursiveConfig = RecursiveConfig {
86 nb_iterations_left: 1,
87 low_thresh: 0.8,
88 high_thresh: 4.0,
89 random_thresh: 1.1,
90};
91
92#[allow(clippy::cast_possible_truncation)]
96#[allow(clippy::cast_sign_loss)]
97#[allow(clippy::cast_precision_loss)]
98pub fn select<T: Number<T>>(
99 gradients: &DMatrix<T>,
100 region_config: RegionConfig<T>,
101 block_config: BlockConfig,
102 recursive_config: RecursiveConfig,
103 nb_target: usize,
104) -> DMatrix<bool> {
105 let median_gradients = region_median_gradients(gradients, region_config.size);
107 let regions_thresholds = region_thresholds(&median_gradients, region_config.threshold_coefs);
108 let (vec_nb_candidates, picked) = pick_all_block_candidates(
109 block_config,
110 region_config.size,
111 ®ions_thresholds,
112 gradients,
113 );
114 let nb_candidates: usize = vec_nb_candidates.iter().sum();
115 let candidates_ratio = nb_candidates as Float / nb_target as Float;
118 let target_size = candidates_ratio.sqrt() * (block_config.base_size as Float + 1.0) - 1.0;
125 let target_size = std::cmp::max(1, target_size.round() as i32) as usize;
126 if candidates_ratio < recursive_config.low_thresh
129 || candidates_ratio > recursive_config.high_thresh
130 {
131 if target_size != block_config.base_size && recursive_config.nb_iterations_left > 0 {
132 let mut b_config = block_config;
133 b_config.base_size = target_size;
134 let mut rec_config = recursive_config;
135 rec_config.nb_iterations_left -= 1;
136 select(gradients, region_config, b_config, rec_config, nb_target)
137 } else {
138 to_mask(&picked)
139 }
140 } else if candidates_ratio > recursive_config.random_thresh {
141 let mut rng = rand::thread_rng();
143 picked.map(|p| p > 0 && rng.gen::<u8>() <= (255.0 / candidates_ratio) as u8)
144 } else {
145 to_mask(&picked)
146 }
147}
148
149fn to_mask(picked: &DMatrix<u8>) -> DMatrix<bool> {
151 picked.map(|p| p > 0)
152}
153
154#[allow(clippy::cast_possible_truncation)]
156fn pick_all_block_candidates<T: Number<T>>(
157 block_config: BlockConfig,
158 regions_size: usize,
159 regions_thresholds: &DMatrix<T>,
160 gradients: &DMatrix<T>,
161) -> (Vec<usize>, DMatrix<Picked>) {
162 let (nb_rows, nb_cols) = gradients.shape();
163 let max_gradients_0 = init_max_gradients(gradients, block_config.base_size);
164 let max_gradients_multires =
165 multires::limited_sequence(block_config.nb_levels, max_gradients_0, |m| {
166 multires::halve(m, max_of_four_gradients)
167 });
168 let mut threshold_level_coef = 1.0;
169 let mut nb_picked = Vec::new();
170 let (blocks_rows, blocks_cols) = max_gradients_multires[0].shape();
171 let mut mask = DMatrix::repeat(blocks_rows, blocks_cols, true);
172 let mut candidates = DMatrix::repeat(nb_rows, nb_cols, 0);
173 for (level, max_gradients_level) in max_gradients_multires.iter().enumerate() {
174 let (nb_picked_level, mask_next_level, new_candidates) = pick_level_block_candidates(
176 threshold_level_coef,
177 (level + 1) as u8,
178 regions_size,
179 regions_thresholds,
180 max_gradients_level,
181 &mask,
182 candidates,
183 );
184 nb_picked.push(nb_picked_level);
185 mask = mask_next_level;
186 candidates = new_candidates;
187 threshold_level_coef *= block_config.threshold_factor;
188 }
189 (nb_picked, candidates)
190}
191
192fn init_max_gradients<T: Number<T>>(
194 gradients: &DMatrix<T>,
195 block_size: usize,
196) -> DMatrix<(T, usize, usize)> {
197 let (nb_rows, nb_cols) = gradients.shape();
198 let nb_rows_blocks = match div_rem(nb_rows, block_size) {
199 (quot, 0) => quot,
200 (quot, _) => quot + 1,
201 };
202 let nb_cols_blocks = match div_rem(nb_cols, block_size) {
203 (quot, 0) => quot,
204 (quot, _) => quot + 1,
205 };
206 DMatrix::from_fn(nb_rows_blocks, nb_cols_blocks, |bi, bj| {
207 let start_i = bi * block_size;
208 let start_j = bj * block_size;
209 let end_i = std::cmp::min(start_i + block_size, nb_rows);
210 let end_j = std::cmp::min(start_j + block_size, nb_cols);
211 let mut tmp_max = (gradients[(start_i, start_j)], start_i, start_j);
212 for j in start_j..end_j {
213 for i in start_i..end_i {
214 let g = gradients[(i, j)];
215 if g > tmp_max.0 {
216 tmp_max = (g, i, j);
217 }
218 }
219 }
220 tmp_max
221 })
222}
223
224fn max_of_four_gradients<T: Number<T>>(
226 g1: (T, usize, usize),
227 g2: (T, usize, usize),
228 g3: (T, usize, usize),
229 g4: (T, usize, usize),
230) -> (T, usize, usize) {
231 let g_max = |g_m1: (T, usize, usize), g_m2: (T, usize, usize)| {
232 if g_m1.0 < g_m2.0 {
233 g_m2
234 } else {
235 g_m1
236 }
237 };
238 g_max(g1, g_max(g2, g_max(g3, g4)))
239}
240
241fn pick_level_block_candidates<T: Number<T>>(
247 threshold_level_coef: Float,
248 level: Picked,
249 regions_size: usize,
250 regions_thresholds: &DMatrix<T>,
251 max_gradients: &DMatrix<(T, usize, usize)>,
252 mask: &DMatrix<bool>,
253 candidates: DMatrix<Picked>,
254) -> (usize, DMatrix<bool>, DMatrix<Picked>) {
255 let (mask_height, mask_width) = mask.shape();
256 let mut mask_next_level = DMatrix::repeat(mask_height / 2, mask_width / 2, true);
257 let mut candidates = candidates;
258 let mut nb_picked = 0;
259 for j in 0..(mask_width / 2 * 2) {
261 for i in 0..(mask_height / 2 * 2) {
262 if mask[(i, j)] {
263 let (g2, i_g, j_g) = max_gradients[(i, j)];
264 let threshold = regions_thresholds[(i_g / regions_size, j_g / regions_size)];
265 if g2.as_() >= threshold_level_coef * threshold.as_() {
266 mask_next_level[(i / 2, j / 2)] = false;
267 candidates[(i_g, j_g)] = level;
268 nb_picked += 1;
269 }
270 } else {
271 mask_next_level[(i / 2, j / 2)] = false;
272 }
273 }
274 }
275 (nb_picked, mask_next_level, candidates)
276}
277
278#[allow(clippy::cast_precision_loss)]
281#[allow(clippy::cast_possible_wrap)]
282#[allow(clippy::cast_sign_loss)]
283#[allow(clippy::cast_possible_truncation)]
284fn region_thresholds<T: Number<T>>(median_gradients: &DMatrix<T>, coefs: (Float, T)) -> DMatrix<T> {
285 let (nb_rows, nb_cols) = median_gradients.shape();
286 DMatrix::from_fn(nb_rows, nb_cols, |i, j| {
287 let start_i = std::cmp::max(0, i as i32 - 1) as usize;
288 let start_j = std::cmp::max(0, j as i32 - 1) as usize;
289 let end_i = std::cmp::min(nb_rows, i + 2);
290 let end_j = std::cmp::min(nb_cols, j + 2);
291 let mut sum: T = num_traits::cast(0).unwrap();
292 let mut nb_elements = 0;
293 for j in start_j..end_j {
294 for i in start_i..end_i {
295 sum = sum + median_gradients[(i, j)];
296 nb_elements += 1;
297 }
298 }
299 let (a, b) = coefs;
300 let thresh_tmp = sum.as_() / nb_elements as Float + b.as_();
301 num_traits::cast(a * thresh_tmp * thresh_tmp).expect("woops")
302 })
303}
304
305fn region_median_gradients<T: Number<T>>(gradients: &DMatrix<T>, size: usize) -> DMatrix<T> {
308 let (nb_rows, nb_cols) = gradients.shape();
309 let nb_rows_regions = match div_rem(nb_rows, size) {
310 (quot, 0) => quot,
311 (quot, _) => quot + 1,
312 };
313 let nb_cols_regions = match div_rem(nb_cols, size) {
314 (quot, 0) => quot,
315 (quot, _) => quot + 1,
316 };
317 DMatrix::from_fn(nb_rows_regions, nb_cols_regions, |i, j| {
318 let height = std::cmp::min(size, nb_rows - i * size);
319 let width = std::cmp::min(size, nb_cols - j * size);
320 let region_slice = gradients.slice((i * size, j * size), (height, width));
321 let mut region_cloned: Vec<T> = region_slice.iter().cloned().collect();
322 region_cloned.sort_unstable();
323 region_cloned[region_cloned.len() / 2]
324 })
325}