1use crate::transforms::Transform;
2#[cfg(feature = "image-support")]
3use image::{DynamicImage, GenericImageView, ImageBuffer};
4use scirs2_core::RngExt;
5use torsh_core::error::{Result, TorshError};
6use torsh_tensor::Tensor;
7
8pub struct ImageToTensor;
10
11impl Transform<DynamicImage> for ImageToTensor {
12 type Output = Tensor<f32>;
13
14 fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
15 #[cfg(feature = "image-support")]
16 {
17 let rgb_image = input.to_rgb8();
18 let (width, height) = rgb_image.dimensions();
19
20 let mut data = Vec::with_capacity((width * height * 3) as usize);
22
23 for channel in 0..3 {
25 for y in 0..height {
26 for x in 0..width {
27 let pixel = rgb_image.get_pixel(x, y);
28 let value = pixel[channel] as f32 / 255.0;
29 data.push(value);
30 }
31 }
32 }
33
34 Tensor::from_data(
35 data,
36 vec![3, height as usize, width as usize],
37 torsh_core::device::DeviceType::Cpu,
38 )
39 }
40
41 #[cfg(not(feature = "image-support"))]
42 {
43 Err(TorshError::UnsupportedOperation {
44 op: "image to tensor conversion".to_string(),
45 dtype: "DynamicImage".to_string(),
46 })
47 }
48 }
49}
50
51pub struct TensorToImage;
53
54impl Transform<Tensor<f32>> for TensorToImage {
55 type Output = DynamicImage;
56
57 fn transform(&self, input: Tensor<f32>) -> Result<Self::Output> {
58 #[cfg(feature = "image-support")]
59 {
60 let shape = input.shape();
61 if shape.ndim() != 3 {
62 return Err(TorshError::InvalidShape(
63 "Expected 3D tensor (C, H, W)".to_string(),
64 ));
65 }
66
67 let dims = shape.dims();
68 let (channels, height, width) = (dims[0], dims[1], dims[2]);
69
70 if channels != 3 {
71 return Err(TorshError::InvalidShape(
72 "Expected 3 channels for RGB image".to_string(),
73 ));
74 }
75
76 let data = input.to_vec()?;
77 let mut img_data = Vec::with_capacity(width * height * 3);
78
79 for y in 0..height {
81 for x in 0..width {
82 for c in 0..3 {
83 let idx = c * height * width + y * width + x;
84 let value = (data[idx] * 255.0).clamp(0.0, 255.0) as u8;
85 img_data.push(value);
86 }
87 }
88 }
89
90 let img_buffer = ImageBuffer::from_raw(width as u32, height as u32, img_data)
91 .ok_or_else(|| {
92 TorshError::InvalidArgument("Failed to create image buffer".to_string())
93 })?;
94
95 Ok(DynamicImage::ImageRgb8(img_buffer))
96 }
97
98 #[cfg(not(feature = "image-support"))]
99 {
100 Err(TorshError::UnsupportedOperation {
101 op: "tensor to image conversion".to_string(),
102 dtype: "Tensor<f32>".to_string(),
103 })
104 }
105 }
106}
107
108pub struct Compose<T> {
110 transforms: Vec<Box<dyn Transform<T, Output = T>>>,
111}
112
113impl<T: 'static> Compose<T> {
114 pub fn new() -> Self {
115 Self {
116 transforms: Vec::new(),
117 }
118 }
119
120 pub fn add_transform<F>(mut self, transform: F) -> Self
121 where
122 F: Transform<T, Output = T> + 'static,
123 {
124 self.transforms.push(Box::new(transform));
125 self
126 }
127
128 pub fn add_boxed(mut self, transform: Box<dyn Transform<T, Output = T>>) -> Self {
129 self.transforms.push(transform);
130 self
131 }
132}
133
134impl<T> Transform<T> for Compose<T> {
135 type Output = T;
136
137 fn transform(&self, mut input: T) -> Result<Self::Output> {
138 for transform in &self.transforms {
139 input = transform.transform(input)?;
140 }
141 Ok(input)
142 }
143}
144
145impl<T: 'static> Default for Compose<T> {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151pub struct RandomHorizontalFlip {
153 prob: f32,
154}
155
156impl RandomHorizontalFlip {
157 pub fn new(prob: f32) -> Self {
158 Self { prob }
159 }
160}
161
162impl Transform<DynamicImage> for RandomHorizontalFlip {
163 type Output = DynamicImage;
164
165 fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
166 #[cfg(feature = "image-support")]
167 {
168 #[allow(unused_imports)] use scirs2_core::random::{Random, Rng};
171 let mut rng = Random::seed(0);
172 if rng.random::<f32>() < self.prob {
173 Ok(input.fliph())
174 } else {
175 Ok(input)
176 }
177 }
178
179 #[cfg(not(feature = "image-support"))]
180 {
181 Err(TorshError::UnsupportedOperation {
182 op: "random horizontal flip".to_string(),
183 dtype: "DynamicImage".to_string(),
184 })
185 }
186 }
187}
188
189pub struct RandomVerticalFlip {
191 prob: f32,
192}
193
194impl RandomVerticalFlip {
195 pub fn new(prob: f32) -> Self {
196 Self { prob }
197 }
198}
199
200impl Transform<DynamicImage> for RandomVerticalFlip {
201 type Output = DynamicImage;
202
203 fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
204 #[cfg(feature = "image-support")]
205 {
206 #[allow(unused_imports)] use scirs2_core::random::{Random, Rng};
209 let mut rng = Random::seed(0);
210 if rng.random::<f32>() < self.prob {
211 Ok(input.flipv())
212 } else {
213 Ok(input)
214 }
215 }
216
217 #[cfg(not(feature = "image-support"))]
218 {
219 Err(TorshError::UnsupportedOperation {
220 op: "random vertical flip".to_string(),
221 dtype: "DynamicImage".to_string(),
222 })
223 }
224 }
225}
226
227pub struct RandomRotation {
229 degrees: f32,
230}
231
232impl RandomRotation {
233 pub fn new(degrees: f32) -> Self {
234 Self { degrees }
235 }
236}
237
238impl Transform<DynamicImage> for RandomRotation {
239 type Output = DynamicImage;
240
241 fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
242 #[cfg(all(feature = "image-support", feature = "imageproc"))]
243 {
244 #[allow(unused_imports)] use scirs2_core::random::{Random, Rng};
247 let mut rng = Random::seed(0);
248 let angle_deg = rng.gen_range(-self.degrees..=self.degrees);
249 let angle_rad = angle_deg.to_radians();
250
251 let rgb_image = input.to_rgb8();
253
254 let rotated = imageproc::geometric_transformations::rotate_about_center(
257 &rgb_image,
258 angle_rad,
259 imageproc::geometric_transformations::Interpolation::Bilinear,
260 image::Rgb([0u8, 0u8, 0u8]), );
262
263 Ok(DynamicImage::ImageRgb8(rotated))
264 }
265
266 #[cfg(all(feature = "image-support", not(feature = "imageproc")))]
267 {
268 #[allow(unused_imports)] use scirs2_core::random::{Random, Rng};
271 let mut rng = Random::seed(0);
272 let _angle = rng.gen_range(-self.degrees..=self.degrees);
273 Ok(input)
276 }
277
278 #[cfg(not(feature = "image-support"))]
279 {
280 Err(TorshError::UnsupportedOperation {
281 op: "random rotation".to_string(),
282 dtype: "DynamicImage".to_string(),
283 })
284 }
285 }
286}
287
288pub mod transforms {
290 use super::*;
291 use crate::transforms::Transform;
292
293 pub struct Resize {
295 size: (u32, u32),
296 }
297
298 impl Resize {
299 pub fn new(size: (u32, u32)) -> Self {
300 Self { size }
301 }
302 }
303
304 impl Transform<DynamicImage> for Resize {
305 type Output = DynamicImage;
306
307 fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
308 #[cfg(feature = "image-support")]
309 {
310 Ok(input.resize_exact(
311 self.size.0,
312 self.size.1,
313 image::imageops::FilterType::Lanczos3,
314 ))
315 }
316
317 #[cfg(not(feature = "image-support"))]
318 {
319 Err(TorshError::UnsupportedOperation {
320 op: "image resize".to_string(),
321 dtype: "DynamicImage".to_string(),
322 })
323 }
324 }
325 }
326
327 pub struct CenterCrop {
329 size: (u32, u32),
330 }
331
332 impl CenterCrop {
333 pub fn new(size: (u32, u32)) -> Self {
334 Self { size }
335 }
336 }
337
338 impl Transform<DynamicImage> for CenterCrop {
339 type Output = DynamicImage;
340
341 fn transform(&self, input: DynamicImage) -> Result<Self::Output> {
342 #[cfg(feature = "image-support")]
343 {
344 let (width, height) = input.dimensions();
345 let (crop_width, crop_height) = self.size;
346
347 if crop_width > width || crop_height > height {
348 return Err(TorshError::InvalidArgument(
349 "Crop size cannot be larger than image size".to_string(),
350 ));
351 }
352
353 let x = (width - crop_width) / 2;
354 let y = (height - crop_height) / 2;
355
356 Ok(input.crop_imm(x, y, crop_width, crop_height))
357 }
358
359 #[cfg(not(feature = "image-support"))]
360 {
361 Err(TorshError::UnsupportedOperation {
362 op: "image center crop".to_string(),
363 dtype: "DynamicImage".to_string(),
364 })
365 }
366 }
367 }
368
369 pub struct Normalize {
371 mean: [f32; 3],
372 std: [f32; 3],
373 }
374
375 impl Normalize {
376 pub fn new(mean: [f32; 3], std: [f32; 3]) -> Self {
377 Self { mean, std }
378 }
379
380 pub fn imagenet() -> Self {
382 Self::new([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
383 }
384 }
385
386 impl Transform<Tensor<f32>> for Normalize {
387 type Output = Tensor<f32>;
388
389 fn transform(&self, input: Tensor<f32>) -> Result<Self::Output> {
390 let shape_ref = input.shape();
393 let shape = shape_ref.dims();
394
395 if shape.len() != 3 {
396 return Err(TorshError::InvalidShape(format!(
397 "Expected 3D tensor (C, H, W), got shape {shape:?}"
398 )));
399 }
400
401 let channels = shape[0];
402 if channels != 3 {
403 return Err(TorshError::InvalidShape(format!(
404 "Expected 3 channels for RGB image, got {channels}"
405 )));
406 }
407
408 let mut data = input.to_vec()?;
410 let height = shape[1];
411 let width = shape[2];
412 let channel_size = height * width;
413
414 for c in 0..3 {
416 let channel_start = c * channel_size;
417 let channel_end = channel_start + channel_size;
418 let mean = self.mean[c];
419 let std = self.std[c];
420
421 for pixel in &mut data[channel_start..channel_end] {
422 *pixel = (*pixel - mean) / std;
423 }
424 }
425
426 Tensor::from_data(data, shape.to_vec(), input.device())
428 }
429 }
430}