Skip to main content

torsh_vision/
datasets_impl.rs

1//! Dataset loading and management for torsh-vision
2//!
3//! This module provides both legacy dataset implementations and optimized alternatives
4//! with lazy loading, caching, and memory management features.
5
6// Include the optimized implementations directly
7pub use crate::optimized_impl::*;
8
9// Legacy implementations (kept for backward compatibility)
10use crate::utils::{image_to_tensor, load_images_from_dir};
11use crate::{Result, VisionError};
12use image::DynamicImage;
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15use torsh_tensor::creation;
16use torsh_tensor::Tensor;
17
18/// Legacy ImageFolder dataset for loading images from a directory structure
19/// where each subdirectory represents a class
20///
21/// **Note**: This implementation loads all images into memory at once.
22/// For large datasets, consider using `OptimizedImageDataset` instead.
23#[derive(Debug)]
24pub struct ImageFolder {
25    data: Vec<(Tensor<f32>, usize)>,
26    class_to_idx: HashMap<String, usize>,
27    classes: Vec<String>,
28}
29
30impl ImageFolder {
31    /// Create a new ImageFolder dataset
32    ///
33    /// **Memory Warning**: This loads all images into memory immediately.
34    /// For datasets larger than a few GB, use `OptimizedImageDataset`.
35    pub fn new<P: AsRef<Path>>(root: P) -> Result<Self> {
36        eprintln!("Warning: ImageFolder loads all data into memory. Consider using OptimizedImageDataset for large datasets.");
37
38        let root_path = root.as_ref();
39
40        if !root_path.exists() {
41            return Err(VisionError::IoError(std::io::Error::new(
42                std::io::ErrorKind::NotFound,
43                format!("Directory {:?} does not exist", root_path),
44            )));
45        }
46
47        let mut classes = Vec::new();
48        let mut class_to_idx = HashMap::new();
49        let mut data = Vec::new();
50
51        // Collect all subdirectories as classes
52        for entry in std::fs::read_dir(root_path)? {
53            let entry = entry?;
54            let path = entry.path();
55
56            if path.is_dir() {
57                if let Some(class_name) = path.file_name() {
58                    let class_str = class_name.to_string_lossy().to_string();
59                    if !class_to_idx.contains_key(&class_str) {
60                        let class_idx = classes.len();
61                        classes.push(class_str.clone());
62                        class_to_idx.insert(class_str.clone(), class_idx);
63
64                        // Load images from this class directory
65                        let images = load_images_from_dir(&path)?;
66                        for (image, _filename) in images {
67                            let tensor = image_to_tensor(&image)?;
68                            data.push((tensor, class_idx));
69                        }
70                    }
71                }
72            }
73        }
74
75        if classes.is_empty() {
76            return Err(VisionError::TransformError(
77                "No class directories found".to_string(),
78            ));
79        }
80
81        Ok(Self {
82            data,
83            class_to_idx,
84            classes,
85        })
86    }
87
88    pub fn len(&self) -> usize {
89        self.data.len()
90    }
91
92    pub fn is_empty(&self) -> bool {
93        self.data.is_empty()
94    }
95
96    pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
97        self.data.get(index).cloned()
98    }
99
100    pub fn classes(&self) -> &[String] {
101        &self.classes
102    }
103
104    pub fn class_to_idx(&self) -> &HashMap<String, usize> {
105        &self.class_to_idx
106    }
107}
108
109/// Legacy ImageNet dataset placeholder
110#[derive(Debug)]
111pub struct ImageNet {
112    data: Vec<Tensor<f32>>,
113    labels: Vec<usize>,
114}
115
116impl ImageNet {
117    pub fn new(_root: &str, _train: bool) -> Result<Self> {
118        eprintln!("Warning: ImageNet placeholder implementation. Use OptimizedImageDataset for real datasets.");
119        Ok(Self {
120            data: vec![creation::zeros(&[3, 224, 224]).expect("tensor creation should succeed")],
121            labels: vec![0],
122        })
123    }
124
125    pub fn len(&self) -> usize {
126        self.data.len()
127    }
128
129    pub fn is_empty(&self) -> bool {
130        self.data.is_empty()
131    }
132
133    pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
134        if index < self.data.len() {
135            Some((self.data[index].clone(), self.labels[index]))
136        } else {
137            None
138        }
139    }
140}
141
142/// Legacy CIFAR-10 dataset loader
143///
144/// **Note**: This implementation loads the entire dataset into memory at once.
145/// For memory-efficient loading, use `OptimizedCIFARDataset` instead.
146#[derive(Debug)]
147pub struct CIFAR10 {
148    data: Vec<Tensor<f32>>,
149    labels: Vec<usize>,
150    classes: Vec<String>,
151}
152
153impl CIFAR10 {
154    /// Create a new CIFAR-10 dataset
155    ///
156    /// **Memory Warning**: This loads all data into memory immediately.
157    pub fn new<P: AsRef<Path>>(root: P, train: bool, download: bool) -> Result<Self> {
158        eprintln!("Warning: CIFAR10 loads all data into memory. Consider using OptimizedCIFARDataset for memory efficiency.");
159
160        let root_path = root.as_ref();
161
162        // Create directory if it doesn't exist
163        if !root_path.exists() {
164            std::fs::create_dir_all(root_path)?;
165        }
166
167        let classes = vec![
168            "airplane".to_string(),
169            "automobile".to_string(),
170            "bird".to_string(),
171            "cat".to_string(),
172            "deer".to_string(),
173            "dog".to_string(),
174            "frog".to_string(),
175            "horse".to_string(),
176            "ship".to_string(),
177            "truck".to_string(),
178        ];
179
180        let (all_data, all_labels) = if train {
181            // Load training batches
182            let mut data = Vec::new();
183            let mut labels = Vec::new();
184
185            for i in 1..=5 {
186                let batch_file = root_path.join(format!("data_batch_{}.bin", i));
187                if !batch_file.exists() {
188                    if download {
189                        return Err(VisionError::TransformError(
190                            format!("CIFAR-10 files not found in {:?}. Please download them manually from https://www.cs.toronto.edu/~kriz/cifar.html", root_path)
191                        ));
192                    } else {
193                        return Err(VisionError::IoError(std::io::Error::new(
194                            std::io::ErrorKind::NotFound,
195                            format!("CIFAR-10 training batch {} not found in {:?}", i, root_path),
196                        )));
197                    }
198                }
199
200                let (batch_data, batch_labels) = Self::load_batch(&batch_file)?;
201                data.extend(batch_data);
202                labels.extend(batch_labels);
203            }
204
205            (data, labels)
206        } else {
207            // Load test batch
208            let test_file = root_path.join("test_batch.bin");
209            if !test_file.exists() {
210                if download {
211                    return Err(VisionError::TransformError(
212                        format!("CIFAR-10 files not found in {:?}. Please download them manually from https://www.cs.toronto.edu/~kriz/cifar.html", root_path)
213                    ));
214                } else {
215                    return Err(VisionError::IoError(std::io::Error::new(
216                        std::io::ErrorKind::NotFound,
217                        format!("CIFAR-10 test batch not found in {:?}", root_path),
218                    )));
219                }
220            }
221
222            Self::load_batch(&test_file)?
223        };
224
225        Ok(Self {
226            data: all_data,
227            labels: all_labels,
228            classes,
229        })
230    }
231
232    fn load_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<Tensor<f32>>, Vec<usize>)> {
233        let data = std::fs::read(path)?;
234
235        // Each CIFAR-10 batch contains 10,000 samples
236        // Each sample is 1 byte label + 3072 bytes image data (32x32x3)
237        const SAMPLES_PER_BATCH: usize = 10000;
238        const BYTES_PER_SAMPLE: usize = 1 + 3072; // 1 label + 32*32*3 pixels
239
240        if data.len() != SAMPLES_PER_BATCH * BYTES_PER_SAMPLE {
241            return Err(VisionError::TransformError(format!(
242                "Invalid CIFAR-10 batch file size. Expected {}, got {}",
243                SAMPLES_PER_BATCH * BYTES_PER_SAMPLE,
244                data.len()
245            )));
246        }
247
248        let mut images = Vec::with_capacity(SAMPLES_PER_BATCH);
249        let mut labels = Vec::with_capacity(SAMPLES_PER_BATCH);
250
251        for i in 0..SAMPLES_PER_BATCH {
252            let start_idx = i * BYTES_PER_SAMPLE;
253
254            // First byte is the label
255            let label = data[start_idx] as usize;
256            labels.push(label);
257
258            // Next 3072 bytes are the image data (R, G, B channels in that order)
259            let tensor = creation::zeros(&[3, 32, 32]).expect("tensor creation should succeed");
260
261            // CIFAR-10 format: first 1024 bytes are red channel, next 1024 green, last 1024 blue
262            for channel in 0..3 {
263                for y in 0..32 {
264                    for x in 0..32 {
265                        let pixel_idx = start_idx + 1 + channel * 1024 + y * 32 + x;
266                        let pixel_val = data[pixel_idx] as f32 / 255.0; // Normalize to [0, 1]
267                        tensor.set(&[channel, y, x], pixel_val)?;
268                    }
269                }
270            }
271
272            images.push(tensor);
273        }
274
275        Ok((images, labels))
276    }
277
278    pub fn len(&self) -> usize {
279        self.data.len()
280    }
281
282    pub fn is_empty(&self) -> bool {
283        self.data.is_empty()
284    }
285
286    pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
287        if index < self.data.len() {
288            Some((self.data[index].clone(), self.labels[index]))
289        } else {
290            None
291        }
292    }
293
294    pub fn classes(&self) -> &[String] {
295        &self.classes
296    }
297}
298
299/// Legacy MNIST dataset loader
300///
301/// **Note**: This implementation loads the entire dataset into memory at once.
302/// For memory-efficient loading, consider using an optimized alternative.
303#[derive(Debug)]
304pub struct MNIST {
305    data: Vec<Tensor<f32>>,
306    labels: Vec<usize>,
307}
308
309impl MNIST {
310    /// Create a new MNIST dataset
311    ///
312    /// **Memory Warning**: This loads all data into memory immediately.
313    pub fn new<P: AsRef<Path>>(root: P, train: bool, download: bool) -> Result<Self> {
314        eprintln!("Warning: MNIST loads all data into memory. Consider optimized alternatives for memory efficiency.");
315
316        let root_path = root.as_ref();
317
318        // Create directory if it doesn't exist
319        if !root_path.exists() {
320            std::fs::create_dir_all(root_path)?;
321        }
322
323        let (images_filename, labels_filename) = if train {
324            ("train-images-idx3-ubyte", "train-labels-idx1-ubyte")
325        } else {
326            ("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte")
327        };
328
329        let images_path = root_path.join(images_filename);
330        let labels_path = root_path.join(labels_filename);
331
332        // Check if files exist, if not and download is true, suggest downloading manually
333        if !images_path.exists() || !labels_path.exists() {
334            if download {
335                return Err(VisionError::TransformError(
336                    format!("MNIST files not found in {:?}. Please download them manually from http://yann.lecun.com/exdb/mnist/", root_path)
337                ));
338            } else {
339                return Err(VisionError::IoError(std::io::Error::new(
340                    std::io::ErrorKind::NotFound,
341                    format!("MNIST files not found in {:?}", root_path),
342                )));
343            }
344        }
345
346        // Load images and labels
347        let images = Self::load_images(&images_path)?;
348        let labels = Self::load_labels(&labels_path)?;
349
350        if images.len() != labels.len() {
351            return Err(VisionError::TransformError(
352                "Number of images and labels don't match".to_string(),
353            ));
354        }
355
356        Ok(Self {
357            data: images,
358            labels,
359        })
360    }
361
362    fn load_images<P: AsRef<Path>>(path: P) -> Result<Vec<Tensor<f32>>> {
363        let data = std::fs::read(path)?;
364
365        if data.len() < 16 {
366            return Err(VisionError::TransformError(
367                "Invalid MNIST images file format".to_string(),
368            ));
369        }
370
371        // Read header
372        let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
373        let num_images = u32::from_be_bytes([data[4], data[5], data[6], data[7]]) as usize;
374        let rows = u32::from_be_bytes([data[8], data[9], data[10], data[11]]) as usize;
375        let cols = u32::from_be_bytes([data[12], data[13], data[14], data[15]]) as usize;
376
377        if magic != 0x00000803 {
378            return Err(VisionError::TransformError(
379                "Invalid MNIST images file magic number".to_string(),
380            ));
381        }
382
383        let mut images = Vec::with_capacity(num_images);
384        let image_size = rows * cols;
385
386        for i in 0..num_images {
387            let start_idx = 16 + i * image_size;
388            let end_idx = start_idx + image_size;
389
390            if end_idx > data.len() {
391                break;
392            }
393
394            let tensor = creation::zeros(&[1, rows, cols]).expect("tensor creation should succeed");
395
396            for (pixel_idx, &pixel_val) in data[start_idx..end_idx].iter().enumerate() {
397                let y = pixel_idx / cols;
398                let x = pixel_idx % cols;
399                let normalized_val = pixel_val as f32 / 255.0;
400                tensor.set(&[0, y, x], normalized_val)?;
401            }
402
403            images.push(tensor);
404        }
405
406        Ok(images)
407    }
408
409    fn load_labels<P: AsRef<Path>>(path: P) -> Result<Vec<usize>> {
410        let data = std::fs::read(path)?;
411
412        if data.len() < 8 {
413            return Err(VisionError::TransformError(
414                "Invalid MNIST labels file format".to_string(),
415            ));
416        }
417
418        // Read header
419        let magic = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
420        let num_labels = u32::from_be_bytes([data[4], data[5], data[6], data[7]]) as usize;
421
422        if magic != 0x00000801 {
423            return Err(VisionError::TransformError(
424                "Invalid MNIST labels file magic number".to_string(),
425            ));
426        }
427
428        if data.len() < 8 + num_labels {
429            return Err(VisionError::TransformError(
430                "MNIST labels file too short".to_string(),
431            ));
432        }
433
434        let labels = data[8..8 + num_labels]
435            .iter()
436            .map(|&label| label as usize)
437            .collect();
438
439        Ok(labels)
440    }
441
442    pub fn len(&self) -> usize {
443        self.data.len()
444    }
445
446    pub fn is_empty(&self) -> bool {
447        self.data.is_empty()
448    }
449
450    pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
451        if index < self.data.len() {
452            Some((self.data[index].clone(), self.labels[index]))
453        } else {
454            None
455        }
456    }
457}
458
459/// Helper function to create optimized datasets with sensible defaults
460pub fn create_optimized_image_dataset<P: AsRef<Path>>(root: P) -> Result<OptimizedImageDataset> {
461    OptimizedDatasetBuilder::new()
462        .with_cache(1000, 512) // 1000 items, 512MB cache
463        .with_prefetch(true, 16) // Enable prefetching with batch size 16
464        .build_image_dataset(root)
465}
466
467/// Helper function to create optimized CIFAR datasets
468pub fn create_optimized_cifar_dataset<P: AsRef<Path>>(
469    root: P,
470    is_cifar100: bool,
471    train: bool,
472) -> Result<OptimizedCIFARDataset> {
473    OptimizedDatasetBuilder::new()
474        .with_cache(2000, 256) // 2000 items, 256MB cache
475        .with_prefetch(true, 32) // Prefetch in larger batches for CIFAR
476        .build_cifar_dataset(root, is_cifar100, train)
477}
478
479// Type aliases for backward compatibility
480pub type CifarDataset = CIFAR10;
481pub type MnistDataset = MNIST;
482
483// Placeholder implementations for datasets that aren't fully implemented yet
484#[derive(Debug)]
485pub struct CocoDataset {
486    data: Vec<Tensor<f32>>,
487    labels: Vec<usize>,
488}
489
490impl CocoDataset {
491    pub fn new<P: AsRef<Path>>(_root: P, _train: bool) -> Result<Self> {
492        eprintln!("Warning: CocoDataset is a placeholder implementation");
493        Ok(Self {
494            data: vec![torsh_tensor::creation::zeros(&[3, 224, 224])
495                .expect("tensor creation should succeed")],
496            labels: vec![0],
497        })
498    }
499
500    pub fn len(&self) -> usize {
501        self.data.len()
502    }
503
504    pub fn is_empty(&self) -> bool {
505        self.data.is_empty()
506    }
507
508    pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
509        if index < self.data.len() {
510            Some((self.data[index].clone(), self.labels[index]))
511        } else {
512            None
513        }
514    }
515}
516
517#[derive(Debug)]
518pub struct VocDataset {
519    data: Vec<Tensor<f32>>,
520    labels: Vec<usize>,
521}
522
523impl VocDataset {
524    pub fn new<P: AsRef<Path>>(_root: P, _train: bool) -> Result<Self> {
525        eprintln!("Warning: VocDataset is a placeholder implementation");
526        Ok(Self {
527            data: vec![torsh_tensor::creation::zeros(&[3, 224, 224])
528                .expect("tensor creation should succeed")],
529            labels: vec![0],
530        })
531    }
532
533    pub fn len(&self) -> usize {
534        self.data.len()
535    }
536
537    pub fn is_empty(&self) -> bool {
538        self.data.is_empty()
539    }
540
541    pub fn get(&self, index: usize) -> Option<(Tensor<f32>, usize)> {
542        if index < self.data.len() {
543            Some((self.data[index].clone(), self.labels[index]))
544        } else {
545            None
546        }
547    }
548}