1pub use crate::optimized_impl::*;
8
9use crate::utils::{image_to_tensor, load_images_from_dir};
11use crate::{Result, VisionError};
12use image::DynamicImage;
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15use torsh_tensor::creation;
16use torsh_tensor::Tensor;
17
18#[derive(Debug)]
24pub struct ImageFolder {
25 data: Vec<(Tensor<f32>, usize)>,
26 class_to_idx: HashMap<String, usize>,
27 classes: Vec<String>,
28}
29
30impl ImageFolder {
31 pub fn new<P: AsRef<Path>>(root: P) -> Result<Self> {
36 eprintln!("Warning: ImageFolder loads all data into memory. Consider using OptimizedImageDataset for large datasets.");
37
38 let root_path = root.as_ref();
39
40 if !root_path.exists() {
41 return Err(VisionError::IoError(std::io::Error::new(
42 std::io::ErrorKind::NotFound,
43 format!("Directory {:?} does not exist", root_path),
44 )));
45 }
46
47 let mut classes = Vec::new();
48 let mut class_to_idx = HashMap::new();
49 let mut data = Vec::new();
50
51 for entry in std::fs::read_dir(root_path)? {
53 let entry = entry?;
54 let path = entry.path();
55
56 if path.is_dir() {
57 if let Some(class_name) = path.file_name() {
58 let class_str = class_name.to_string_lossy().to_string();
59 if !class_to_idx.contains_key(&class_str) {
60 let class_idx = classes.len();
61 classes.push(class_str.clone());
62 class_to_idx.insert(class_str.clone(), class_idx);
63
64 let images = load_images_from_dir(&path)?;
66 for (image, _filename) in images {
67 let tensor = image_to_tensor(&image)?;
68 data.push((tensor, class_idx));
69 }
70 }
71 }
72 }
73 }
74
75 if classes.is_empty() {
76 return Err(VisionError::TransformError(
77 "No class directories found".to_string(),
78 ));
79 }
80
81 Ok(Self {
82 data,
83 class_to_idx,
84 classes,
85 })
86 }
87
88 pub fn len(&self) -> usize {
89 self.data.len()
90 }
91
92 pub fn is_empty(&self) -> bool {
93 self.data.is_empty()
94 }
95
96 pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
97 self.data.get(index).cloned()
98 }
99
100 pub fn classes(&self) -> &[String] {
101 &self.classes
102 }
103
104 pub fn class_to_idx(&self) -> &HashMap<String, usize> {
105 &self.class_to_idx
106 }
107}
108
109#[derive(Debug)]
111pub struct ImageNet {
112 data: Vec<Tensor<f32>>,
113 labels: Vec<usize>,
114}
115
116impl ImageNet {
117 pub fn new(_root: &str, _train: bool) -> Result<Self> {
118 eprintln!("Warning: ImageNet placeholder implementation. Use OptimizedImageDataset for real datasets.");
119 Ok(Self {
120 data: vec![creation::zeros(&[3, 224, 224]).expect("tensor creation should succeed")],
121 labels: vec![0],
122 })
123 }
124
125 pub fn len(&self) -> usize {
126 self.data.len()
127 }
128
129 pub fn is_empty(&self) -> bool {
130 self.data.is_empty()
131 }
132
133 pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
134 if index < self.data.len() {
135 Some((self.data[index].clone(), self.labels[index]))
136 } else {
137 None
138 }
139 }
140}
141
142#[derive(Debug)]
147pub struct CIFAR10 {
148 data: Vec<Tensor<f32>>,
149 labels: Vec<usize>,
150 classes: Vec<String>,
151}
152
153impl CIFAR10 {
154 pub fn new<P: AsRef<Path>>(root: P, train: bool, download: bool) -> Result<Self> {
158 eprintln!("Warning: CIFAR10 loads all data into memory. Consider using OptimizedCIFARDataset for memory efficiency.");
159
160 let root_path = root.as_ref();
161
162 if !root_path.exists() {
164 std::fs::create_dir_all(root_path)?;
165 }
166
167 let classes = vec![
168 "airplane".to_string(),
169 "automobile".to_string(),
170 "bird".to_string(),
171 "cat".to_string(),
172 "deer".to_string(),
173 "dog".to_string(),
174 "frog".to_string(),
175 "horse".to_string(),
176 "ship".to_string(),
177 "truck".to_string(),
178 ];
179
180 let (all_data, all_labels) = if train {
181 let mut data = Vec::new();
183 let mut labels = Vec::new();
184
185 for i in 1..=5 {
186 let batch_file = root_path.join(format!("data_batch_{}.bin", i));
187 if !batch_file.exists() {
188 if download {
189 return Err(VisionError::TransformError(
190 format!("CIFAR-10 files not found in {:?}. Please download them manually from https://www.cs.toronto.edu/~kriz/cifar.html", root_path)
191 ));
192 } else {
193 return Err(VisionError::IoError(std::io::Error::new(
194 std::io::ErrorKind::NotFound,
195 format!("CIFAR-10 training batch {} not found in {:?}", i, root_path),
196 )));
197 }
198 }
199
200 let (batch_data, batch_labels) = Self::load_batch(&batch_file)?;
201 data.extend(batch_data);
202 labels.extend(batch_labels);
203 }
204
205 (data, labels)
206 } else {
207 let test_file = root_path.join("test_batch.bin");
209 if !test_file.exists() {
210 if download {
211 return Err(VisionError::TransformError(
212 format!("CIFAR-10 files not found in {:?}. Please download them manually from https://www.cs.toronto.edu/~kriz/cifar.html", root_path)
213 ));
214 } else {
215 return Err(VisionError::IoError(std::io::Error::new(
216 std::io::ErrorKind::NotFound,
217 format!("CIFAR-10 test batch not found in {:?}", root_path),
218 )));
219 }
220 }
221
222 Self::load_batch(&test_file)?
223 };
224
225 Ok(Self {
226 data: all_data,
227 labels: all_labels,
228 classes,
229 })
230 }
231
232 fn load_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<Tensor<f32>>, Vec<usize>)> {
233 let data = std::fs::read(path)?;
234
235 const SAMPLES_PER_BATCH: usize = 10000;
238 const BYTES_PER_SAMPLE: usize = 1 + 3072; if data.len() != SAMPLES_PER_BATCH * BYTES_PER_SAMPLE {
241 return Err(VisionError::TransformError(format!(
242 "Invalid CIFAR-10 batch file size. Expected {}, got {}",
243 SAMPLES_PER_BATCH * BYTES_PER_SAMPLE,
244 data.len()
245 )));
246 }
247
248 let mut images = Vec::with_capacity(SAMPLES_PER_BATCH);
249 let mut labels = Vec::with_capacity(SAMPLES_PER_BATCH);
250
251 for i in 0..SAMPLES_PER_BATCH {
252 let start_idx = i * BYTES_PER_SAMPLE;
253
254 let label = data[start_idx] as usize;
256 labels.push(label);
257
258 let tensor = creation::zeros(&[3, 32, 32]).expect("tensor creation should succeed");
260
261 for channel in 0..3 {
263 for y in 0..32 {
264 for x in 0..32 {
265 let pixel_idx = start_idx + 1 + channel * 1024 + y * 32 + x;
266 let pixel_val = data[pixel_idx] as f32 / 255.0; tensor.set(&[channel, y, x], pixel_val)?;
268 }
269 }
270 }
271
272 images.push(tensor);
273 }
274
275 Ok((images, labels))
276 }
277
278 pub fn len(&self) -> usize {
279 self.data.len()
280 }
281
282 pub fn is_empty(&self) -> bool {
283 self.data.is_empty()
284 }
285
286 pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
287 if index < self.data.len() {
288 Some((self.data[index].clone(), self.labels[index]))
289 } else {
290 None
291 }
292 }
293
294 pub fn classes(&self) -> &[String] {
295 &self.classes
296 }
297}
298
299#[derive(Debug)]
304pub struct MNIST {
305 data: Vec<Tensor<f32>>,
306 labels: Vec<usize>,
307}
308
309impl MNIST {
310 pub fn new<P: AsRef<Path>>(root: P, train: bool, download: bool) -> Result<Self> {
314 eprintln!("Warning: MNIST loads all data into memory. Consider optimized alternatives for memory efficiency.");
315
316 let root_path = root.as_ref();
317
318 if !root_path.exists() {
320 std::fs::create_dir_all(root_path)?;
321 }
322
323 let (images_filename, labels_filename) = if train {
324 ("train-images-idx3-ubyte", "train-labels-idx1-ubyte")
325 } else {
326 ("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte")
327 };
328
329 let images_path = root_path.join(images_filename);
330 let labels_path = root_path.join(labels_filename);
331
332 if !images_path.exists() || !labels_path.exists() {
334 if download {
335 return Err(VisionError::TransformError(
336 format!("MNIST files not found in {:?}. Please download them manually from http://yann.lecun.com/exdb/mnist/", root_path)
337 ));
338 } else {
339 return Err(VisionError::IoError(std::io::Error::new(
340 std::io::ErrorKind::NotFound,
341 format!("MNIST files not found in {:?}", root_path),
342 )));
343 }
344 }
345
346 let images = Self::load_images(&images_path)?;
348 let labels = Self::load_labels(&labels_path)?;
349
350 if images.len() != labels.len() {
351 return Err(VisionError::TransformError(
352 "Number of images and labels don't match".to_string(),
353 ));
354 }
355
356 Ok(Self {
357 data: images,
358 labels,
359 })
360 }
361
362 fn load_images<P: AsRef<Path>>(path: P) -> Result<Vec<Tensor<f32>>> {
363 let data = std::fs::read(path)?;
364
365 if data.len() < 16 {
366 return Err(VisionError::TransformError(
367 "Invalid MNIST images file format".to_string(),
368 ));
369 }
370
371 let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
373 let num_images = u32::from_be_bytes([data[4], data[5], data[6], data[7]]) as usize;
374 let rows = u32::from_be_bytes([data[8], data[9], data[10], data[11]]) as usize;
375 let cols = u32::from_be_bytes([data[12], data[13], data[14], data[15]]) as usize;
376
377 if magic != 0x00000803 {
378 return Err(VisionError::TransformError(
379 "Invalid MNIST images file magic number".to_string(),
380 ));
381 }
382
383 let mut images = Vec::with_capacity(num_images);
384 let image_size = rows * cols;
385
386 for i in 0..num_images {
387 let start_idx = 16 + i * image_size;
388 let end_idx = start_idx + image_size;
389
390 if end_idx > data.len() {
391 break;
392 }
393
394 let tensor = creation::zeros(&[1, rows, cols]).expect("tensor creation should succeed");
395
396 for (pixel_idx, &pixel_val) in data[start_idx..end_idx].iter().enumerate() {
397 let y = pixel_idx / cols;
398 let x = pixel_idx % cols;
399 let normalized_val = pixel_val as f32 / 255.0;
400 tensor.set(&[0, y, x], normalized_val)?;
401 }
402
403 images.push(tensor);
404 }
405
406 Ok(images)
407 }
408
409 fn load_labels<P: AsRef<Path>>(path: P) -> Result<Vec<usize>> {
410 let data = std::fs::read(path)?;
411
412 if data.len() < 8 {
413 return Err(VisionError::TransformError(
414 "Invalid MNIST labels file format".to_string(),
415 ));
416 }
417
418 let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
420 let num_labels = u32::from_be_bytes([data[4], data[5], data[6], data[7]]) as usize;
421
422 if magic != 0x00000801 {
423 return Err(VisionError::TransformError(
424 "Invalid MNIST labels file magic number".to_string(),
425 ));
426 }
427
428 if data.len() < 8 + num_labels {
429 return Err(VisionError::TransformError(
430 "MNIST labels file too short".to_string(),
431 ));
432 }
433
434 let labels = data[8..8 + num_labels]
435 .iter()
436 .map(|&label| label as usize)
437 .collect();
438
439 Ok(labels)
440 }
441
442 pub fn len(&self) -> usize {
443 self.data.len()
444 }
445
446 pub fn is_empty(&self) -> bool {
447 self.data.is_empty()
448 }
449
450 pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
451 if index < self.data.len() {
452 Some((self.data[index].clone(), self.labels[index]))
453 } else {
454 None
455 }
456 }
457}
458
459pub fn create_optimized_image_dataset<P: AsRef<Path>>(root: P) -> Result<OptimizedImageDataset> {
461 OptimizedDatasetBuilder::new()
462 .with_cache(1000, 512) .with_prefetch(true, 16) .build_image_dataset(root)
465}
466
467pub fn create_optimized_cifar_dataset<P: AsRef<Path>>(
469 root: P,
470 is_cifar100: bool,
471 train: bool,
472) -> Result<OptimizedCIFARDataset> {
473 OptimizedDatasetBuilder::new()
474 .with_cache(2000, 256) .with_prefetch(true, 32) .build_cifar_dataset(root, is_cifar100, train)
477}
478
479pub type CifarDataset = CIFAR10;
481pub type MnistDataset = MNIST;
482
483#[derive(Debug)]
485pub struct CocoDataset {
486 data: Vec<Tensor<f32>>,
487 labels: Vec<usize>,
488}
489
490impl CocoDataset {
491 pub fn new<P: AsRef<Path>>(_root: P, _train: bool) -> Result<Self> {
492 eprintln!("Warning: CocoDataset is a placeholder implementation");
493 Ok(Self {
494 data: vec![torsh_tensor::creation::zeros(&[3, 224, 224])
495 .expect("tensor creation should succeed")],
496 labels: vec![0],
497 })
498 }
499
500 pub fn len(&self) -> usize {
501 self.data.len()
502 }
503
504 pub fn is_empty(&self) -> bool {
505 self.data.is_empty()
506 }
507
508 pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
509 if index < self.data.len() {
510 Some((self.data[index].clone(), self.labels[index]))
511 } else {
512 None
513 }
514 }
515}
516
517#[derive(Debug)]
518pub struct VocDataset {
519 data: Vec<Tensor<f32>>,
520 labels: Vec<usize>,
521}
522
523impl VocDataset {
524 pub fn new<P: AsRef<Path>>(_root: P, _train: bool) -> Result<Self> {
525 eprintln!("Warning: VocDataset is a placeholder implementation");
526 Ok(Self {
527 data: vec![torsh_tensor::creation::zeros(&[3, 224, 224])
528 .expect("tensor creation should succeed")],
529 labels: vec![0],
530 })
531 }
532
533 pub fn len(&self) -> usize {
534 self.data.len()
535 }
536
537 pub fn is_empty(&self) -> bool {
538 self.data.is_empty()
539 }
540
541 pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
542 if index < self.data.len() {
543 Some((self.data[index].clone(), self.labels[index]))
544 } else {
545 None
546 }
547 }
548}