Skip to main content

torsh_data/vision/
mod.rs

1//! Vision-specific datasets and transformations
2//!
3//! This module provides comprehensive computer vision support including:
4//! - Image datasets and transformations
5//! - Standard vision datasets (MNIST, CIFAR-10, ImageNet)
6//! - Video processing components
7//!
8//! The module is organized into focused submodules:
9//! - `image`: Image datasets and transformations
10//! - `datasets`: Standard vision datasets
11//! - `video`: Video processing (extracted inline for now)
12
13pub mod datasets;
14pub mod image;
15
16// Re-export for backward compatibility
17pub use datasets::{ImageNet, CIFAR10, MNIST};
18pub use image::transforms::transforms::{CenterCrop, Normalize, Resize};
19pub use image::{
20    Compose, ImageFolder, ImageToTensor, RandomHorizontalFlip, RandomRotation, RandomVerticalFlip,
21    TensorToImage,
22};
23
24// Inline video components for now (extracted from original)
25use crate::{dataset::Dataset, transforms::Transform};
26use torsh_core::error::{Result, TorshError};
27use torsh_tensor::Tensor;
28
29#[cfg(not(feature = "std"))]
30use alloc::{boxed::Box, string::String, vec::Vec};
31use std::path::{Path, PathBuf};
32
33/// Video frame container
34#[derive(Debug, Clone)]
35pub struct VideoFrames {
36    frames: Vec<Tensor<f32>>,
37    fps: f32,
38    duration: f32,
39}
40
41impl VideoFrames {
42    pub fn new(frames: Vec<Tensor<f32>>, fps: f32) -> Self {
43        let duration = frames.len() as f32 / fps;
44        Self {
45            frames,
46            fps,
47            duration,
48        }
49    }
50
51    pub fn frames(&self) -> &[Tensor<f32>] {
52        &self.frames
53    }
54
55    pub fn fps(&self) -> f32 {
56        self.fps
57    }
58
59    pub fn duration(&self) -> f32 {
60        self.duration
61    }
62
63    pub fn num_frames(&self) -> usize {
64        self.frames.len()
65    }
66}
67
68/// Video dataset for loading video files from directories
69pub struct VideoFolder {
70    root: PathBuf,
71    samples: Vec<(PathBuf, usize)>,
72    classes: Vec<String>,
73    transform: Option<Box<dyn Transform<VideoFrames, Output = Tensor<f32>>>>,
74    max_frames: usize,
75    frame_rate: Option<f32>,
76}
77
78impl VideoFolder {
79    /// Create a new video folder dataset
80    pub fn new<P: AsRef<Path>>(root: P) -> Result<Self> {
81        let root = root.as_ref().to_path_buf();
82
83        if !root.exists() {
84            return Err(TorshError::IoError(format!(
85                "Directory does not exist: {root:?}"
86            )));
87        }
88
89        let mut classes = Vec::new();
90        let mut samples = Vec::new();
91
92        // Scan subdirectories for classes
93        for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
94            let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
95            let path = entry.path();
96
97            if path.is_dir() {
98                let class_name = path
99                    .file_name()
100                    .and_then(|n| n.to_str())
101                    .ok_or_else(|| TorshError::IoError("Invalid class directory name".to_string()))?
102                    .to_string();
103
104                let class_idx = classes.len();
105                classes.push(class_name);
106
107                // Scan videos in class directory
108                for video_entry in
109                    std::fs::read_dir(&path).map_err(|e| TorshError::IoError(e.to_string()))?
110                {
111                    let video_entry =
112                        video_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
113                    let video_path = video_entry.path();
114
115                    if Self::is_video_file(&video_path) {
116                        samples.push((video_path, class_idx));
117                    }
118                }
119            }
120        }
121
122        Ok(Self {
123            root,
124            samples,
125            classes,
126            transform: None,
127            max_frames: 32, // Default to 32 frames
128            frame_rate: None,
129        })
130    }
131
132    /// Set maximum number of frames to extract
133    pub fn with_max_frames(mut self, max_frames: usize) -> Self {
134        self.max_frames = max_frames;
135        self
136    }
137
138    /// Set target frame rate for extraction
139    pub fn with_frame_rate(mut self, fps: f32) -> Self {
140        self.frame_rate = Some(fps);
141        self
142    }
143
144    /// Set transform to apply to video frames
145    pub fn with_transform<T>(mut self, transform: T) -> Self
146    where
147        T: Transform<VideoFrames, Output = Tensor<f32>> + 'static,
148    {
149        self.transform = Some(Box::new(transform));
150        self
151    }
152
153    /// Get class names
154    pub fn classes(&self) -> &[String] {
155        &self.classes
156    }
157
158    /// Get the root directory path
159    pub fn root(&self) -> &Path {
160        &self.root
161    }
162
163    /// Get the number of samples
164    pub fn num_samples(&self) -> usize {
165        self.samples.len()
166    }
167
168    /// Check if file is a supported video format
169    fn is_video_file(path: &Path) -> bool {
170        if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
171            matches!(
172                extension.to_lowercase().as_str(),
173                "mp4" | "avi" | "mov" | "mkv" | "wmv" | "flv" | "webm"
174            )
175        } else {
176            false
177        }
178    }
179
180    /// Load video frames (simplified implementation)
181    fn load_video(&self, _path: &Path) -> Result<VideoFrames> {
182        // In a real implementation, this would use ffmpeg or similar to extract frames
183        // For now, create dummy video data
184        let mut frames = Vec::new();
185        for _i in 0..self.max_frames {
186            // Create dummy frame (3 channels, 224x224 - typical video frame size)
187            let frame = torsh_tensor::creation::rand::<f32>(&[3, 224, 224])?;
188            frames.push(frame);
189        }
190
191        let fps = self.frame_rate.unwrap_or(30.0);
192        Ok(VideoFrames::new(frames, fps))
193    }
194}
195
196impl Dataset for VideoFolder {
197    type Item = (Tensor<f32>, usize);
198
199    fn len(&self) -> usize {
200        self.samples.len()
201    }
202
203    fn get(&self, index: usize) -> Result<Self::Item> {
204        if index >= self.samples.len() {
205            return Err(TorshError::IndexError {
206                index,
207                size: self.samples.len(),
208            });
209        }
210
211        let (ref path, class_idx) = self.samples[index];
212        let video_frames = self.load_video(path)?;
213
214        let tensor = if let Some(ref transform) = self.transform {
215            transform.transform(video_frames)?
216        } else {
217            // Default: convert to tensor (concatenate frames along batch dimension)
218            VideoToTensor.transform(video_frames)?
219        };
220
221        Ok((tensor, class_idx))
222    }
223}
224
225/// Transform to convert video frames to tensor
226pub struct VideoToTensor;
227
228impl Transform<VideoFrames> for VideoToTensor {
229    type Output = Tensor<f32>;
230
231    fn transform(&self, input: VideoFrames) -> Result<Self::Output> {
232        let frames = input.frames();
233        if frames.is_empty() {
234            return Err(TorshError::InvalidArgument(
235                "VideoFrames cannot be empty".to_string(),
236            ));
237        }
238
239        // Get frame dimensions
240        let frame_shape = frames[0].shape();
241        let dims = frame_shape.dims();
242
243        if dims.len() != 3 {
244            return Err(TorshError::InvalidShape(
245                "Expected 3D frame tensors (C, H, W)".to_string(),
246            ));
247        }
248
249        let (channels, height, width) = (dims[0], dims[1], dims[2]);
250        let num_frames = frames.len();
251
252        // Concatenate frames into a single tensor: (T, C, H, W)
253        let mut video_data = Vec::with_capacity(num_frames * channels * height * width);
254
255        for frame in frames {
256            let frame_data = frame.to_vec()?;
257            video_data.extend(frame_data);
258        }
259
260        Tensor::from_data(
261            video_data,
262            vec![num_frames, channels, height, width],
263            torsh_core::device::DeviceType::Cpu,
264        )
265    }
266}
267
268/// Transform to convert tensor to video frames
269pub struct TensorToVideo {
270    fps: f32,
271}
272
273impl TensorToVideo {
274    pub fn new(fps: f32) -> Self {
275        Self { fps }
276    }
277}
278
279impl Transform<Tensor<f32>> for TensorToVideo {
280    type Output = VideoFrames;
281
282    fn transform(&self, input: Tensor<f32>) -> Result<Self::Output> {
283        let shape = input.shape();
284        let dims = shape.dims();
285
286        if dims.len() != 4 {
287            return Err(TorshError::InvalidShape(
288                "Expected 4D tensor (T, C, H, W)".to_string(),
289            ));
290        }
291
292        let (num_frames, channels, height, width) = (dims[0], dims[1], dims[2], dims[3]);
293        let frame_size = channels * height * width;
294
295        let data = input.to_vec()?;
296        let mut frames = Vec::with_capacity(num_frames);
297
298        for t in 0..num_frames {
299            let start_idx = t * frame_size;
300            let end_idx = start_idx + frame_size;
301            let frame_data = data[start_idx..end_idx].to_vec();
302
303            let frame = Tensor::from_data(
304                frame_data,
305                vec![channels, height, width],
306                torsh_core::device::DeviceType::Cpu,
307            )?;
308
309            frames.push(frame);
310        }
311
312        Ok(VideoFrames::new(frames, self.fps))
313    }
314}