scirs2_transform/
image.rs

1//! Image processing transformers for feature extraction
2//!
3//! This module provides utilities for extracting features from images,
4//! including patch extraction, HOG features, and image normalization.
5
6use scirs2_core::ndarray::{par_azip, s, Array1, Array2, Array3, Array4};
7use std::f64::consts::PI;
8
9use crate::error::{Result, TransformError};
10
11/// Extract patches from 2D images
12pub struct PatchExtractor {
13    /// Patch size (height, width)
14    patch_size: (usize, usize),
15    /// Maximum number of patches to extract
16    max_patches: Option<usize>,
17    /// Random seed for reproducible patch selection
18    random_state: Option<u64>,
19}
20
21impl PatchExtractor {
22    /// Create a new patch extractor
23    pub fn new(_patchsize: (usize, usize)) -> Self {
24        PatchExtractor {
25            patch_size: _patchsize,
26            max_patches: None,
27            random_state: None,
28        }
29    }
30
31    /// Set maximum number of patches to extract
32    pub fn with_max_patches(mut self, maxpatches: usize) -> Self {
33        self.max_patches = Some(maxpatches);
34        self
35    }
36
37    /// Set random seed for reproducible patch selection
38    pub fn with_random_state(mut self, seed: u64) -> Self {
39        self.random_state = Some(seed);
40        self
41    }
42
43    /// Extract patches from a 2D grayscale image
44    pub fn extract_patches_2d(&self, image: &Array2<f64>) -> Result<Array3<f64>> {
45        let (img_height, img_width) = (image.shape()[0], image.shape()[1]);
46        let (patch_height, patch_width) = self.patch_size;
47
48        if patch_height > img_height || patch_width > img_width {
49            return Err(TransformError::InvalidInput(format!(
50                "Patch size ({patch_height}, {patch_width}) exceeds image size ({img_height}, {img_width})"
51            )));
52        }
53
54        let n_patches_h = img_height - patch_height + 1;
55        let n_patches_w = img_width - patch_width + 1;
56        let total_patches = n_patches_h * n_patches_w;
57
58        let n_patches = if let Some(max_p) = self.max_patches {
59            max_p.min(total_patches)
60        } else {
61            total_patches
62        };
63
64        let mut patches = Array3::zeros((n_patches, patch_height, patch_width));
65
66        if n_patches == total_patches {
67            // Extract all patches
68            let mut patch_idx = 0;
69            for i in 0..n_patches_h {
70                for j in 0..n_patches_w {
71                    let patch = image.slice(s![i..i + patch_height, j..j + patch_width]);
72                    patches.slice_mut(s![patch_idx, .., ..]).assign(&patch);
73                    patch_idx += 1;
74                }
75            }
76        } else {
77            // Random patch selection
78            use scirs2_core::random::rngs::StdRng;
79            use scirs2_core::random::{Rng, SeedableRng};
80
81            let mut rng = if let Some(seed) = self.random_state {
82                StdRng::seed_from_u64(seed)
83            } else {
84                StdRng::seed_from_u64(scirs2_core::random::random::<u64>())
85            };
86
87            for patch_idx in 0..n_patches {
88                let i = rng.gen_range(0..n_patches_h);
89                let j = rng.gen_range(0..n_patches_w);
90                let patch = image.slice(s![i..i + patch_height, j..j + patch_width]);
91                patches.slice_mut(s![patch_idx, .., ..]).assign(&patch);
92            }
93        }
94
95        Ok(patches)
96    }
97
98    /// Extract patches from a batch of 2D images
99    pub fn extract_patches_batch(&self, images: &Array3<f64>) -> Result<Array4<f64>> {
100        let n_images = images.shape()[0];
101        let (img_height, img_width) = (images.shape()[1], images.shape()[2]);
102        let (patch_height, patch_width) = self.patch_size;
103
104        if patch_height > img_height || patch_width > img_width {
105            return Err(TransformError::InvalidInput(format!(
106                "Patch size ({patch_height}, {patch_width}) exceeds image size ({img_height}, {img_width})"
107            )));
108        }
109
110        let n_patches_per_image = if let Some(max_p) = self.max_patches {
111            let total = (img_height - patch_height + 1) * (img_width - patch_width + 1);
112            max_p.min(total)
113        } else {
114            (img_height - patch_height + 1) * (img_width - patch_width + 1)
115        };
116
117        let mut all_patches =
118            Array4::zeros((n_images * n_patches_per_image, patch_height, patch_width, 1));
119
120        for (img_idx, image) in images.outer_iter().enumerate() {
121            let patches = self.extract_patches_2d(&image.to_owned())?;
122            let start_idx = img_idx * n_patches_per_image;
123
124            for (patch_idx, patch) in patches.outer_iter().enumerate() {
125                all_patches
126                    .slice_mut(s![start_idx + patch_idx, .., .., 0])
127                    .assign(&patch);
128            }
129        }
130
131        Ok(all_patches)
132    }
133}
134
135/// Histogram of Oriented Gradients (HOG) feature extractor
136pub struct HOGDescriptor {
137    /// Size of each cell in pixels (height, width)
138    cell_size: (usize, usize),
139    /// Size of each block in cells (height, width)
140    block_size: (usize, usize),
141    /// Number of orientation bins
142    n_bins: usize,
143    /// Block normalization method
144    block_norm: BlockNorm,
145}
146
147/// Block normalization methods for HOG
148#[derive(Clone, Copy, Debug)]
149pub enum BlockNorm {
150    /// L1 normalization
151    L1,
152    /// L2 normalization
153    L2,
154    /// L1-sqrt normalization
155    L1Sqrt,
156    /// L2-Hys normalization (L2 with clipping)
157    L2Hys,
158}
159
160impl HOGDescriptor {
161    /// Create a new HOG descriptor
162    pub fn new(_cellsize: (usize, usize), block_size: (usize, usize), n_bins: usize) -> Self {
163        HOGDescriptor {
164            cell_size: _cellsize,
165            block_size,
166            n_bins,
167            block_norm: BlockNorm::L2Hys,
168        }
169    }
170
171    /// Set block normalization method
172    pub fn with_block_norm(mut self, blocknorm: BlockNorm) -> Self {
173        self.block_norm = blocknorm;
174        self
175    }
176
177    /// Compute gradients for an image
178    fn compute_gradients(&self, image: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
179        let (height, width) = (image.shape()[0], image.shape()[1]);
180        let mut grad_x = Array2::zeros((height, width));
181        let mut grad_y = Array2::zeros((height, width));
182
183        // Compute x-gradients
184        for i in 0..height {
185            for j in 1..width - 1 {
186                grad_x[[i, j]] = image[[i, j + 1]] - image[[i, j - 1]];
187            }
188            // Handle boundaries
189            grad_x[[i, 0]] = image[[i, 1]] - image[[i, 0]];
190            grad_x[[i, width - 1]] = image[[i, width - 1]] - image[[i, width - 2]];
191        }
192
193        // Compute y-gradients
194        for j in 0..width {
195            for i in 1..height - 1 {
196                grad_y[[i, j]] = image[[i + 1, j]] - image[[i - 1, j]];
197            }
198            // Handle boundaries
199            grad_y[[0, j]] = image[[1, j]] - image[[0, j]];
200            grad_y[[height - 1, j]] = image[[height - 1, j]] - image[[height - 2, j]];
201        }
202
203        (grad_x, grad_y)
204    }
205
206    /// Compute HOG features for a single image
207    pub fn compute(&self, image: &Array2<f64>) -> Result<Array1<f64>> {
208        let (height, width) = (image.shape()[0], image.shape()[1]);
209        let (cell_h, cell_w) = self.cell_size;
210        let (block_h, block_w) = self.block_size;
211
212        // Compute gradients
213        let (grad_x, grad_y) = self.compute_gradients(image);
214
215        // Compute magnitude and orientation
216        let magnitude = (&grad_x * &grad_x + &grad_y * &grad_y).mapv(f64::sqrt);
217        let mut orientation = grad_y.mapv(|y| y.atan2(0.0));
218        orientation.zip_mut_with(&grad_x, |o, &x| *o = (*o).atan2(x));
219
220        // Number of cells
221        let n_cells_h = height / cell_h;
222        let n_cells_w = width / cell_w;
223
224        // Build orientation histograms for each cell
225        let mut cell_histograms = Array3::zeros((n_cells_h, n_cells_w, self.n_bins));
226        let bin_size = PI / self.n_bins as f64;
227
228        for cell_i in 0..n_cells_h {
229            for cell_j in 0..n_cells_w {
230                let start_i = cell_i * cell_h;
231                let start_j = cell_j * cell_w;
232
233                for i in start_i..start_i.min(start_i + cell_h).min(height) {
234                    for j in start_j..start_j.min(start_j + cell_w).min(width) {
235                        let mag = magnitude[[i, j]];
236                        let mut angle = orientation[[i, j]];
237
238                        // Convert to 0-pi range
239                        if angle < 0.0 {
240                            angle += PI;
241                        }
242
243                        // Compute bin indices
244                        let bin_idx = (angle / bin_size) as usize;
245                        let bin_idx = bin_idx.min(self.n_bins - 1);
246
247                        cell_histograms[[cell_i, cell_j, bin_idx]] += mag;
248                    }
249                }
250            }
251        }
252
253        // Number of blocks
254        let n_blocks_h = n_cells_h - block_h + 1;
255        let n_blocks_w = n_cells_w - block_w + 1;
256        let block_features = block_h * block_w * self.n_bins;
257
258        let mut features = Vec::with_capacity(n_blocks_h * n_blocks_w * block_features);
259
260        // Extract and normalize blocks
261        for block_i in 0..n_blocks_h {
262            for block_j in 0..n_blocks_w {
263                let mut block_hist = Vec::with_capacity(block_features);
264
265                // Collect histograms from cells in this block
266                for i in 0..block_h {
267                    for j in 0..block_w {
268                        let cell_hist = cell_histograms.slice(s![block_i + i, block_j + j, ..]);
269                        block_hist.extend(cell_hist.iter());
270                    }
271                }
272
273                // Normalize block
274                let block_hist = self.normalize_block(&block_hist);
275                features.extend(block_hist);
276            }
277        }
278
279        Ok(Array1::from_vec(features))
280    }
281
282    /// Normalize a block histogram
283    fn normalize_block(&self, hist: &[f64]) -> Vec<f64> {
284        let epsilon = 1e-8;
285
286        match self.block_norm {
287            BlockNorm::L1 => {
288                let norm: f64 = hist.iter().sum::<f64>() + epsilon;
289                hist.iter().map(|&v| v / norm).collect()
290            }
291            BlockNorm::L2 => {
292                let norm = hist.iter().map(|&v| v * v).sum::<f64>().sqrt() + epsilon;
293                hist.iter().map(|&v| v / norm).collect()
294            }
295            BlockNorm::L1Sqrt => {
296                let norm: f64 = hist.iter().sum::<f64>() + epsilon;
297                hist.iter().map(|&v| (v / norm).sqrt()).collect()
298            }
299            BlockNorm::L2Hys => {
300                // L2 normalization with clipping
301                let mut norm = hist.iter().map(|&v| v * v).sum::<f64>().sqrt() + epsilon;
302                let mut normalized: Vec<f64> = hist.iter().map(|&v| v / norm).collect();
303
304                // Clip values to 0.2
305                let clip_val = 0.2;
306                for v in &mut normalized {
307                    if *v > clip_val {
308                        *v = clip_val;
309                    }
310                }
311
312                // Re-normalize
313                norm = normalized.iter().map(|&v| v * v).sum::<f64>().sqrt() + epsilon;
314                normalized.iter_mut().for_each(|v| *v /= norm);
315
316                normalized
317            }
318        }
319    }
320}
321
322/// Image normalization transformer
323pub struct ImageNormalizer {
324    /// Normalization method
325    method: ImageNormMethod,
326    /// Channel-wise statistics (mean, std) for standardization
327    channel_stats: Option<(Array1<f64>, Array1<f64>)>,
328}
329
330/// Image normalization methods
331#[derive(Clone, Copy, Debug)]
332pub enum ImageNormMethod {
333    /// Min-max normalization to [0, 1]
334    MinMax,
335    /// Standardization (zero mean, unit variance)
336    Standard,
337    /// Normalization to [-1, 1]
338    Symmetric,
339    /// Custom range normalization
340    Range(f64, f64),
341}
342
343impl ImageNormalizer {
344    /// Create a new image normalizer
345    pub fn new(method: ImageNormMethod) -> Self {
346        ImageNormalizer {
347            method,
348            channel_stats: None,
349        }
350    }
351
352    /// Fit the normalizer on a batch of images
353    pub fn fit(&mut self, images: &Array4<f64>) -> Result<()> {
354        if let ImageNormMethod::Standard = self.method {
355            let n_channels = images.shape()[3];
356            let mut means = Array1::zeros(n_channels);
357            let mut stds = Array1::zeros(n_channels);
358
359            // Compute channel-wise statistics
360            for c in 0..n_channels {
361                let channel_data = images.slice(s![.., .., .., c]);
362                let flat_data: Vec<f64> = channel_data.iter().cloned().collect();
363
364                let mean = flat_data.iter().sum::<f64>() / flat_data.len() as f64;
365                let variance = flat_data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
366                    / flat_data.len() as f64;
367
368                means[c] = mean;
369                stds[c] = variance.sqrt();
370            }
371
372            self.channel_stats = Some((means, stds));
373        }
374
375        Ok(())
376    }
377
378    /// Transform images
379    pub fn transform(&self, images: &Array4<f64>) -> Result<Array4<f64>> {
380        let mut result = images.clone();
381
382        match self.method {
383            ImageNormMethod::MinMax => {
384                // Normalize each image independently to [0, 1]
385                for mut image in result.outer_iter_mut() {
386                    let min = image.iter().cloned().fold(f64::INFINITY, f64::min);
387                    let max = image.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
388                    let range = max - min;
389
390                    if range > 0.0 {
391                        image.mapv_inplace(|v| (v - min) / range);
392                    }
393                }
394            }
395            ImageNormMethod::Standard => {
396                if let Some((ref means, ref stds)) = self.channel_stats {
397                    let n_channels = images.shape()[3];
398
399                    for c in 0..n_channels {
400                        let mean = means[c];
401                        let std = stds[c].max(1e-8); // Avoid division by zero
402
403                        result
404                            .slice_mut(s![.., .., .., c])
405                            .mapv_inplace(|v| (v - mean) / std);
406                    }
407                } else {
408                    return Err(TransformError::NotFitted(
409                        "ImageNormalizer must be fitted before transform".into(),
410                    ));
411                }
412            }
413            ImageNormMethod::Symmetric => {
414                // Normalize to [-1, 1]
415                for mut image in result.outer_iter_mut() {
416                    let min = image.iter().cloned().fold(f64::INFINITY, f64::min);
417                    let max = image.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
418                    let range = max - min;
419
420                    if range > 0.0 {
421                        image.mapv_inplace(|v| 2.0 * (v - min) / range - 1.0);
422                    }
423                }
424            }
425            ImageNormMethod::Range(new_min, new_max) => {
426                let new_range = new_max - new_min;
427
428                for mut image in result.outer_iter_mut() {
429                    let min = image.iter().cloned().fold(f64::INFINITY, f64::min);
430                    let max = image.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
431                    let range = max - min;
432
433                    if range > 0.0 {
434                        image.mapv_inplace(|v| new_min + new_range * (v - min) / range);
435                    }
436                }
437            }
438        }
439
440        Ok(result)
441    }
442
443    /// Fit and transform in one step
444    pub fn fit_transform(&mut self, images: &Array4<f64>) -> Result<Array4<f64>> {
445        self.fit(images)?;
446        self.transform(images)
447    }
448}
449
450/// Convert RGB images to grayscale
451#[allow(dead_code)]
452pub fn rgb_to_grayscale(images: &Array4<f64>) -> Result<Array3<f64>> {
453    let shape = images.shape();
454    if shape[3] != 3 {
455        return Err(TransformError::InvalidInput(format!(
456            "Expected 3 channels for RGB, got {}",
457            shape[3]
458        )));
459    }
460
461    let (n_samples, height, width) = (shape[0], shape[1], shape[2]);
462    let mut grayscale = Array3::zeros((n_samples, height, width));
463
464    // Use standard RGB to grayscale conversion weights
465    let weights = [0.2989, 0.5870, 0.1140];
466
467    par_azip!((mut gray in grayscale.outer_iter_mut(),
468               rgb in images.outer_iter()) {
469        for i in 0..height {
470            for j in 0..width {
471                gray[[i, j]] = weights[0] * rgb[[i, j, 0]]
472                             + weights[1] * rgb[[i, j, 1]]
473                             + weights[2] * rgb[[i, j, 2]];
474            }
475        }
476    });
477
478    Ok(grayscale)
479}
480
481/// Resize images using bilinear interpolation
482#[allow(dead_code)]
483pub fn resize_images(images: &Array4<f64>, newsize: (usize, usize)) -> Result<Array4<f64>> {
484    let (n_samples, old_h, old_w, n_channels) = {
485        let shape = images.shape();
486        (shape[0], shape[1], shape[2], shape[3])
487    };
488    let (new_h, new_w) = newsize;
489
490    let mut resized = Array4::zeros((n_samples, new_h, new_w, n_channels));
491
492    let scale_h = old_h as f64 / new_h as f64;
493    let scale_w = old_w as f64 / new_w as f64;
494
495    par_azip!((mut resized_img in resized.outer_iter_mut(),
496               original_img in images.outer_iter()) {
497        for i in 0..new_h {
498            for j in 0..new_w {
499                // Map to original coordinates
500                let orig_i = i as f64 * scale_h;
501                let orig_j = j as f64 * scale_w;
502
503                // Get integer parts and fractions
504                let i0 = orig_i.floor() as usize;
505                let j0 = orig_j.floor() as usize;
506                let i1 = (i0 + 1).min(old_h - 1);
507                let j1 = (j0 + 1).min(old_w - 1);
508
509                let di = orig_i - i0 as f64;
510                let dj = orig_j - j0 as f64;
511
512                // Bilinear interpolation for each channel
513                for c in 0..n_channels {
514                    let v00 = original_img[[i0, j0, c]];
515                    let v01 = original_img[[i0, j1, c]];
516                    let v10 = original_img[[i1, j0, c]];
517                    let v11 = original_img[[i1, j1, c]];
518
519                    let v0 = v00 * (1.0 - dj) + v01 * dj;
520                    let v1 = v10 * (1.0 - dj) + v11 * dj;
521                    let v = v0 * (1.0 - di) + v1 * di;
522
523                    resized_img[[i, j, c]] = v;
524                }
525            }
526        }
527    });
528
529    Ok(resized)
530}