1use scirs2_core::ndarray::{par_azip, s, Array1, Array2, Array3, Array4};
7use std::f64::consts::PI;
8
9use crate::error::{Result, TransformError};
10
11pub struct PatchExtractor {
13 patch_size: (usize, usize),
15 max_patches: Option<usize>,
17 random_state: Option<u64>,
19}
20
21impl PatchExtractor {
22 pub fn new(_patchsize: (usize, usize)) -> Self {
24 PatchExtractor {
25 patch_size: _patchsize,
26 max_patches: None,
27 random_state: None,
28 }
29 }
30
31 pub fn with_max_patches(mut self, maxpatches: usize) -> Self {
33 self.max_patches = Some(maxpatches);
34 self
35 }
36
37 pub fn with_random_state(mut self, seed: u64) -> Self {
39 self.random_state = Some(seed);
40 self
41 }
42
43 pub fn extract_patches_2d(&self, image: &Array2<f64>) -> Result<Array3<f64>> {
45 let (img_height, img_width) = (image.shape()[0], image.shape()[1]);
46 let (patch_height, patch_width) = self.patch_size;
47
48 if patch_height > img_height || patch_width > img_width {
49 return Err(TransformError::InvalidInput(format!(
50 "Patch size ({patch_height}, {patch_width}) exceeds image size ({img_height}, {img_width})"
51 )));
52 }
53
54 let n_patches_h = img_height - patch_height + 1;
55 let n_patches_w = img_width - patch_width + 1;
56 let total_patches = n_patches_h * n_patches_w;
57
58 let n_patches = if let Some(max_p) = self.max_patches {
59 max_p.min(total_patches)
60 } else {
61 total_patches
62 };
63
64 let mut patches = Array3::zeros((n_patches, patch_height, patch_width));
65
66 if n_patches == total_patches {
67 let mut patch_idx = 0;
69 for i in 0..n_patches_h {
70 for j in 0..n_patches_w {
71 let patch = image.slice(s![i..i + patch_height, j..j + patch_width]);
72 patches.slice_mut(s![patch_idx, .., ..]).assign(&patch);
73 patch_idx += 1;
74 }
75 }
76 } else {
77 use scirs2_core::random::rngs::StdRng;
79 use scirs2_core::random::{Rng, SeedableRng};
80
81 let mut rng = if let Some(seed) = self.random_state {
82 StdRng::seed_from_u64(seed)
83 } else {
84 StdRng::seed_from_u64(scirs2_core::random::random::<u64>())
85 };
86
87 for patch_idx in 0..n_patches {
88 let i = rng.gen_range(0..n_patches_h);
89 let j = rng.gen_range(0..n_patches_w);
90 let patch = image.slice(s![i..i + patch_height, j..j + patch_width]);
91 patches.slice_mut(s![patch_idx, .., ..]).assign(&patch);
92 }
93 }
94
95 Ok(patches)
96 }
97
98 pub fn extract_patches_batch(&self, images: &Array3<f64>) -> Result<Array4<f64>> {
100 let n_images = images.shape()[0];
101 let (img_height, img_width) = (images.shape()[1], images.shape()[2]);
102 let (patch_height, patch_width) = self.patch_size;
103
104 if patch_height > img_height || patch_width > img_width {
105 return Err(TransformError::InvalidInput(format!(
106 "Patch size ({patch_height}, {patch_width}) exceeds image size ({img_height}, {img_width})"
107 )));
108 }
109
110 let n_patches_per_image = if let Some(max_p) = self.max_patches {
111 let total = (img_height - patch_height + 1) * (img_width - patch_width + 1);
112 max_p.min(total)
113 } else {
114 (img_height - patch_height + 1) * (img_width - patch_width + 1)
115 };
116
117 let mut all_patches =
118 Array4::zeros((n_images * n_patches_per_image, patch_height, patch_width, 1));
119
120 for (img_idx, image) in images.outer_iter().enumerate() {
121 let patches = self.extract_patches_2d(&image.to_owned())?;
122 let start_idx = img_idx * n_patches_per_image;
123
124 for (patch_idx, patch) in patches.outer_iter().enumerate() {
125 all_patches
126 .slice_mut(s![start_idx + patch_idx, .., .., 0])
127 .assign(&patch);
128 }
129 }
130
131 Ok(all_patches)
132 }
133}
134
135pub struct HOGDescriptor {
137 cell_size: (usize, usize),
139 block_size: (usize, usize),
141 n_bins: usize,
143 block_norm: BlockNorm,
145}
146
147#[derive(Clone, Copy, Debug)]
149pub enum BlockNorm {
150 L1,
152 L2,
154 L1Sqrt,
156 L2Hys,
158}
159
160impl HOGDescriptor {
161 pub fn new(_cellsize: (usize, usize), block_size: (usize, usize), n_bins: usize) -> Self {
163 HOGDescriptor {
164 cell_size: _cellsize,
165 block_size,
166 n_bins,
167 block_norm: BlockNorm::L2Hys,
168 }
169 }
170
171 pub fn with_block_norm(mut self, blocknorm: BlockNorm) -> Self {
173 self.block_norm = blocknorm;
174 self
175 }
176
177 fn compute_gradients(&self, image: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
179 let (height, width) = (image.shape()[0], image.shape()[1]);
180 let mut grad_x = Array2::zeros((height, width));
181 let mut grad_y = Array2::zeros((height, width));
182
183 for i in 0..height {
185 for j in 1..width - 1 {
186 grad_x[[i, j]] = image[[i, j + 1]] - image[[i, j - 1]];
187 }
188 grad_x[[i, 0]] = image[[i, 1]] - image[[i, 0]];
190 grad_x[[i, width - 1]] = image[[i, width - 1]] - image[[i, width - 2]];
191 }
192
193 for j in 0..width {
195 for i in 1..height - 1 {
196 grad_y[[i, j]] = image[[i + 1, j]] - image[[i - 1, j]];
197 }
198 grad_y[[0, j]] = image[[1, j]] - image[[0, j]];
200 grad_y[[height - 1, j]] = image[[height - 1, j]] - image[[height - 2, j]];
201 }
202
203 (grad_x, grad_y)
204 }
205
206 pub fn compute(&self, image: &Array2<f64>) -> Result<Array1<f64>> {
208 let (height, width) = (image.shape()[0], image.shape()[1]);
209 let (cell_h, cell_w) = self.cell_size;
210 let (block_h, block_w) = self.block_size;
211
212 let (grad_x, grad_y) = self.compute_gradients(image);
214
215 let magnitude = (&grad_x * &grad_x + &grad_y * &grad_y).mapv(f64::sqrt);
217 let mut orientation = grad_y.mapv(|y| y.atan2(0.0));
218 orientation.zip_mut_with(&grad_x, |o, &x| *o = (*o).atan2(x));
219
220 let n_cells_h = height / cell_h;
222 let n_cells_w = width / cell_w;
223
224 let mut cell_histograms = Array3::zeros((n_cells_h, n_cells_w, self.n_bins));
226 let bin_size = PI / self.n_bins as f64;
227
228 for cell_i in 0..n_cells_h {
229 for cell_j in 0..n_cells_w {
230 let start_i = cell_i * cell_h;
231 let start_j = cell_j * cell_w;
232
233 for i in start_i..start_i.min(start_i + cell_h).min(height) {
234 for j in start_j..start_j.min(start_j + cell_w).min(width) {
235 let mag = magnitude[[i, j]];
236 let mut angle = orientation[[i, j]];
237
238 if angle < 0.0 {
240 angle += PI;
241 }
242
243 let bin_idx = (angle / bin_size) as usize;
245 let bin_idx = bin_idx.min(self.n_bins - 1);
246
247 cell_histograms[[cell_i, cell_j, bin_idx]] += mag;
248 }
249 }
250 }
251 }
252
253 let n_blocks_h = n_cells_h - block_h + 1;
255 let n_blocks_w = n_cells_w - block_w + 1;
256 let block_features = block_h * block_w * self.n_bins;
257
258 let mut features = Vec::with_capacity(n_blocks_h * n_blocks_w * block_features);
259
260 for block_i in 0..n_blocks_h {
262 for block_j in 0..n_blocks_w {
263 let mut block_hist = Vec::with_capacity(block_features);
264
265 for i in 0..block_h {
267 for j in 0..block_w {
268 let cell_hist = cell_histograms.slice(s![block_i + i, block_j + j, ..]);
269 block_hist.extend(cell_hist.iter());
270 }
271 }
272
273 let block_hist = self.normalize_block(&block_hist);
275 features.extend(block_hist);
276 }
277 }
278
279 Ok(Array1::from_vec(features))
280 }
281
282 fn normalize_block(&self, hist: &[f64]) -> Vec<f64> {
284 let epsilon = 1e-8;
285
286 match self.block_norm {
287 BlockNorm::L1 => {
288 let norm: f64 = hist.iter().sum::<f64>() + epsilon;
289 hist.iter().map(|&v| v / norm).collect()
290 }
291 BlockNorm::L2 => {
292 let norm = hist.iter().map(|&v| v * v).sum::<f64>().sqrt() + epsilon;
293 hist.iter().map(|&v| v / norm).collect()
294 }
295 BlockNorm::L1Sqrt => {
296 let norm: f64 = hist.iter().sum::<f64>() + epsilon;
297 hist.iter().map(|&v| (v / norm).sqrt()).collect()
298 }
299 BlockNorm::L2Hys => {
300 let mut norm = hist.iter().map(|&v| v * v).sum::<f64>().sqrt() + epsilon;
302 let mut normalized: Vec<f64> = hist.iter().map(|&v| v / norm).collect();
303
304 let clip_val = 0.2;
306 for v in &mut normalized {
307 if *v > clip_val {
308 *v = clip_val;
309 }
310 }
311
312 norm = normalized.iter().map(|&v| v * v).sum::<f64>().sqrt() + epsilon;
314 normalized.iter_mut().for_each(|v| *v /= norm);
315
316 normalized
317 }
318 }
319 }
320}
321
322pub struct ImageNormalizer {
324 method: ImageNormMethod,
326 channel_stats: Option<(Array1<f64>, Array1<f64>)>,
328}
329
330#[derive(Clone, Copy, Debug)]
332pub enum ImageNormMethod {
333 MinMax,
335 Standard,
337 Symmetric,
339 Range(f64, f64),
341}
342
343impl ImageNormalizer {
344 pub fn new(method: ImageNormMethod) -> Self {
346 ImageNormalizer {
347 method,
348 channel_stats: None,
349 }
350 }
351
352 pub fn fit(&mut self, images: &Array4<f64>) -> Result<()> {
354 if let ImageNormMethod::Standard = self.method {
355 let n_channels = images.shape()[3];
356 let mut means = Array1::zeros(n_channels);
357 let mut stds = Array1::zeros(n_channels);
358
359 for c in 0..n_channels {
361 let channel_data = images.slice(s![.., .., .., c]);
362 let flat_data: Vec<f64> = channel_data.iter().cloned().collect();
363
364 let mean = flat_data.iter().sum::<f64>() / flat_data.len() as f64;
365 let variance = flat_data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
366 / flat_data.len() as f64;
367
368 means[c] = mean;
369 stds[c] = variance.sqrt();
370 }
371
372 self.channel_stats = Some((means, stds));
373 }
374
375 Ok(())
376 }
377
378 pub fn transform(&self, images: &Array4<f64>) -> Result<Array4<f64>> {
380 let mut result = images.clone();
381
382 match self.method {
383 ImageNormMethod::MinMax => {
384 for mut image in result.outer_iter_mut() {
386 let min = image.iter().cloned().fold(f64::INFINITY, f64::min);
387 let max = image.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
388 let range = max - min;
389
390 if range > 0.0 {
391 image.mapv_inplace(|v| (v - min) / range);
392 }
393 }
394 }
395 ImageNormMethod::Standard => {
396 if let Some((ref means, ref stds)) = self.channel_stats {
397 let n_channels = images.shape()[3];
398
399 for c in 0..n_channels {
400 let mean = means[c];
401 let std = stds[c].max(1e-8); result
404 .slice_mut(s![.., .., .., c])
405 .mapv_inplace(|v| (v - mean) / std);
406 }
407 } else {
408 return Err(TransformError::NotFitted(
409 "ImageNormalizer must be fitted before transform".into(),
410 ));
411 }
412 }
413 ImageNormMethod::Symmetric => {
414 for mut image in result.outer_iter_mut() {
416 let min = image.iter().cloned().fold(f64::INFINITY, f64::min);
417 let max = image.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
418 let range = max - min;
419
420 if range > 0.0 {
421 image.mapv_inplace(|v| 2.0 * (v - min) / range - 1.0);
422 }
423 }
424 }
425 ImageNormMethod::Range(new_min, new_max) => {
426 let new_range = new_max - new_min;
427
428 for mut image in result.outer_iter_mut() {
429 let min = image.iter().cloned().fold(f64::INFINITY, f64::min);
430 let max = image.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
431 let range = max - min;
432
433 if range > 0.0 {
434 image.mapv_inplace(|v| new_min + new_range * (v - min) / range);
435 }
436 }
437 }
438 }
439
440 Ok(result)
441 }
442
443 pub fn fit_transform(&mut self, images: &Array4<f64>) -> Result<Array4<f64>> {
445 self.fit(images)?;
446 self.transform(images)
447 }
448}
449
450#[allow(dead_code)]
452pub fn rgb_to_grayscale(images: &Array4<f64>) -> Result<Array3<f64>> {
453 let shape = images.shape();
454 if shape[3] != 3 {
455 return Err(TransformError::InvalidInput(format!(
456 "Expected 3 channels for RGB, got {}",
457 shape[3]
458 )));
459 }
460
461 let (n_samples, height, width) = (shape[0], shape[1], shape[2]);
462 let mut grayscale = Array3::zeros((n_samples, height, width));
463
464 let weights = [0.2989, 0.5870, 0.1140];
466
467 par_azip!((mut gray in grayscale.outer_iter_mut(),
468 rgb in images.outer_iter()) {
469 for i in 0..height {
470 for j in 0..width {
471 gray[[i, j]] = weights[0] * rgb[[i, j, 0]]
472 + weights[1] * rgb[[i, j, 1]]
473 + weights[2] * rgb[[i, j, 2]];
474 }
475 }
476 });
477
478 Ok(grayscale)
479}
480
481#[allow(dead_code)]
483pub fn resize_images(images: &Array4<f64>, newsize: (usize, usize)) -> Result<Array4<f64>> {
484 let (n_samples, old_h, old_w, n_channels) = {
485 let shape = images.shape();
486 (shape[0], shape[1], shape[2], shape[3])
487 };
488 let (new_h, new_w) = newsize;
489
490 let mut resized = Array4::zeros((n_samples, new_h, new_w, n_channels));
491
492 let scale_h = old_h as f64 / new_h as f64;
493 let scale_w = old_w as f64 / new_w as f64;
494
495 par_azip!((mut resized_img in resized.outer_iter_mut(),
496 original_img in images.outer_iter()) {
497 for i in 0..new_h {
498 for j in 0..new_w {
499 let orig_i = i as f64 * scale_h;
501 let orig_j = j as f64 * scale_w;
502
503 let i0 = orig_i.floor() as usize;
505 let j0 = orig_j.floor() as usize;
506 let i1 = (i0 + 1).min(old_h - 1);
507 let j1 = (j0 + 1).min(old_w - 1);
508
509 let di = orig_i - i0 as f64;
510 let dj = orig_j - j0 as f64;
511
512 for c in 0..n_channels {
514 let v00 = original_img[[i0, j0, c]];
515 let v01 = original_img[[i0, j1, c]];
516 let v10 = original_img[[i1, j0, c]];
517 let v11 = original_img[[i1, j1, c]];
518
519 let v0 = v00 * (1.0 - dj) + v01 * dj;
520 let v1 = v10 * (1.0 - dj) + v11 * dj;
521 let v = v0 * (1.0 - di) + v1 * di;
522
523 resized_img[[i, j, c]] = v;
524 }
525 }
526 }
527 });
528
529 Ok(resized)
530}