1use crate::transforms::Transform;
16use torsh_core::error::Result;
17use torsh_core::{
18 dtype::{FloatElement, TensorElement},
19 error::TorshError,
20};
21use torsh_tensor::Tensor;
22use scirs2_core::random::thread_rng;
24
25#[derive(Debug, Clone)]
30pub struct RandomHorizontalFlip {
31 prob: f32,
32}
33
34impl RandomHorizontalFlip {
35 pub fn new(prob: f32) -> Self {
43 assert!(
44 (0.0..=1.0).contains(&prob),
45 "Probability must be between 0 and 1"
46 );
47 Self { prob }
48 }
49}
50
51impl<T: FloatElement> Transform<Tensor<T>> for RandomHorizontalFlip {
52 type Output = Tensor<T>;
53
54 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
55 let mut rng = thread_rng(); let random_val = rng.random::<f32>();
58 if random_val < self.prob {
59 self.horizontal_flip(input)
60 } else {
61 Ok(input)
62 }
63 }
64
65 fn is_deterministic(&self) -> bool {
66 false
67 }
68}
69
70impl RandomHorizontalFlip {
71 fn horizontal_flip<T: FloatElement>(&self, input: Tensor<T>) -> Result<Tensor<T>> {
72 let binding = input.shape();
73 let shape = binding.dims();
74 if shape.len() < 2 {
75 return Err(TorshError::InvalidArgument(
76 "Input tensor must have at least 2 dimensions for horizontal flip".to_string(),
77 ));
78 }
79
80 Ok(input)
85 }
86}
87
88#[derive(Debug, Clone)]
93pub struct RandomCrop {
94 size: (usize, usize),
95 padding: Option<usize>,
96}
97
98impl RandomCrop {
99 pub fn new(size: (usize, usize)) -> Self {
104 Self {
105 size,
106 padding: None,
107 }
108 }
109
110 pub fn with_padding(mut self, padding: usize) -> Self {
115 self.padding = Some(padding);
116 self
117 }
118}
119
120impl<T: TensorElement> Transform<Tensor<T>> for RandomCrop {
121 type Output = Tensor<T>;
122
123 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
124 let shape = input.shape();
125 let dims = shape.dims();
126
127 if dims.len() < 2 {
129 return Err(TorshError::InvalidArgument(
130 "Input tensor must have at least 2 dimensions for random crop".to_string(),
131 ));
132 }
133
134 let (input_height, input_width) = if dims.len() == 2 {
135 (dims[0], dims[1])
136 } else {
137 (dims[1], dims[2])
139 };
140
141 let (crop_height, crop_width) = self.size;
142
143 if crop_height > input_height || crop_width > input_width {
145 if let Some(padding) = self.padding {
146 let _new_height = input_height.max(crop_height) + 2 * padding;
148 let _new_width = input_width.max(crop_width) + 2 * padding;
149
150 return Ok(input);
154 } else {
155 return Err(TorshError::InvalidArgument(
156 format!("Crop size ({crop_height}, {crop_width}) is larger than input size ({input_height}, {input_width}) and no padding specified"),
157 ));
158 }
159 }
160
161 let mut rng = thread_rng();
163 let max_y = input_height - crop_height;
164 let max_x = input_width - crop_width;
165
166 let _start_y = if max_y > 0 {
167 rng.gen_range(0..=max_y)
168 } else {
169 0
170 };
171 let _start_x = if max_x > 0 {
172 rng.gen_range(0..=max_x)
173 } else {
174 0
175 };
176
177 Ok(input)
187 }
188
189 fn is_deterministic(&self) -> bool {
190 false
191 }
192}
193
194#[derive(Clone, Copy, Debug, PartialEq, Eq)]
196pub enum InterpolationMode {
197 Nearest,
199 Linear,
201 Bilinear,
203 Bicubic,
205}
206
207impl Default for InterpolationMode {
208 fn default() -> Self {
209 Self::Bilinear
210 }
211}
212
213#[derive(Debug, Clone)]
218pub struct Resize {
219 size: (usize, usize),
220 interpolation: InterpolationMode,
221}
222
223impl Resize {
224 pub fn new(size: (usize, usize)) -> Self {
229 Self {
230 size,
231 interpolation: InterpolationMode::Bilinear,
232 }
233 }
234
235 pub fn with_interpolation(mut self, mode: InterpolationMode) -> Self {
240 self.interpolation = mode;
241 self
242 }
243}
244
245impl<T: FloatElement> Transform<Tensor<T>> for Resize {
246 type Output = Tensor<T>;
247
248 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
249 let shape = input.shape();
250 let dims = shape.dims();
251
252 if dims.len() < 2 {
254 return Err(TorshError::InvalidArgument(
255 "Input tensor must have at least 2 dimensions for resize".to_string(),
256 ));
257 }
258
259 let (input_height, input_width) = if dims.len() == 2 {
260 (dims[0], dims[1])
261 } else {
262 (dims[1], dims[2])
264 };
265
266 let (target_height, target_width) = self.size;
267
268 if input_height == target_height && input_width == target_width {
270 return Ok(input);
271 }
272
273 match self.interpolation {
288 InterpolationMode::Nearest => {
289 Ok(input)
292 }
293 InterpolationMode::Linear | InterpolationMode::Bilinear => {
294 Ok(input)
297 }
298 InterpolationMode::Bicubic => {
299 Ok(input)
302 }
303 }
304 }
305
306 fn is_deterministic(&self) -> bool {
307 true
308 }
309}
310
311#[derive(Debug, Clone)]
316pub struct CenterCrop {
317 size: (usize, usize),
318}
319
320impl CenterCrop {
321 pub fn new(size: (usize, usize)) -> Self {
326 Self { size }
327 }
328}
329
330impl<T: TensorElement> Transform<Tensor<T>> for CenterCrop {
331 type Output = Tensor<T>;
332
333 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
334 let shape = input.shape();
335 let dims = shape.dims();
336
337 if dims.len() < 2 {
339 return Err(TorshError::InvalidArgument(
340 "Input tensor must have at least 2 dimensions for center crop".to_string(),
341 ));
342 }
343
344 let (input_height, input_width) = if dims.len() == 2 {
345 (dims[0], dims[1])
346 } else {
347 (dims[1], dims[2])
349 };
350
351 let (crop_height, crop_width) = self.size;
352
353 if crop_height > input_height || crop_width > input_width {
355 return Err(TorshError::InvalidArgument(
356 format!("Crop size ({crop_height}, {crop_width}) is larger than input size ({input_height}, {input_width})"),
357 ));
358 }
359
360 let _start_y = (input_height - crop_height) / 2;
362 let _start_x = (input_width - crop_width) / 2;
363
364 Ok(input)
380 }
381
382 fn is_deterministic(&self) -> bool {
383 true
384 }
385}
386
387pub fn random_horizontal_flip(prob: f32) -> RandomHorizontalFlip {
389 RandomHorizontalFlip::new(prob)
390}
391
392pub fn random_crop(size: (usize, usize)) -> RandomCrop {
394 RandomCrop::new(size)
395}
396
397pub fn resize(size: (usize, usize)) -> Resize {
399 Resize::new(size)
400}
401
402pub fn center_crop(size: (usize, usize)) -> CenterCrop {
404 CenterCrop::new(size)
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use torsh_core::device::DeviceType;
411
412 fn mock_tensor_2d() -> Tensor<f32> {
413 Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu).unwrap()
414 }
415
416 fn mock_tensor_3d() -> Tensor<f32> {
417 Tensor::from_data(
418 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
419 vec![2, 2, 2], DeviceType::Cpu,
421 )
422 .unwrap()
423 }
424
425 #[test]
426 fn test_random_horizontal_flip_creation() {
427 let flip = RandomHorizontalFlip::new(0.5);
428 let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &flip;
429 assert!(!_test.is_deterministic());
430 }
431
432 #[test]
433 #[should_panic]
434 fn test_random_horizontal_flip_invalid_prob() {
435 RandomHorizontalFlip::new(1.5); }
437
438 #[test]
439 fn test_random_crop_creation() {
440 let crop = RandomCrop::new((224, 224));
441 let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
442 assert!(!_test.is_deterministic());
443 }
444
445 #[test]
446 fn test_random_crop_with_padding() {
447 let crop = RandomCrop::new((224, 224)).with_padding(10);
448 let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
449 assert!(!_test.is_deterministic());
450 }
451
452 #[test]
453 fn test_resize_creation() {
454 let resize_transform = Resize::new((224, 224));
455 let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &resize_transform;
456 assert!(_test.is_deterministic());
457 }
458
459 #[test]
460 fn test_resize_with_interpolation() {
461 let resize_transform =
462 Resize::new((224, 224)).with_interpolation(InterpolationMode::Nearest);
463 let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &resize_transform;
464 assert!(_test.is_deterministic());
465 }
466
467 #[test]
468 fn test_center_crop_creation() {
469 let crop = CenterCrop::new((224, 224));
470 let _test: &dyn Transform<Tensor<f32>, Output = Tensor<f32>> = &crop;
471 assert!(_test.is_deterministic());
472 }
473
474 #[test]
475 fn test_interpolation_mode_default() {
476 assert_eq!(InterpolationMode::default(), InterpolationMode::Bilinear);
477 }
478
479 #[test]
480 fn test_tensor_transforms_2d() {
481 let tensor = mock_tensor_2d();
482
483 let flip = RandomHorizontalFlip::new(1.0); let result = flip.transform(tensor.clone());
485 assert!(result.is_ok());
486
487 let crop = CenterCrop::new((1, 1));
488 let result = crop.transform(tensor.clone());
489 assert!(result.is_ok());
490
491 let resize_transform = Resize::new((4, 4));
492 let result = resize_transform.transform(tensor);
493 assert!(result.is_ok());
494 }
495
496 #[test]
497 fn test_tensor_transforms_3d() {
498 let tensor = mock_tensor_3d();
499
500 let flip = RandomHorizontalFlip::new(0.0); let result = flip.transform(tensor.clone());
502 assert!(result.is_ok());
503
504 let crop = CenterCrop::new((1, 1));
505 let result = crop.transform(tensor.clone());
506 assert!(result.is_ok());
507
508 let resize_transform = Resize::new((4, 4));
509 let result = resize_transform.transform(tensor);
510 assert!(result.is_ok());
511 }
512
513 #[test]
514 fn test_convenience_functions() {
515 let _flip = random_horizontal_flip(0.5);
516 let _crop = random_crop((224, 224));
517 let _resize = resize((256, 256));
518 let _center = center_crop((224, 224));
519 }
520
521 #[test]
522 fn test_invalid_tensor_dimensions() {
523 let tensor_1d = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu).unwrap();
524
525 let flip = RandomHorizontalFlip::new(1.0);
526 assert!(flip.transform(tensor_1d.clone()).is_err());
527
528 let crop = CenterCrop::new((1, 1));
529 assert!(crop.transform(tensor_1d.clone()).is_err());
530
531 let resize_transform = Resize::new((4, 4));
532 assert!(resize_transform.transform(tensor_1d).is_err());
533 }
534
535 #[test]
536 fn test_crop_size_validation() {
537 let tensor = mock_tensor_2d(); let crop = CenterCrop::new((3, 3)); assert!(crop.transform(tensor.clone()).is_err());
541
542 let random_crop = RandomCrop::new((3, 3)); assert!(random_crop.transform(tensor).is_err());
544 }
545}