1use scirs2_core::ndarray::{Array3, Axis};
53use scirs2_core::random::thread_rng;
54use sklears_core::{
57 error::{Result, SklearsError},
58 traits::{Estimator, Fit, Transform, Untrained},
59 types::Float,
60};
61use std::f64::consts::PI;
62
63#[cfg(feature = "serde")]
64use serde::{Deserialize, Serialize};
65
66#[derive(Debug, Clone, Copy, PartialEq)]
68#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
69pub enum ImageNormalizationStrategy {
70 MinMax,
72 StandardScore,
74 CustomRange(Float, Float),
76}
77
78impl Default for ImageNormalizationStrategy {
79 fn default() -> Self {
80 Self::MinMax
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
87pub enum ColorSpace {
88 RGB,
90 HSV,
92 LAB,
94 Grayscale,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
101pub enum InterpolationMethod {
102 Nearest,
104 Bilinear,
106 Bicubic,
108}
109
110impl Default for InterpolationMethod {
111 fn default() -> Self {
112 Self::Bilinear
113 }
114}
115
116#[derive(Debug, Clone)]
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119pub struct ImageNormalizerConfig {
120 pub strategy: ImageNormalizationStrategy,
122 pub channel_wise: bool,
124 pub epsilon: Float,
126}
127
128impl Default for ImageNormalizerConfig {
129 fn default() -> Self {
130 Self {
131 strategy: ImageNormalizationStrategy::MinMax,
132 channel_wise: true,
133 epsilon: 1e-8,
134 }
135 }
136}
137
138#[derive(Debug, Clone)]
140#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
141pub struct ImageNormalizer<State = Untrained> {
142 config: ImageNormalizerConfig,
143 state: std::marker::PhantomData<State>,
144}
145
146#[derive(Debug, Clone)]
148#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
149pub struct ImageNormalizerFitted {
150 config: ImageNormalizerConfig,
151 min_vals: Vec<Float>,
152 max_vals: Vec<Float>,
153 mean_vals: Vec<Float>,
154 std_vals: Vec<Float>,
155}
156
157impl ImageNormalizer<Untrained> {
158 pub fn new() -> Self {
160 Self {
161 config: ImageNormalizerConfig::default(),
162 state: std::marker::PhantomData,
163 }
164 }
165
166 pub fn with_strategy(mut self, strategy: ImageNormalizationStrategy) -> Self {
168 self.config.strategy = strategy;
169 self
170 }
171
172 pub fn with_range(mut self, range: (Float, Float)) -> Self {
174 self.config.strategy = ImageNormalizationStrategy::CustomRange(range.0, range.1);
175 self
176 }
177
178 pub fn with_channel_wise(mut self, channel_wise: bool) -> Self {
180 self.config.channel_wise = channel_wise;
181 self
182 }
183
184 pub fn with_epsilon(mut self, epsilon: Float) -> Self {
186 self.config.epsilon = epsilon;
187 self
188 }
189}
190
191impl Estimator for ImageNormalizer<Untrained> {
192 type Config = ImageNormalizerConfig;
193 type Error = SklearsError;
194 type Float = Float;
195
196 fn config(&self) -> &Self::Config {
197 &self.config
198 }
199}
200
201impl Fit<Array3<Float>, ()> for ImageNormalizer<Untrained> {
202 type Fitted = ImageNormalizerFitted;
203
204 fn fit(self, x: &Array3<Float>, _y: &()) -> Result<Self::Fitted> {
205 let shape = x.dim();
206 let n_channels = shape.2;
207
208 let (min_vals, max_vals, mean_vals, std_vals) = if self.config.channel_wise {
209 let mut min_vals = Vec::with_capacity(n_channels);
210 let mut max_vals = Vec::with_capacity(n_channels);
211 let mut mean_vals = Vec::with_capacity(n_channels);
212 let mut std_vals = Vec::with_capacity(n_channels);
213
214 for channel in 0..n_channels {
215 let channel_data = x.index_axis(Axis(2), channel);
216 let data_slice: Vec<Float> = channel_data.iter().copied().collect();
217
218 let min_val = data_slice.iter().fold(Float::INFINITY, |a, &b| a.min(b));
219 let max_val = data_slice
220 .iter()
221 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
222
223 let mean_val = data_slice.iter().sum::<Float>() / data_slice.len() as Float;
224
225 let var_val = data_slice
226 .iter()
227 .map(|&x| (x - mean_val).powi(2))
228 .sum::<Float>()
229 / (data_slice.len() as Float - 1.0);
230
231 let std_val = var_val.sqrt().max(self.config.epsilon);
232
233 min_vals.push(min_val);
234 max_vals.push(max_val);
235 mean_vals.push(mean_val);
236 std_vals.push(std_val);
237 }
238
239 (min_vals, max_vals, mean_vals, std_vals)
240 } else {
241 let all_data: Vec<Float> = x.iter().copied().collect();
243
244 let min_val = all_data.iter().fold(Float::INFINITY, |a, &b| a.min(b));
245 let max_val = all_data.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b));
246
247 let mean_val = all_data.iter().sum::<Float>() / all_data.len() as Float;
248
249 let var_val = all_data
250 .iter()
251 .map(|&x| (x - mean_val).powi(2))
252 .sum::<Float>()
253 / (all_data.len() as Float - 1.0);
254
255 let std_val = var_val.sqrt().max(self.config.epsilon);
256
257 (
258 vec![min_val; n_channels],
259 vec![max_val; n_channels],
260 vec![mean_val; n_channels],
261 vec![std_val; n_channels],
262 )
263 };
264
265 Ok(ImageNormalizerFitted {
266 config: self.config,
267 min_vals,
268 max_vals,
269 mean_vals,
270 std_vals,
271 })
272 }
273}
274
275impl Transform<Array3<Float>, Array3<Float>> for ImageNormalizerFitted {
276 fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
277 let shape = x.dim();
278 let n_channels = shape.2;
279
280 if n_channels != self.min_vals.len() {
281 return Err(SklearsError::InvalidInput(format!(
282 "Expected {} channels, got {}",
283 self.min_vals.len(),
284 n_channels
285 )));
286 }
287
288 let mut result = x.clone();
289
290 match self.config.strategy {
291 ImageNormalizationStrategy::MinMax => {
292 for channel in 0..n_channels {
293 let min_val = self.min_vals[channel];
294 let max_val = self.max_vals[channel];
295 let range = max_val - min_val;
296
297 if range > self.config.epsilon {
298 let mut channel_data = result.index_axis_mut(Axis(2), channel);
299 channel_data.mapv_inplace(|x| (x - min_val) / range);
300 }
301 }
302 }
303 ImageNormalizationStrategy::CustomRange(min_target, max_target) => {
304 let target_range = max_target - min_target;
305 for channel in 0..n_channels {
306 let min_val = self.min_vals[channel];
307 let max_val = self.max_vals[channel];
308 let source_range = max_val - min_val;
309
310 if source_range > self.config.epsilon {
311 let mut channel_data = result.index_axis_mut(Axis(2), channel);
312 channel_data.mapv_inplace(|x| {
313 min_target + ((x - min_val) / source_range) * target_range
314 });
315 }
316 }
317 }
318 ImageNormalizationStrategy::StandardScore => {
319 for channel in 0..n_channels {
320 let mean_val = self.mean_vals[channel];
321 let std_val = self.std_vals[channel];
322
323 let mut channel_data = result.index_axis_mut(Axis(2), channel);
324 channel_data.mapv_inplace(|x| (x - mean_val) / std_val);
325 }
326 }
327 }
328
329 Ok(result)
330 }
331}
332
333#[derive(Debug, Clone)]
335#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
336pub struct ImageAugmenterConfig {
337 pub rotation_range: Option<(Float, Float)>,
339 pub zoom_range: Option<(Float, Float)>,
341 pub width_shift_range: Option<Float>,
343 pub height_shift_range: Option<Float>,
345 pub horizontal_flip: bool,
347 pub vertical_flip: bool,
349 pub brightness_range: Option<(Float, Float)>,
351 pub random_seed: Option<u64>,
353}
354
355impl Default for ImageAugmenterConfig {
356 fn default() -> Self {
357 Self {
358 rotation_range: None,
359 zoom_range: None,
360 width_shift_range: None,
361 height_shift_range: None,
362 horizontal_flip: false,
363 vertical_flip: false,
364 brightness_range: None,
365 random_seed: None,
366 }
367 }
368}
369
370#[derive(Debug, Clone)]
372#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
373pub struct ImageAugmenter {
374 config: ImageAugmenterConfig,
375}
376
377impl ImageAugmenter {
378 pub fn new() -> Self {
380 Self {
381 config: ImageAugmenterConfig::default(),
382 }
383 }
384
385 pub fn with_rotation_range(mut self, range: (Float, Float)) -> Self {
387 self.config.rotation_range = Some(range);
388 self
389 }
390
391 pub fn with_zoom_range(mut self, range: (Float, Float)) -> Self {
393 self.config.zoom_range = Some(range);
394 self
395 }
396
397 pub fn with_width_shift_range(mut self, range: Float) -> Self {
399 self.config.width_shift_range = Some(range);
400 self
401 }
402
403 pub fn with_height_shift_range(mut self, range: Float) -> Self {
405 self.config.height_shift_range = Some(range);
406 self
407 }
408
409 pub fn with_horizontal_flip(mut self, enabled: bool) -> Self {
411 self.config.horizontal_flip = enabled;
412 self
413 }
414
415 pub fn with_vertical_flip(mut self, enabled: bool) -> Self {
417 self.config.vertical_flip = enabled;
418 self
419 }
420
421 pub fn with_brightness_range(mut self, range: (Float, Float)) -> Self {
423 self.config.brightness_range = Some(range);
424 self
425 }
426
427 pub fn with_random_seed(mut self, seed: u64) -> Self {
429 self.config.random_seed = Some(seed);
430 self
432 }
433}
434
435impl Transform<Array3<Float>, Array3<Float>> for ImageAugmenter {
436 fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
437 let mut result = x.clone();
438 let mut rng = thread_rng();
439
440 if self.config.horizontal_flip && rng.random::<Float>() < 0.5 {
442 result = self.horizontal_flip(&result)?;
443 }
444
445 if self.config.vertical_flip && rng.random::<Float>() < 0.5 {
447 result = self.vertical_flip(&result)?;
448 }
449
450 if let Some((min_angle, max_angle)) = self.config.rotation_range {
452 let angle = rng.gen_range(min_angle..max_angle);
453 if angle.abs() > 1e-6 {
454 result = self.rotate(&result, angle)?;
455 }
456 }
457
458 if let Some((min_brightness, max_brightness)) = self.config.brightness_range {
460 let brightness_factor = rng.gen_range(min_brightness..max_brightness);
461 if (brightness_factor - 1.0).abs() > 1e-6 {
462 result.mapv_inplace(|x| (x * brightness_factor).clamp(0.0, 1.0));
463 }
464 }
465
466 Ok(result)
467 }
468}
469
470impl ImageAugmenter {
471 fn horizontal_flip(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
473 let shape = image.dim();
474 let mut result = Array3::zeros(shape);
475
476 for row in 0..shape.0 {
477 for col in 0..shape.1 {
478 for channel in 0..shape.2 {
479 result[[row, shape.1 - 1 - col, channel]] = image[[row, col, channel]];
480 }
481 }
482 }
483
484 Ok(result)
485 }
486
487 fn vertical_flip(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
489 let shape = image.dim();
490 let mut result = Array3::zeros(shape);
491
492 for row in 0..shape.0 {
493 for col in 0..shape.1 {
494 for channel in 0..shape.2 {
495 result[[shape.0 - 1 - row, col, channel]] = image[[row, col, channel]];
496 }
497 }
498 }
499
500 Ok(result)
501 }
502
503 fn rotate(&self, image: &Array3<Float>, angle_degrees: Float) -> Result<Array3<Float>> {
505 let shape = image.dim();
506 let mut result = Array3::zeros(shape);
507
508 let angle_rad = angle_degrees * PI / 180.0;
509 let cos_angle = angle_rad.cos();
510 let sin_angle = angle_rad.sin();
511
512 let center_x = shape.1 as Float / 2.0;
513 let center_y = shape.0 as Float / 2.0;
514
515 for row in 0..shape.0 {
517 for col in 0..shape.1 {
518 let x = col as Float - center_x;
519 let y = row as Float - center_y;
520
521 let rotated_x = x * cos_angle - y * sin_angle + center_x;
522 let rotated_y = x * sin_angle + y * cos_angle + center_y;
523
524 let src_col = rotated_x.round() as isize;
525 let src_row = rotated_y.round() as isize;
526
527 if src_row >= 0
528 && src_row < shape.0 as isize
529 && src_col >= 0
530 && src_col < shape.1 as isize
531 {
532 for channel in 0..shape.2 {
533 result[[row, col, channel]] =
534 image[[src_row as usize, src_col as usize, channel]];
535 }
536 }
537 }
538 }
539
540 Ok(result)
541 }
542}
543
544#[derive(Debug, Clone)]
546#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
547pub struct ColorSpaceTransformer {
548 source: ColorSpace,
549 target: ColorSpace,
550}
551
552impl ColorSpaceTransformer {
553 pub fn new() -> Self {
555 Self {
556 source: ColorSpace::RGB,
557 target: ColorSpace::RGB,
558 }
559 }
560
561 pub fn from_colorspace(mut self, colorspace: ColorSpace) -> Self {
563 self.source = colorspace;
564 self
565 }
566
567 pub fn to_colorspace(mut self, colorspace: ColorSpace) -> Self {
569 self.target = colorspace;
570 self
571 }
572
573 pub fn from_rgb(mut self) -> Self {
575 self.source = ColorSpace::RGB;
576 self
577 }
578
579 pub fn to_hsv(mut self) -> Self {
581 self.target = ColorSpace::HSV;
582 self
583 }
584
585 pub fn to_grayscale(mut self) -> Self {
587 self.target = ColorSpace::Grayscale;
588 self
589 }
590}
591
592impl Transform<Array3<Float>, Array3<Float>> for ColorSpaceTransformer {
593 fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
594 match (self.source, self.target) {
595 (ColorSpace::RGB, ColorSpace::HSV) => self.rgb_to_hsv(x),
596 (ColorSpace::RGB, ColorSpace::Grayscale) => self.rgb_to_grayscale(x),
597 (ColorSpace::HSV, ColorSpace::RGB) => self.hsv_to_rgb(x),
598 (source, target) if source == target => Ok(x.clone()),
599 _ => Err(SklearsError::InvalidInput(format!(
600 "Conversion from {:?} to {:?} not implemented",
601 self.source, self.target
602 ))),
603 }
604 }
605}
606
607impl ColorSpaceTransformer {
608 fn rgb_to_hsv(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
610 let shape = image.dim();
611 if shape.2 != 3 {
612 return Err(SklearsError::InvalidInput(
613 "RGB images must have 3 channels".to_string(),
614 ));
615 }
616
617 let mut result = Array3::zeros(shape);
618
619 for row in 0..shape.0 {
620 for col in 0..shape.1 {
621 let r = image[[row, col, 0]];
622 let g = image[[row, col, 1]];
623 let b = image[[row, col, 2]];
624
625 let max_val = r.max(g).max(b);
626 let min_val = r.min(g).min(b);
627 let delta = max_val - min_val;
628
629 let h = if delta < 1e-8 {
631 0.0
632 } else if (max_val - r).abs() < 1e-8 {
633 60.0 * (((g - b) / delta) % 6.0)
634 } else if (max_val - g).abs() < 1e-8 {
635 60.0 * (((b - r) / delta) + 2.0)
636 } else {
637 60.0 * (((r - g) / delta) + 4.0)
638 };
639
640 let h = if h < 0.0 { h + 360.0 } else { h };
641
642 let s = if max_val < 1e-8 { 0.0 } else { delta / max_val };
644
645 let v = max_val;
647
648 result[[row, col, 0]] = h / 360.0; result[[row, col, 1]] = s;
650 result[[row, col, 2]] = v;
651 }
652 }
653
654 Ok(result)
655 }
656
657 fn hsv_to_rgb(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
659 let shape = image.dim();
660 if shape.2 != 3 {
661 return Err(SklearsError::InvalidInput(
662 "HSV images must have 3 channels".to_string(),
663 ));
664 }
665
666 let mut result = Array3::zeros(shape);
667
668 for row in 0..shape.0 {
669 for col in 0..shape.1 {
670 let h = image[[row, col, 0]] * 360.0; let s = image[[row, col, 1]];
672 let v = image[[row, col, 2]];
673
674 let c = v * s;
675 let x = c * (1.0 - ((h / 60.0) % 2.0 - 1.0).abs());
676 let m = v - c;
677
678 let (r, g, b) = if h < 60.0 {
679 (c, x, 0.0)
680 } else if h < 120.0 {
681 (x, c, 0.0)
682 } else if h < 180.0 {
683 (0.0, c, x)
684 } else if h < 240.0 {
685 (0.0, x, c)
686 } else if h < 300.0 {
687 (x, 0.0, c)
688 } else {
689 (c, 0.0, x)
690 };
691
692 result[[row, col, 0]] = r + m;
693 result[[row, col, 1]] = g + m;
694 result[[row, col, 2]] = b + m;
695 }
696 }
697
698 Ok(result)
699 }
700
701 fn rgb_to_grayscale(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
703 let shape = image.dim();
704 if shape.2 != 3 {
705 return Err(SklearsError::InvalidInput(
706 "RGB images must have 3 channels".to_string(),
707 ));
708 }
709
710 let mut result = Array3::zeros((shape.0, shape.1, 1));
711
712 const R_WEIGHT: Float = 0.299;
714 const G_WEIGHT: Float = 0.587;
715 const B_WEIGHT: Float = 0.114;
716
717 for row in 0..shape.0 {
718 for col in 0..shape.1 {
719 let r = image[[row, col, 0]];
720 let g = image[[row, col, 1]];
721 let b = image[[row, col, 2]];
722
723 let gray = R_WEIGHT * r + G_WEIGHT * g + B_WEIGHT * b;
724 result[[row, col, 0]] = gray;
725 }
726 }
727
728 Ok(result)
729 }
730}
731
732#[derive(Debug, Clone)]
734#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
735pub struct ImageResizer {
736 target_size: (usize, usize),
737 method: InterpolationMethod,
738}
739
740impl ImageResizer {
741 pub fn new(target_size: (usize, usize)) -> Self {
743 Self {
744 target_size,
745 method: InterpolationMethod::default(),
746 }
747 }
748
749 pub fn with_method(mut self, method: InterpolationMethod) -> Self {
751 self.method = method;
752 self
753 }
754}
755
756impl Transform<Array3<Float>, Array3<Float>> for ImageResizer {
757 fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
758 let source_shape = x.dim();
759 let (target_height, target_width) = self.target_size;
760
761 if target_height == 0 || target_width == 0 {
762 return Err(SklearsError::InvalidInput(
763 "Target dimensions must be positive".to_string(),
764 ));
765 }
766
767 let mut result = Array3::zeros((target_height, target_width, source_shape.2));
768
769 let height_scale = source_shape.0 as Float / target_height as Float;
770 let width_scale = source_shape.1 as Float / target_width as Float;
771
772 match self.method {
773 InterpolationMethod::Nearest => {
774 for row in 0..target_height {
775 for col in 0..target_width {
776 let src_row = ((row as Float + 0.5) * height_scale).floor() as usize;
777 let src_col = ((col as Float + 0.5) * width_scale).floor() as usize;
778
779 let src_row = src_row.min(source_shape.0 - 1);
780 let src_col = src_col.min(source_shape.1 - 1);
781
782 for channel in 0..source_shape.2 {
783 result[[row, col, channel]] = x[[src_row, src_col, channel]];
784 }
785 }
786 }
787 }
788 InterpolationMethod::Bilinear => {
789 for row in 0..target_height {
790 for col in 0..target_width {
791 let src_y = (row as Float + 0.5) * height_scale - 0.5;
792 let src_x = (col as Float + 0.5) * width_scale - 0.5;
793
794 let y1 = src_y.floor() as isize;
795 let x1 = src_x.floor() as isize;
796 let y2 = y1 + 1;
797 let x2 = x1 + 1;
798
799 let dy = src_y - y1 as Float;
800 let dx = src_x - x1 as Float;
801
802 for channel in 0..source_shape.2 {
803 let mut sum = 0.0;
804
805 if y1 >= 0
807 && y1 < source_shape.0 as isize
808 && x1 >= 0
809 && x1 < source_shape.1 as isize
810 {
811 sum += (1.0 - dx)
812 * (1.0 - dy)
813 * x[[y1 as usize, x1 as usize, channel]];
814 }
815 if y1 >= 0
816 && y1 < source_shape.0 as isize
817 && x2 >= 0
818 && x2 < source_shape.1 as isize
819 {
820 sum += dx * (1.0 - dy) * x[[y1 as usize, x2 as usize, channel]];
821 }
822 if y2 >= 0
823 && y2 < source_shape.0 as isize
824 && x1 >= 0
825 && x1 < source_shape.1 as isize
826 {
827 sum += (1.0 - dx) * dy * x[[y2 as usize, x1 as usize, channel]];
828 }
829 if y2 >= 0
830 && y2 < source_shape.0 as isize
831 && x2 >= 0
832 && x2 < source_shape.1 as isize
833 {
834 sum += dx * dy * x[[y2 as usize, x2 as usize, channel]];
835 }
836
837 result[[row, col, channel]] = sum;
838 }
839 }
840 }
841 }
842 InterpolationMethod::Bicubic => {
843 let bilinear_resizer =
846 ImageResizer::new(self.target_size).with_method(InterpolationMethod::Bilinear);
847 return bilinear_resizer.transform(x);
848 }
849 }
850
851 Ok(result)
852 }
853}
854
855#[derive(Debug, Clone, Copy, PartialEq, Eq)]
857#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
858pub enum EdgeDetectionMethod {
859 Sobel,
861 Laplacian,
863 Canny,
865}
866
867impl Default for EdgeDetectionMethod {
868 fn default() -> Self {
869 Self::Sobel
870 }
871}
872
873#[derive(Debug, Clone)]
875#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
876pub struct EdgeDetector {
877 method: EdgeDetectionMethod,
878 threshold: Option<Float>,
879 blur_sigma: Option<Float>,
880}
881
882impl EdgeDetector {
883 pub fn new() -> Self {
885 Self {
886 method: EdgeDetectionMethod::default(),
887 threshold: None,
888 blur_sigma: None,
889 }
890 }
891
892 pub fn with_method(mut self, method: EdgeDetectionMethod) -> Self {
894 self.method = method;
895 self
896 }
897
898 pub fn with_threshold(mut self, threshold: Float) -> Self {
900 self.threshold = Some(threshold);
901 self
902 }
903
904 pub fn with_blur_sigma(mut self, sigma: Float) -> Self {
906 self.blur_sigma = Some(sigma);
907 self
908 }
909}
910
911impl Transform<Array3<Float>, Array3<Float>> for EdgeDetector {
912 fn transform(&self, x: &Array3<Float>) -> Result<Array3<Float>> {
913 let gray_image = if x.dim().2 == 3 {
915 let color_transformer = ColorSpaceTransformer::new().from_rgb().to_grayscale();
916 color_transformer.transform(x)?
917 } else if x.dim().2 == 1 {
918 x.clone()
919 } else {
920 return Err(SklearsError::InvalidInput(
921 "Image must have 1 or 3 channels".to_string(),
922 ));
923 };
924
925 let mut processed = gray_image;
926
927 if let Some(sigma) = self.blur_sigma {
929 processed = self.gaussian_blur(&processed, sigma)?;
930 }
931
932 let edges = match self.method {
934 EdgeDetectionMethod::Sobel => self.sobel_edge_detection(&processed)?,
935 EdgeDetectionMethod::Laplacian => self.laplacian_edge_detection(&processed)?,
936 EdgeDetectionMethod::Canny => {
937 let sobel_edges = self.sobel_edge_detection(&processed)?;
939 if let Some(threshold) = self.threshold {
940 self.apply_threshold(&sobel_edges, threshold)?
941 } else {
942 sobel_edges
943 }
944 }
945 };
946
947 Ok(edges)
948 }
949}
950
951impl EdgeDetector {
952 fn gaussian_blur(&self, image: &Array3<Float>, sigma: Float) -> Result<Array3<Float>> {
954 let shape = image.dim();
955 let mut result = image.clone();
956
957 let kernel_size = (6.0 * sigma).ceil() as usize + 1;
959 let kernel_radius = kernel_size / 2;
960
961 let mut kernel = vec![vec![0.0; kernel_size]; kernel_size];
963 let mut kernel_sum = 0.0;
964
965 for i in 0..kernel_size {
966 for j in 0..kernel_size {
967 let x = (i as isize - kernel_radius as isize) as Float;
968 let y = (j as isize - kernel_radius as isize) as Float;
969 let value = (-((x * x + y * y) / (2.0 * sigma * sigma))).exp();
970 kernel[i][j] = value;
971 kernel_sum += value;
972 }
973 }
974
975 for i in 0..kernel_size {
977 for j in 0..kernel_size {
978 kernel[i][j] /= kernel_sum;
979 }
980 }
981
982 for row in kernel_radius..(shape.0 - kernel_radius) {
984 for col in kernel_radius..(shape.1 - kernel_radius) {
985 for channel in 0..shape.2 {
986 let mut sum = 0.0;
987 for ki in 0..kernel_size {
988 for kj in 0..kernel_size {
989 let img_row = row + ki - kernel_radius;
990 let img_col = col + kj - kernel_radius;
991 sum += image[[img_row, img_col, channel]] * kernel[ki][kj];
992 }
993 }
994 result[[row, col, channel]] = sum;
995 }
996 }
997 }
998
999 Ok(result)
1000 }
1001
1002 fn sobel_edge_detection(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
1004 let shape = image.dim();
1005 let mut result = Array3::zeros(shape);
1006
1007 let sobel_x: [[Float; 3]; 3] = [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]];
1009
1010 let sobel_y: [[Float; 3]; 3] = [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]];
1011
1012 for row in 1..(shape.0 - 1) {
1014 for col in 1..(shape.1 - 1) {
1015 for channel in 0..shape.2 {
1016 let mut gx = 0.0;
1017 let mut gy = 0.0;
1018
1019 for i in 0..3 {
1021 for j in 0..3 {
1022 let pixel_val = image[[row + i - 1, col + j - 1, channel]];
1023 gx += pixel_val * sobel_x[i][j];
1024 gy += pixel_val * sobel_y[i][j];
1025 }
1026 }
1027
1028 let gradient_magnitude = (gx * gx + gy * gy).sqrt();
1030 result[[row, col, channel]] = gradient_magnitude;
1031 }
1032 }
1033 }
1034
1035 Ok(result)
1036 }
1037
1038 fn laplacian_edge_detection(&self, image: &Array3<Float>) -> Result<Array3<Float>> {
1040 let shape = image.dim();
1041 let mut result = Array3::zeros(shape);
1042
1043 let laplacian: [[Float; 3]; 3] = [[0.0, -1.0, 0.0], [-1.0, 4.0, -1.0], [0.0, -1.0, 0.0]];
1045
1046 for row in 1..(shape.0 - 1) {
1048 for col in 1..(shape.1 - 1) {
1049 for channel in 0..shape.2 {
1050 let mut sum = 0.0;
1051
1052 for i in 0..3 {
1054 for j in 0..3 {
1055 let pixel_val = image[[row + i - 1, col + j - 1, channel]];
1056 sum += pixel_val * laplacian[i][j];
1057 }
1058 }
1059
1060 result[[row, col, channel]] = sum.abs();
1061 }
1062 }
1063 }
1064
1065 Ok(result)
1066 }
1067
1068 fn apply_threshold(&self, image: &Array3<Float>, threshold: Float) -> Result<Array3<Float>> {
1070 let mut result = image.clone();
1071 result.mapv_inplace(|x| if x > threshold { 1.0 } else { 0.0 });
1072 Ok(result)
1073 }
1074}
1075
1076#[derive(Debug, Clone)]
1078#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1079pub struct ImageFeatureExtractor {
1080 extract_edges: bool,
1081 extract_histograms: bool,
1082 histogram_bins: usize,
1083 extract_moments: bool,
1084}
1085
1086impl ImageFeatureExtractor {
1087 pub fn new() -> Self {
1089 Self {
1090 extract_edges: true,
1091 extract_histograms: true,
1092 histogram_bins: 32,
1093 extract_moments: true,
1094 }
1095 }
1096
1097 pub fn with_edge_features(mut self, enabled: bool) -> Self {
1099 self.extract_edges = enabled;
1100 self
1101 }
1102
1103 pub fn with_histogram_features(mut self, enabled: bool, bins: usize) -> Self {
1105 self.extract_histograms = enabled;
1106 self.histogram_bins = bins;
1107 self
1108 }
1109
1110 pub fn with_moment_features(mut self, enabled: bool) -> Self {
1112 self.extract_moments = enabled;
1113 self
1114 }
1115}
1116
1117impl Transform<Array3<Float>, Vec<Float>> for ImageFeatureExtractor {
1118 fn transform(&self, x: &Array3<Float>) -> Result<Vec<Float>> {
1119 let mut features = Vec::new();
1120
1121 if self.extract_edges {
1123 let edge_detector = EdgeDetector::new().with_method(EdgeDetectionMethod::Sobel);
1124 let edges = edge_detector.transform(x)?;
1125
1126 let total_pixels = edges.len();
1128 let edge_pixels = edges.iter().filter(|&&x| x > 0.1).count();
1129 features.push(edge_pixels as Float / total_pixels as Float);
1130
1131 let mean_edge_strength = edges.iter().sum::<Float>() / total_pixels as Float;
1133 features.push(mean_edge_strength);
1134 }
1135
1136 if self.extract_histograms {
1138 for channel in 0..x.dim().2 {
1139 let channel_data = x.index_axis(Axis(2), channel);
1140 let histogram = self.compute_histogram(&channel_data, self.histogram_bins)?;
1141 features.extend(histogram);
1142 }
1143 }
1144
1145 if self.extract_moments {
1147 for channel in 0..x.dim().2 {
1148 let channel_data = x.index_axis(Axis(2), channel);
1149 let data_vec: Vec<Float> = channel_data.iter().copied().collect();
1150
1151 let mean = data_vec.iter().sum::<Float>() / data_vec.len() as Float;
1153 features.push(mean);
1154
1155 let variance = data_vec.iter().map(|&x| (x - mean).powi(2)).sum::<Float>()
1157 / data_vec.len() as Float;
1158 features.push(variance);
1159
1160 let skewness = data_vec.iter().map(|&x| (x - mean).powi(3)).sum::<Float>()
1162 / (data_vec.len() as Float * variance.powf(1.5));
1163 features.push(skewness);
1164
1165 let kurtosis = data_vec.iter().map(|&x| (x - mean).powi(4)).sum::<Float>()
1167 / (data_vec.len() as Float * variance.powi(2));
1168 features.push(kurtosis);
1169 }
1170 }
1171
1172 Ok(features)
1173 }
1174}
1175
1176impl ImageFeatureExtractor {
1177 fn compute_histogram(
1179 &self,
1180 data: &scirs2_core::ndarray::ArrayView2<Float>,
1181 bins: usize,
1182 ) -> Result<Vec<Float>> {
1183 let min_val = data.iter().fold(Float::INFINITY, |a, &b| a.min(b));
1184 let max_val = data.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b));
1185
1186 if (max_val - min_val).abs() < 1e-10 {
1187 return Ok(vec![0.0; bins]);
1188 }
1189
1190 let mut histogram = vec![0.0; bins];
1191 let bin_width = (max_val - min_val) / bins as Float;
1192
1193 for &value in data.iter() {
1194 let bin_index = ((value - min_val) / bin_width).floor() as usize;
1195 let bin_index = bin_index.min(bins - 1);
1196 histogram[bin_index] += 1.0;
1197 }
1198
1199 let total_count = data.len() as Float;
1201 for bin in &mut histogram {
1202 *bin /= total_count;
1203 }
1204
1205 Ok(histogram)
1206 }
1207}
1208
1209#[allow(non_snake_case)]
1210#[cfg(test)]
1211mod tests {
1212 use super::*;
1213 use approx::assert_abs_diff_eq;
1214 use scirs2_core::ndarray::arr3;
1215
1216 #[test]
1217 fn test_image_normalizer_minmax() -> Result<()> {
1218 let image = arr3(&[
1219 [[100.0, 50.0, 200.0], [150.0, 75.0, 250.0]],
1220 [[200.0, 100.0, 255.0], [50.0, 25.0, 100.0]],
1221 ]);
1222
1223 let normalizer = ImageNormalizer::new()
1224 .with_strategy(ImageNormalizationStrategy::MinMax)
1225 .with_channel_wise(true);
1226
1227 let fitted = normalizer.fit(&image, &())?;
1228 let normalized = fitted.transform(&image)?;
1229
1230 for channel in 0..3 {
1232 let channel_data = normalized.index_axis(Axis(2), channel);
1233 let min_val = channel_data.iter().fold(Float::INFINITY, |a, &b| a.min(b));
1234 let max_val = channel_data
1235 .iter()
1236 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
1237
1238 assert_abs_diff_eq!(min_val, 0.0, epsilon = 1e-10);
1239 assert_abs_diff_eq!(max_val, 1.0, epsilon = 1e-10);
1240 }
1241
1242 Ok(())
1243 }
1244
1245 #[test]
1246 fn test_image_normalizer_standard_score() -> Result<()> {
1247 let image = arr3(&[
1248 [[100.0, 50.0, 200.0], [150.0, 75.0, 250.0]],
1249 [[200.0, 100.0, 255.0], [50.0, 25.0, 100.0]],
1250 ]);
1251
1252 let normalizer = ImageNormalizer::new()
1253 .with_strategy(ImageNormalizationStrategy::StandardScore)
1254 .with_channel_wise(true);
1255
1256 let fitted = normalizer.fit(&image, &())?;
1257 let normalized = fitted.transform(&image)?;
1258
1259 for channel in 0..3 {
1261 let channel_data = normalized.index_axis(Axis(2), channel);
1262 let data_vec: Vec<Float> = channel_data.iter().copied().collect();
1263
1264 let mean = data_vec.iter().sum::<Float>() / data_vec.len() as Float;
1265 let std = (data_vec.iter().map(|&x| (x - mean).powi(2)).sum::<Float>()
1266 / (data_vec.len() - 1) as Float)
1267 .sqrt();
1268
1269 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
1270 assert_abs_diff_eq!(std, 1.0, epsilon = 1e-10);
1271 }
1272
1273 Ok(())
1274 }
1275
1276 #[test]
1277 fn test_image_augmenter_horizontal_flip() -> Result<()> {
1278 let image = arr3(&[
1279 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1280 [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
1281 ]);
1282
1283 let augmenter = ImageAugmenter::new()
1284 .with_horizontal_flip(true)
1285 .with_random_seed(42); let flipped = augmenter.horizontal_flip(&image)?;
1288
1289 assert_eq!(flipped[[0, 0, 0]], image[[0, 1, 0]]);
1291 assert_eq!(flipped[[0, 1, 0]], image[[0, 0, 0]]);
1292
1293 Ok(())
1294 }
1295
1296 #[test]
1297 fn test_color_space_rgb_to_hsv() -> Result<()> {
1298 let rgb_image = arr3(&[
1299 [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], ]);
1302
1303 let transformer = ColorSpaceTransformer::new().from_rgb().to_hsv();
1304
1305 let hsv_image = transformer.transform(&rgb_image)?;
1306
1307 assert_abs_diff_eq!(hsv_image[[0, 0, 0]], 0.0, epsilon = 1e-6);
1309 assert_abs_diff_eq!(hsv_image[[0, 0, 1]], 1.0, epsilon = 1e-6);
1310 assert_abs_diff_eq!(hsv_image[[0, 0, 2]], 1.0, epsilon = 1e-6);
1311
1312 assert_abs_diff_eq!(hsv_image[[1, 1, 1]], 0.0, epsilon = 1e-6);
1314 assert_abs_diff_eq!(hsv_image[[1, 1, 2]], 1.0, epsilon = 1e-6);
1315
1316 Ok(())
1317 }
1318
1319 #[test]
1320 fn test_rgb_to_grayscale() -> Result<()> {
1321 let rgb_image = arr3(&[
1322 [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0]], ]);
1325
1326 let transformer = ColorSpaceTransformer::new().from_rgb().to_grayscale();
1327
1328 let gray_image = transformer.transform(&rgb_image)?;
1329
1330 assert_eq!(gray_image.dim().2, 1);
1332
1333 assert_abs_diff_eq!(gray_image[[0, 0, 0]], 0.299, epsilon = 1e-6);
1335
1336 assert_abs_diff_eq!(gray_image[[0, 1, 0]], 0.587, epsilon = 1e-6);
1338
1339 assert_abs_diff_eq!(gray_image[[1, 0, 0]], 0.114, epsilon = 1e-6);
1341
1342 assert_abs_diff_eq!(gray_image[[1, 1, 0]], 1.0, epsilon = 1e-6);
1344
1345 Ok(())
1346 }
1347
1348 #[test]
1349 fn test_image_resizer_nearest() -> Result<()> {
1350 let image = arr3(&[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);
1351
1352 let resizer = ImageResizer::new((4, 4)).with_method(InterpolationMethod::Nearest);
1353
1354 let resized = resizer.transform(&image)?;
1355
1356 assert_eq!(resized.dim(), (4, 4, 2));
1357
1358 assert_eq!(resized[[0, 0, 0]], image[[0, 0, 0]]);
1360 assert_eq!(resized[[3, 3, 0]], image[[1, 1, 0]]);
1361
1362 Ok(())
1363 }
1364
1365 #[test]
1366 fn test_image_resizer_bilinear() -> Result<()> {
1367 let image = arr3(&[[[0.0, 0.0], [1.0, 1.0]], [[0.0, 0.0], [1.0, 1.0]]]);
1368
1369 let resizer = ImageResizer::new((3, 3)).with_method(InterpolationMethod::Bilinear);
1370
1371 let resized = resizer.transform(&image)?;
1372
1373 assert_eq!(resized.dim(), (3, 3, 2));
1374
1375 assert!(resized[[1, 1, 0]] > 0.0 && resized[[1, 1, 0]] < 1.0);
1377
1378 Ok(())
1379 }
1380
1381 #[test]
1382 fn test_edge_detector_sobel() -> Result<()> {
1383 let image = arr3(&[
1385 [
1386 [0.0, 0.0, 0.0],
1387 [0.0, 0.0, 0.0],
1388 [1.0, 1.0, 1.0],
1389 [1.0, 1.0, 1.0],
1390 ],
1391 [
1392 [0.0, 0.0, 0.0],
1393 [0.0, 0.0, 0.0],
1394 [1.0, 1.0, 1.0],
1395 [1.0, 1.0, 1.0],
1396 ],
1397 [
1398 [0.0, 0.0, 0.0],
1399 [0.0, 0.0, 0.0],
1400 [1.0, 1.0, 1.0],
1401 [1.0, 1.0, 1.0],
1402 ],
1403 [
1404 [0.0, 0.0, 0.0],
1405 [0.0, 0.0, 0.0],
1406 [1.0, 1.0, 1.0],
1407 [1.0, 1.0, 1.0],
1408 ],
1409 ]);
1410
1411 let edge_detector = EdgeDetector::new().with_method(EdgeDetectionMethod::Sobel);
1412
1413 let edges = edge_detector.transform(&image)?;
1414
1415 assert_eq!(edges.dim().2, 1);
1417
1418 let max_edge = edges.iter().fold(0.0_f64, |acc, &x| acc.max(x));
1420 assert!(
1421 max_edge > 0.01,
1422 "Expected edge detection to produce values > 0.01, got max: {}",
1423 max_edge
1424 );
1425
1426 Ok(())
1427 }
1428
1429 #[test]
1430 fn test_edge_detector_laplacian() -> Result<()> {
1431 let image = arr3(&[
1433 [
1434 [0.0, 0.0, 0.0],
1435 [0.0, 0.0, 0.0],
1436 [1.0, 1.0, 1.0],
1437 [1.0, 1.0, 1.0],
1438 ],
1439 [
1440 [0.0, 0.0, 0.0],
1441 [0.0, 0.0, 0.0],
1442 [1.0, 1.0, 1.0],
1443 [1.0, 1.0, 1.0],
1444 ],
1445 [
1446 [0.0, 0.0, 0.0],
1447 [0.0, 0.0, 0.0],
1448 [1.0, 1.0, 1.0],
1449 [1.0, 1.0, 1.0],
1450 ],
1451 [
1452 [0.0, 0.0, 0.0],
1453 [0.0, 0.0, 0.0],
1454 [1.0, 1.0, 1.0],
1455 [1.0, 1.0, 1.0],
1456 ],
1457 ]);
1458
1459 let edge_detector = EdgeDetector::new().with_method(EdgeDetectionMethod::Laplacian);
1460
1461 let edges = edge_detector.transform(&image)?;
1462
1463 assert_eq!(edges.dim().2, 1);
1465
1466 let max_edge = edges.iter().fold(0.0_f64, |acc, &x| acc.max(x));
1468 assert!(
1469 max_edge > 0.01,
1470 "Expected edge detection to produce values > 0.01, got max: {}",
1471 max_edge
1472 );
1473
1474 Ok(())
1475 }
1476
1477 #[test]
1478 fn test_edge_detector_with_threshold() -> Result<()> {
1479 let image = arr3(&[
1480 [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]],
1481 [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]],
1482 ]);
1483
1484 let edge_detector = EdgeDetector::new()
1485 .with_method(EdgeDetectionMethod::Canny)
1486 .with_threshold(0.3);
1487
1488 let edges = edge_detector.transform(&image)?;
1489
1490 assert_eq!(edges.dim().2, 1);
1492
1493 let all_binary = edges.iter().all(|&x| x == 0.0 || x == 1.0);
1495 assert!(all_binary);
1496
1497 Ok(())
1498 }
1499
1500 #[test]
1501 fn test_image_feature_extractor() -> Result<()> {
1502 let image = arr3(&[
1503 [[0.0, 0.5, 1.0], [0.2, 0.7, 0.9]],
1504 [[0.1, 0.6, 0.8], [0.3, 0.4, 0.6]],
1505 ]);
1506
1507 let feature_extractor = ImageFeatureExtractor::new()
1508 .with_edge_features(true)
1509 .with_histogram_features(true, 4)
1510 .with_moment_features(true);
1511
1512 let features = feature_extractor.transform(&image)?;
1513
1514 assert!(!features.is_empty());
1516
1517 assert_eq!(features.len(), 26);
1521
1522 assert!(features.iter().all(|&x| x.is_finite()));
1524
1525 Ok(())
1526 }
1527
1528 #[test]
1529 fn test_image_feature_extractor_selective_features() -> Result<()> {
1530 let image = arr3(&[
1531 [[0.0, 0.5, 0.2], [0.2, 0.7, 0.1]],
1532 [[0.1, 0.6, 0.3], [0.3, 0.4, 0.2]],
1533 ]);
1534
1535 let feature_extractor = ImageFeatureExtractor::new()
1537 .with_edge_features(true)
1538 .with_histogram_features(false, 4)
1539 .with_moment_features(false);
1540
1541 let features = feature_extractor.transform(&image)?;
1542
1543 assert_eq!(features.len(), 2);
1545
1546 assert!(features[0] >= 0.0); assert!(features[1] >= 0.0); Ok(())
1551 }
1552
1553 #[test]
1554 fn test_gaussian_blur() -> Result<()> {
1555 let image = arr3(&[
1557 [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1558 [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1559 [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1560 [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1561 [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1562 [[0.0], [0.0], [0.0], [1.0], [1.0], [1.0]],
1563 ]);
1564
1565 let edge_detector_without_blur =
1567 EdgeDetector::new().with_method(EdgeDetectionMethod::Sobel);
1568 let edge_detector_with_blur = EdgeDetector::new()
1569 .with_method(EdgeDetectionMethod::Sobel)
1570 .with_blur_sigma(2.0);
1571
1572 let edges_without_blur = edge_detector_without_blur.transform(&image)?;
1573 let edges_with_blur = edge_detector_with_blur.transform(&image)?;
1574
1575 let edge_count_without_blur = edges_without_blur.iter().filter(|&&x| x > 0.01).count();
1577 let edge_count_with_blur = edges_with_blur.iter().filter(|&&x| x > 0.01).count();
1578
1579 let max_edge_without_blur = edges_without_blur
1581 .iter()
1582 .fold(0.0_f64, |acc, &x| acc.max(x));
1583 let max_edge_with_blur = edges_with_blur.iter().fold(0.0_f64, |acc, &x| acc.max(x));
1584
1585 assert!(
1587 edge_count_with_blur <= edge_count_without_blur
1588 || max_edge_with_blur <= max_edge_without_blur,
1589 "Expected blur to reduce edge count ({} vs {}) or max edge strength ({:.6} vs {:.6})",
1590 edge_count_with_blur,
1591 edge_count_without_blur,
1592 max_edge_with_blur,
1593 max_edge_without_blur
1594 );
1595
1596 Ok(())
1597 }
1598}