Skip to main content

torsh_vision/
optimized_impl.rs

1//! Optimized dataset implementations with lazy loading and memory management
2//!
3//! This module provides memory-efficient alternatives to the basic dataset implementations
4//! with features like lazy loading, caching integration, and unified interfaces.
5
6use crate::error_handling::{EnhancedVisionError, ErrorHandler};
7use crate::io::VisionIO;
8use crate::ImageCache;
9use crate::{Result, VisionError};
10use image::DynamicImage;
11use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use torsh_tensor::{creation, Tensor};
15
16/// Configuration for optimized datasets
17#[derive(Debug, Clone)]
18pub struct DatasetConfig {
19    /// Maximum number of items to keep in memory cache
20    pub max_cache_items: usize,
21    /// Maximum memory usage in MB for cached items
22    pub max_cache_memory_mb: usize,
23    /// Enable background prefetching
24    pub enable_prefetch: bool,
25    /// Batch size for prefetching
26    pub prefetch_batch_size: usize,
27    /// Enable data validation
28    pub validate_data: bool,
29}
30
31impl Default for DatasetConfig {
32    fn default() -> Self {
33        Self {
34            max_cache_items: 1000,
35            max_cache_memory_mb: 512,
36            enable_prefetch: true,
37            prefetch_batch_size: 16,
38            validate_data: true,
39        }
40    }
41}
42
43/// Generic dataset trait for unified interface
44pub trait OptimizedDataset: Send + Sync {
45    type Item;
46
47    /// Get the total number of items in the dataset
48    fn len(&self) -> usize;
49
50    /// Check if the dataset is empty
51    fn is_empty(&self) -> bool {
52        self.len() == 0
53    }
54
55    /// Get an item by index with lazy loading
56    fn get_item(&self, index: usize) -> Result<Self::Item>;
57
58    /// Get multiple items efficiently with batching
59    fn get_batch(&self, indices: &[usize]) -> Result<Vec<Self::Item>> {
60        indices.iter().map(|&i| self.get_item(i)).collect()
61    }
62
63    /// Get dataset metadata
64    fn metadata(&self) -> DatasetMetadata;
65
66    /// Prefetch items for improved performance
67    fn prefetch(&self, indices: &[usize]) -> Result<()>;
68
69    /// Clear cache to free memory
70    fn clear_cache(&self);
71
72    /// Get cache statistics
73    fn cache_stats(&self) -> CacheStatistics;
74}
75
76/// Dataset metadata information
77#[derive(Debug, Clone)]
78pub struct DatasetMetadata {
79    pub name: String,
80    pub version: String,
81    pub num_classes: usize,
82    pub class_names: Vec<String>,
83    pub total_items: usize,
84    pub item_shape: Vec<usize>,
85    pub data_type: String,
86}
87
88/// Cache statistics for monitoring
89#[derive(Debug, Clone)]
90pub struct CacheStatistics {
91    pub cache_hits: usize,
92    pub cache_misses: usize,
93    pub hit_rate: f64,
94    pub memory_usage_mb: f64,
95    pub cached_items: usize,
96}
97
98/// Optimized image classification dataset with lazy loading
99pub struct OptimizedImageDataset {
100    config: DatasetConfig,
101    io: Arc<VisionIO>,
102    cache: Arc<ImageCache>,
103    image_paths: Vec<PathBuf>,
104    labels: Vec<usize>,
105    class_names: Vec<String>,
106    class_to_idx: HashMap<String, usize>,
107    metadata: DatasetMetadata,
108}
109
110impl OptimizedImageDataset {
111    /// Create a new optimized image dataset
112    pub fn new<P: AsRef<Path>>(root: P, config: DatasetConfig) -> Result<Self> {
113        let root_path = root.as_ref();
114
115        if !root_path.exists() {
116            return Err(VisionError::IoError(std::io::Error::new(
117                std::io::ErrorKind::NotFound,
118                format!("Dataset directory {:?} does not exist", root_path),
119            )));
120        }
121
122        let io = Arc::new(VisionIO::new());
123        let cache = Arc::new(ImageCache::new(config.max_cache_memory_mb));
124
125        let mut image_paths = Vec::new();
126        let mut labels = Vec::new();
127        let mut class_names = Vec::new();
128        let mut class_to_idx = HashMap::new();
129
130        // Scan directory structure for classes and images
131        for entry in std::fs::read_dir(root_path)? {
132            let entry = entry?;
133            let path = entry.path();
134
135            if path.is_dir() {
136                if let Some(class_name) = path.file_name() {
137                    let class_str = class_name.to_string_lossy().to_string();
138                    let class_idx = class_names.len();
139                    class_names.push(class_str.clone());
140                    class_to_idx.insert(class_str, class_idx);
141
142                    // Scan images in this class directory
143                    for img_entry in std::fs::read_dir(&path)? {
144                        let img_entry = img_entry?;
145                        let img_path = img_entry.path();
146
147                        if img_path.is_file() && io.is_supported_image(&img_path) {
148                            image_paths.push(img_path);
149                            labels.push(class_idx);
150                        }
151                    }
152                }
153            }
154        }
155
156        if class_names.is_empty() {
157            return Err(VisionError::TransformError(
158                "No class directories found".to_string(),
159            ));
160        }
161
162        let metadata = DatasetMetadata {
163            name: "OptimizedImageDataset".to_string(),
164            version: "1.0".to_string(),
165            num_classes: class_names.len(),
166            class_names: class_names.clone(),
167            total_items: image_paths.len(),
168            item_shape: vec![3, 224, 224], // Default shape, will be updated dynamically
169            data_type: "f32".to_string(),
170        };
171
172        Ok(Self {
173            config,
174            io,
175            cache,
176            image_paths,
177            labels,
178            class_names,
179            class_to_idx,
180            metadata,
181        })
182    }
183
184    /// Get class information
185    pub fn classes(&self) -> &[String] {
186        &self.class_names
187    }
188
189    /// Get class to index mapping
190    pub fn class_to_idx(&self) -> &HashMap<String, usize> {
191        &self.class_to_idx
192    }
193}
194
195impl OptimizedDataset for OptimizedImageDataset {
196    type Item = (Tensor<f32>, usize);
197
198    fn len(&self) -> usize {
199        self.image_paths.len()
200    }
201
202    fn get_item(&self, index: usize) -> Result<Self::Item> {
203        if index >= self.image_paths.len() {
204            return Err(VisionError::InvalidArgument(format!(
205                "Index {} out of bounds for dataset of size {}",
206                index,
207                self.image_paths.len()
208            )));
209        }
210
211        let image_path = &self.image_paths[index];
212        let label = self.labels[index];
213
214        // Load image using cached I/O
215        let image = self.cache.get_or_load(image_path)?;
216
217        // Convert to tensor
218        let tensor = crate::utils::image_to_tensor(&image)?;
219
220        Ok((tensor, label))
221    }
222
223    fn get_batch(&self, indices: &[usize]) -> Result<Vec<Self::Item>> {
224        // Validate all indices first
225        for &index in indices {
226            if index >= self.image_paths.len() {
227                return Err(VisionError::InvalidArgument(format!(
228                    "Index {} out of bounds for dataset of size {}",
229                    index,
230                    self.image_paths.len()
231                )));
232            }
233        }
234
235        // Use parallel loading for large batches
236        if indices.len() > 4 {
237            // For now, fall back to sequential loading
238            // In the future, could implement parallel loading with rayon
239            indices.iter().map(|&i| self.get_item(i)).collect()
240        } else {
241            indices.iter().map(|&i| self.get_item(i)).collect()
242        }
243    }
244
245    fn metadata(&self) -> DatasetMetadata {
246        self.metadata.clone()
247    }
248
249    fn prefetch(&self, indices: &[usize]) -> Result<()> {
250        if !self.config.enable_prefetch {
251            return Ok(());
252        }
253
254        // Collect image paths for prefetching
255        let paths: Vec<_> = indices
256            .iter()
257            .filter_map(|&i| self.image_paths.get(i))
258            .collect();
259
260        // Trigger background loading (simplified implementation)
261        for path in paths {
262            let _ = self.cache.get_or_load(path);
263        }
264
265        Ok(())
266    }
267
268    fn clear_cache(&self) {
269        self.cache.clear();
270    }
271
272    fn cache_stats(&self) -> CacheStatistics {
273        let stats = self.cache.stats();
274        CacheStatistics {
275            cache_hits: stats.hit_count,
276            cache_misses: stats.miss_count,
277            hit_rate: stats.hit_rate,
278            memory_usage_mb: stats.current_size_bytes as f64 / (1024.0 * 1024.0),
279            cached_items: stats.entry_count,
280        }
281    }
282}
283
284/// Optimized CIFAR dataset with lazy loading and validation
285pub struct OptimizedCIFARDataset {
286    config: DatasetConfig,
287    data_path: PathBuf,
288    is_cifar100: bool,
289    is_train: bool,
290    cached_data: Arc<std::sync::Mutex<HashMap<usize, (Tensor<f32>, usize)>>>,
291    classes: Vec<String>,
292    total_samples: usize,
293    metadata: DatasetMetadata,
294}
295
296impl OptimizedCIFARDataset {
297    /// Create a new optimized CIFAR dataset
298    pub fn new<P: AsRef<Path>>(
299        root: P,
300        is_cifar100: bool,
301        train: bool,
302        config: DatasetConfig,
303    ) -> Result<Self> {
304        let root_path = root.as_ref();
305
306        if !root_path.exists() {
307            std::fs::create_dir_all(root_path)?;
308        }
309
310        let (data_path, total_samples, classes) = if is_cifar100 {
311            let file_name = if train { "train.bin" } else { "test.bin" };
312            let path = root_path.join(file_name);
313
314            if !path.exists() {
315                return Err(VisionError::TransformError(
316                    format!("CIFAR-100 {} file not found in {:?}. Please download from https://www.cs.toronto.edu/~kriz/cifar.html", 
317                           if train { "training" } else { "test" }, root_path)
318                ));
319            }
320
321            let samples = if train { 50000 } else { 10000 };
322            let classes = Self::get_cifar100_classes();
323            (path, samples, classes)
324        } else {
325            // CIFAR-10
326            let total_samples = if train { 50000 } else { 10000 };
327            let path = if train {
328                root_path.join("data_batch_1.bin") // We'll handle all batches in get_item
329            } else {
330                root_path.join("test_batch.bin")
331            };
332
333            let classes = vec![
334                "airplane",
335                "automobile",
336                "bird",
337                "cat",
338                "deer",
339                "dog",
340                "frog",
341                "horse",
342                "ship",
343                "truck",
344            ]
345            .into_iter()
346            .map(|s| s.to_string())
347            .collect();
348
349            (path, total_samples, classes)
350        };
351
352        let metadata = DatasetMetadata {
353            name: if is_cifar100 { "CIFAR-100" } else { "CIFAR-10" }.to_string(),
354            version: "1.0".to_string(),
355            num_classes: classes.len(),
356            class_names: classes.clone(),
357            total_items: total_samples,
358            item_shape: vec![3, 32, 32],
359            data_type: "f32".to_string(),
360        };
361
362        Ok(Self {
363            config,
364            data_path,
365            is_cifar100,
366            is_train: train,
367            cached_data: Arc::new(std::sync::Mutex::new(HashMap::new())),
368            classes,
369            total_samples,
370            metadata,
371        })
372    }
373
374    fn get_cifar100_classes() -> Vec<String> {
375        vec![
376            "apple",
377            "aquarium_fish",
378            "baby",
379            "bear",
380            "beaver",
381            "bed",
382            "bee",
383            "beetle",
384            "bicycle",
385            "bottle",
386            "bowl",
387            "boy",
388            "bridge",
389            "bus",
390            "butterfly",
391            "camel",
392            "can",
393            "castle",
394            "caterpillar",
395            "cattle",
396            "chair",
397            "chimpanzee",
398            "clock",
399            "cloud",
400            "cockroach",
401            "couch",
402            "crab",
403            "crocodile",
404            "cup",
405            "dinosaur",
406            "dolphin",
407            "elephant",
408            "flatfish",
409            "forest",
410            "fox",
411            "girl",
412            "hamster",
413            "house",
414            "kangaroo",
415            "keyboard",
416            "lamp",
417            "lawn_mower",
418            "leopard",
419            "lion",
420            "lizard",
421            "lobster",
422            "man",
423            "maple_tree",
424            "motorcycle",
425            "mountain",
426            "mouse",
427            "mushroom",
428            "oak_tree",
429            "orange",
430            "orchid",
431            "otter",
432            "palm_tree",
433            "pear",
434            "pickup_truck",
435            "pine_tree",
436            "plain",
437            "plate",
438            "poppy",
439            "porcupine",
440            "possum",
441            "rabbit",
442            "raccoon",
443            "ray",
444            "road",
445            "rocket",
446            "rose",
447            "sea",
448            "seal",
449            "shark",
450            "shrew",
451            "skunk",
452            "skyscraper",
453            "snail",
454            "snake",
455            "spider",
456            "squirrel",
457            "streetcar",
458            "sunflower",
459            "sweet_pepper",
460            "table",
461            "tank",
462            "telephone",
463            "television",
464            "tiger",
465            "tractor",
466            "train",
467            "trout",
468            "tulip",
469            "turtle",
470            "wardrobe",
471            "whale",
472            "willow_tree",
473            "wolf",
474            "woman",
475            "worm",
476        ]
477        .into_iter()
478        .map(|s| s.to_string())
479        .collect()
480    }
481
482    fn load_cifar_sample(&self, index: usize) -> Result<(Tensor<f32>, usize)> {
483        // Check cache first
484        {
485            let cache = self
486                .cached_data
487                .lock()
488                .expect("lock should not be poisoned");
489            if let Some(cached_item) = cache.get(&index) {
490                return Ok(cached_item.clone());
491            }
492        }
493
494        // Load from disk
495        let (tensor, label) = if self.is_cifar100 {
496            self.load_cifar100_sample(index)?
497        } else {
498            self.load_cifar10_sample(index)?
499        };
500
501        // Cache the result
502        {
503            let mut cache = self
504                .cached_data
505                .lock()
506                .expect("lock should not be poisoned");
507            if cache.len() < self.config.max_cache_items {
508                cache.insert(index, (tensor.clone(), label));
509            }
510        }
511
512        Ok((tensor, label))
513    }
514
515    fn load_cifar10_sample(&self, index: usize) -> Result<(Tensor<f32>, usize)> {
516        let (batch_idx, sample_idx) = if self.is_train {
517            (index / 10000, index % 10000)
518        } else {
519            (0, index) // Test batch
520        };
521
522        let batch_file = if self.is_train {
523            self.data_path
524                .parent()
525                .expect("data_path should have parent")
526                .join(format!("data_batch_{}.bin", batch_idx + 1))
527        } else {
528            self.data_path.clone()
529        };
530
531        if !batch_file.exists() {
532            return Err(VisionError::IoError(std::io::Error::new(
533                std::io::ErrorKind::NotFound,
534                format!("CIFAR-10 batch file {:?} not found", batch_file),
535            )));
536        }
537
538        let data = std::fs::read(batch_file)?;
539        let start_idx = sample_idx * 3073; // 1 label + 3072 pixels
540
541        if start_idx + 3073 > data.len() {
542            return Err(VisionError::TransformError(
543                "Invalid CIFAR-10 file format".to_string(),
544            ));
545        }
546
547        let label = data[start_idx] as usize;
548        let tensor = torsh_tensor::creation::zeros(&[3, 32, 32])
549            .map_err(|e| VisionError::TransformError(format!("Failed to create tensor: {}", e)))?;
550
551        // Load RGB channels
552        for channel in 0..3 {
553            for y in 0..32 {
554                for x in 0..32 {
555                    let pixel_idx = start_idx + 1 + channel * 1024 + y * 32 + x;
556                    let pixel_val = data[pixel_idx] as f32 / 255.0;
557                    tensor.set(&[channel, y, x], pixel_val)?;
558                }
559            }
560        }
561
562        Ok((tensor, label))
563    }
564
565    fn load_cifar100_sample(&self, index: usize) -> Result<(Tensor<f32>, usize)> {
566        let data = std::fs::read(&self.data_path)?;
567        let start_idx = index * 3074; // 2 labels + 3072 pixels
568
569        if start_idx + 3074 > data.len() {
570            return Err(VisionError::TransformError(
571                "Invalid CIFAR-100 file format".to_string(),
572            ));
573        }
574
575        let _coarse_label = data[start_idx] as usize;
576        let fine_label = data[start_idx + 1] as usize;
577        let tensor = torsh_tensor::creation::zeros(&[3, 32, 32])
578            .map_err(|e| VisionError::TransformError(format!("Failed to create tensor: {}", e)))?;
579
580        // Load RGB channels
581        for channel in 0..3 {
582            for y in 0..32 {
583                for x in 0..32 {
584                    let pixel_idx = start_idx + 2 + channel * 1024 + y * 32 + x;
585                    let pixel_val = data[pixel_idx] as f32 / 255.0;
586                    tensor.set(&[channel, y, x], pixel_val)?;
587                }
588            }
589        }
590
591        Ok((tensor, fine_label))
592    }
593
594    /// Get class information
595    pub fn classes(&self) -> &[String] {
596        &self.classes
597    }
598}
599
600impl OptimizedDataset for OptimizedCIFARDataset {
601    type Item = (Tensor<f32>, usize);
602
603    fn len(&self) -> usize {
604        self.total_samples
605    }
606
607    fn get_item(&self, index: usize) -> Result<Self::Item> {
608        if index >= self.total_samples {
609            return Err(VisionError::InvalidArgument(format!(
610                "Index {} out of bounds for dataset of size {}",
611                index, self.total_samples
612            )));
613        }
614
615        self.load_cifar_sample(index)
616    }
617
618    fn metadata(&self) -> DatasetMetadata {
619        self.metadata.clone()
620    }
621
622    fn prefetch(&self, indices: &[usize]) -> Result<()> {
623        if !self.config.enable_prefetch {
624            return Ok(());
625        }
626
627        for &index in indices {
628            if index < self.total_samples {
629                let _ = self.load_cifar_sample(index);
630            }
631        }
632
633        Ok(())
634    }
635
636    fn clear_cache(&self) {
637        let mut cache = self
638            .cached_data
639            .lock()
640            .expect("lock should not be poisoned");
641        cache.clear();
642    }
643
644    fn cache_stats(&self) -> CacheStatistics {
645        let cache = self
646            .cached_data
647            .lock()
648            .expect("lock should not be poisoned");
649        CacheStatistics {
650            cache_hits: 0, // Would need to track this
651            cache_misses: 0,
652            hit_rate: 0.0,
653            memory_usage_mb: cache.len() as f64 * 3.0 * 32.0 * 32.0 * 4.0 / (1024.0 * 1024.0), // Rough estimate
654            cached_items: cache.len(),
655        }
656    }
657}
658
659/// Optimized dataset builder for easy configuration
660pub struct OptimizedDatasetBuilder {
661    config: DatasetConfig,
662}
663
664impl OptimizedDatasetBuilder {
665    /// Create a new dataset builder
666    pub fn new() -> Self {
667        Self {
668            config: DatasetConfig::default(),
669        }
670    }
671
672    /// Set cache configuration
673    pub fn with_cache(mut self, max_items: usize, max_memory_mb: usize) -> Self {
674        self.config.max_cache_items = max_items;
675        self.config.max_cache_memory_mb = max_memory_mb;
676        self
677    }
678
679    /// Enable or disable prefetching
680    pub fn with_prefetch(mut self, enable: bool, batch_size: usize) -> Self {
681        self.config.enable_prefetch = enable;
682        self.config.prefetch_batch_size = batch_size;
683        self
684    }
685
686    /// Enable or disable data validation
687    pub fn with_validation(mut self, enable: bool) -> Self {
688        self.config.validate_data = enable;
689        self
690    }
691
692    /// Build an optimized image dataset
693    pub fn build_image_dataset<P: AsRef<Path>>(self, root: P) -> Result<OptimizedImageDataset> {
694        OptimizedImageDataset::new(root, self.config)
695    }
696
697    /// Build an optimized CIFAR dataset
698    pub fn build_cifar_dataset<P: AsRef<Path>>(
699        self,
700        root: P,
701        is_cifar100: bool,
702        train: bool,
703    ) -> Result<OptimizedCIFARDataset> {
704        OptimizedCIFARDataset::new(root, is_cifar100, train, self.config)
705    }
706}
707
708impl Default for OptimizedDatasetBuilder {
709    fn default() -> Self {
710        Self::new()
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717    use tempfile::TempDir;
718
719    #[test]
720    fn test_dataset_builder() {
721        let builder = OptimizedDatasetBuilder::new()
722            .with_cache(500, 256)
723            .with_prefetch(true, 8)
724            .with_validation(true);
725
726        assert_eq!(builder.config.max_cache_items, 500);
727        assert_eq!(builder.config.max_cache_memory_mb, 256);
728        assert_eq!(builder.config.enable_prefetch, true);
729        assert_eq!(builder.config.prefetch_batch_size, 8);
730        assert_eq!(builder.config.validate_data, true);
731    }
732
733    #[test]
734    fn test_dataset_metadata() {
735        let metadata = DatasetMetadata {
736            name: "Test".to_string(),
737            version: "1.0".to_string(),
738            num_classes: 10,
739            class_names: vec!["class1".to_string(), "class2".to_string()],
740            total_items: 1000,
741            item_shape: vec![3, 32, 32],
742            data_type: "f32".to_string(),
743        };
744
745        assert_eq!(metadata.name, "Test");
746        assert_eq!(metadata.num_classes, 10);
747        assert_eq!(metadata.total_items, 1000);
748    }
749}