1pub mod datasets;
14pub mod image;
15
16pub use datasets::{ImageNet, CIFAR10, MNIST};
18pub use image::transforms::transforms::{CenterCrop, Normalize, Resize};
19pub use image::{
20 Compose, ImageFolder, ImageToTensor, RandomHorizontalFlip, RandomRotation, RandomVerticalFlip,
21 TensorToImage,
22};
23
24use crate::{dataset::Dataset, transforms::Transform};
26use torsh_core::error::{Result, TorshError};
27use torsh_tensor::Tensor;
28
29#[cfg(not(feature = "std"))]
30use alloc::{boxed::Box, string::String, vec::Vec};
31use std::path::{Path, PathBuf};
32
33#[derive(Debug, Clone)]
35pub struct VideoFrames {
36 frames: Vec<Tensor<f32>>,
37 fps: f32,
38 duration: f32,
39}
40
41impl VideoFrames {
42 pub fn new(frames: Vec<Tensor<f32>>, fps: f32) -> Self {
43 let duration = frames.len() as f32 / fps;
44 Self {
45 frames,
46 fps,
47 duration,
48 }
49 }
50
51 pub fn frames(&self) -> &[Tensor<f32>] {
52 &self.frames
53 }
54
55 pub fn fps(&self) -> f32 {
56 self.fps
57 }
58
59 pub fn duration(&self) -> f32 {
60 self.duration
61 }
62
63 pub fn num_frames(&self) -> usize {
64 self.frames.len()
65 }
66}
67
68pub struct VideoFolder {
70 root: PathBuf,
71 samples: Vec<(PathBuf, usize)>,
72 classes: Vec<String>,
73 transform: Option<Box<dyn Transform<VideoFrames, Output = Tensor<f32>>>>,
74 max_frames: usize,
75 frame_rate: Option<f32>,
76}
77
78impl VideoFolder {
79 pub fn new<P: AsRef<Path>>(root: P) -> Result<Self> {
81 let root = root.as_ref().to_path_buf();
82
83 if !root.exists() {
84 return Err(TorshError::IoError(format!(
85 "Directory does not exist: {root:?}"
86 )));
87 }
88
89 let mut classes = Vec::new();
90 let mut samples = Vec::new();
91
92 for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
94 let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
95 let path = entry.path();
96
97 if path.is_dir() {
98 let class_name = path
99 .file_name()
100 .and_then(|n| n.to_str())
101 .ok_or_else(|| TorshError::IoError("Invalid class directory name".to_string()))?
102 .to_string();
103
104 let class_idx = classes.len();
105 classes.push(class_name);
106
107 for video_entry in
109 std::fs::read_dir(&path).map_err(|e| TorshError::IoError(e.to_string()))?
110 {
111 let video_entry =
112 video_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
113 let video_path = video_entry.path();
114
115 if Self::is_video_file(&video_path) {
116 samples.push((video_path, class_idx));
117 }
118 }
119 }
120 }
121
122 Ok(Self {
123 root,
124 samples,
125 classes,
126 transform: None,
127 max_frames: 32, frame_rate: None,
129 })
130 }
131
132 pub fn with_max_frames(mut self, max_frames: usize) -> Self {
134 self.max_frames = max_frames;
135 self
136 }
137
138 pub fn with_frame_rate(mut self, fps: f32) -> Self {
140 self.frame_rate = Some(fps);
141 self
142 }
143
144 pub fn with_transform<T>(mut self, transform: T) -> Self
146 where
147 T: Transform<VideoFrames, Output = Tensor<f32>> + 'static,
148 {
149 self.transform = Some(Box::new(transform));
150 self
151 }
152
153 pub fn classes(&self) -> &[String] {
155 &self.classes
156 }
157
158 pub fn root(&self) -> &Path {
160 &self.root
161 }
162
163 pub fn num_samples(&self) -> usize {
165 self.samples.len()
166 }
167
168 fn is_video_file(path: &Path) -> bool {
170 if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
171 matches!(
172 extension.to_lowercase().as_str(),
173 "mp4" | "avi" | "mov" | "mkv" | "wmv" | "flv" | "webm"
174 )
175 } else {
176 false
177 }
178 }
179
180 fn load_video(&self, _path: &Path) -> Result<VideoFrames> {
182 let mut frames = Vec::new();
185 for _i in 0..self.max_frames {
186 let frame = torsh_tensor::creation::rand::<f32>(&[3, 224, 224])?;
188 frames.push(frame);
189 }
190
191 let fps = self.frame_rate.unwrap_or(30.0);
192 Ok(VideoFrames::new(frames, fps))
193 }
194}
195
196impl Dataset for VideoFolder {
197 type Item = (Tensor<f32>, usize);
198
199 fn len(&self) -> usize {
200 self.samples.len()
201 }
202
203 fn get(&self, index: usize) -> Result<Self::Item> {
204 if index >= self.samples.len() {
205 return Err(TorshError::IndexError {
206 index,
207 size: self.samples.len(),
208 });
209 }
210
211 let (ref path, class_idx) = self.samples[index];
212 let video_frames = self.load_video(path)?;
213
214 let tensor = if let Some(ref transform) = self.transform {
215 transform.transform(video_frames)?
216 } else {
217 VideoToTensor.transform(video_frames)?
219 };
220
221 Ok((tensor, class_idx))
222 }
223}
224
225pub struct VideoToTensor;
227
228impl Transform<VideoFrames> for VideoToTensor {
229 type Output = Tensor<f32>;
230
231 fn transform(&self, input: VideoFrames) -> Result<Self::Output> {
232 let frames = input.frames();
233 if frames.is_empty() {
234 return Err(TorshError::InvalidArgument(
235 "VideoFrames cannot be empty".to_string(),
236 ));
237 }
238
239 let frame_shape = frames[0].shape();
241 let dims = frame_shape.dims();
242
243 if dims.len() != 3 {
244 return Err(TorshError::InvalidShape(
245 "Expected 3D frame tensors (C, H, W)".to_string(),
246 ));
247 }
248
249 let (channels, height, width) = (dims[0], dims[1], dims[2]);
250 let num_frames = frames.len();
251
252 let mut video_data = Vec::with_capacity(num_frames * channels * height * width);
254
255 for frame in frames {
256 let frame_data = frame.to_vec()?;
257 video_data.extend(frame_data);
258 }
259
260 Tensor::from_data(
261 video_data,
262 vec![num_frames, channels, height, width],
263 torsh_core::device::DeviceType::Cpu,
264 )
265 }
266}
267
268pub struct TensorToVideo {
270 fps: f32,
271}
272
273impl TensorToVideo {
274 pub fn new(fps: f32) -> Self {
275 Self { fps }
276 }
277}
278
279impl Transform<Tensor<f32>> for TensorToVideo {
280 type Output = VideoFrames;
281
282 fn transform(&self, input: Tensor<f32>) -> Result<Self::Output> {
283 let shape = input.shape();
284 let dims = shape.dims();
285
286 if dims.len() != 4 {
287 return Err(TorshError::InvalidShape(
288 "Expected 4D tensor (T, C, H, W)".to_string(),
289 ));
290 }
291
292 let (num_frames, channels, height, width) = (dims[0], dims[1], dims[2], dims[3]);
293 let frame_size = channels * height * width;
294
295 let data = input.to_vec()?;
296 let mut frames = Vec::with_capacity(num_frames);
297
298 for t in 0..num_frames {
299 let start_idx = t * frame_size;
300 let end_idx = start_idx + frame_size;
301 let frame_data = data[start_idx..end_idx].to_vec();
302
303 let frame = Tensor::from_data(
304 frame_data,
305 vec![channels, height, width],
306 torsh_core::device::DeviceType::Cpu,
307 )?;
308
309 frames.push(frame);
310 }
311
312 Ok(VideoFrames::new(frames, self.fps))
313 }
314}