1use crate::error_handling::{EnhancedVisionError, ErrorHandler};
7use crate::io::VisionIO;
8use crate::ImageCache;
9use crate::{Result, VisionError};
10use image::DynamicImage;
11use std::collections::HashMap;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use torsh_tensor::{creation, Tensor};
15
16#[derive(Debug, Clone)]
18pub struct DatasetConfig {
19 pub max_cache_items: usize,
21 pub max_cache_memory_mb: usize,
23 pub enable_prefetch: bool,
25 pub prefetch_batch_size: usize,
27 pub validate_data: bool,
29}
30
31impl Default for DatasetConfig {
32 fn default() -> Self {
33 Self {
34 max_cache_items: 1000,
35 max_cache_memory_mb: 512,
36 enable_prefetch: true,
37 prefetch_batch_size: 16,
38 validate_data: true,
39 }
40 }
41}
42
43pub trait OptimizedDataset: Send + Sync {
45 type Item;
46
47 fn len(&self) -> usize;
49
50 fn is_empty(&self) -> bool {
52 self.len() == 0
53 }
54
55 fn get_item(&self, index: usize) -> Result<Self::Item>;
57
58 fn get_batch(&self, indices: &[usize]) -> Result<Vec<Self::Item>> {
60 indices.iter().map(|&i| self.get_item(i)).collect()
61 }
62
63 fn metadata(&self) -> DatasetMetadata;
65
66 fn prefetch(&self, indices: &[usize]) -> Result<()>;
68
69 fn clear_cache(&self);
71
72 fn cache_stats(&self) -> CacheStatistics;
74}
75
76#[derive(Debug, Clone)]
78pub struct DatasetMetadata {
79 pub name: String,
80 pub version: String,
81 pub num_classes: usize,
82 pub class_names: Vec<String>,
83 pub total_items: usize,
84 pub item_shape: Vec<usize>,
85 pub data_type: String,
86}
87
88#[derive(Debug, Clone)]
90pub struct CacheStatistics {
91 pub cache_hits: usize,
92 pub cache_misses: usize,
93 pub hit_rate: f64,
94 pub memory_usage_mb: f64,
95 pub cached_items: usize,
96}
97
98pub struct OptimizedImageDataset {
100 config: DatasetConfig,
101 io: Arc<VisionIO>,
102 cache: Arc<ImageCache>,
103 image_paths: Vec<PathBuf>,
104 labels: Vec<usize>,
105 class_names: Vec<String>,
106 class_to_idx: HashMap<String, usize>,
107 metadata: DatasetMetadata,
108}
109
110impl OptimizedImageDataset {
111 pub fn new<P: AsRef<Path>>(root: P, config: DatasetConfig) -> Result<Self> {
113 let root_path = root.as_ref();
114
115 if !root_path.exists() {
116 return Err(VisionError::IoError(std::io::Error::new(
117 std::io::ErrorKind::NotFound,
118 format!("Dataset directory {:?} does not exist", root_path),
119 )));
120 }
121
122 let io = Arc::new(VisionIO::new());
123 let cache = Arc::new(ImageCache::new(config.max_cache_memory_mb));
124
125 let mut image_paths = Vec::new();
126 let mut labels = Vec::new();
127 let mut class_names = Vec::new();
128 let mut class_to_idx = HashMap::new();
129
130 for entry in std::fs::read_dir(root_path)? {
132 let entry = entry?;
133 let path = entry.path();
134
135 if path.is_dir() {
136 if let Some(class_name) = path.file_name() {
137 let class_str = class_name.to_string_lossy().to_string();
138 let class_idx = class_names.len();
139 class_names.push(class_str.clone());
140 class_to_idx.insert(class_str, class_idx);
141
142 for img_entry in std::fs::read_dir(&path)? {
144 let img_entry = img_entry?;
145 let img_path = img_entry.path();
146
147 if img_path.is_file() && io.is_supported_image(&img_path) {
148 image_paths.push(img_path);
149 labels.push(class_idx);
150 }
151 }
152 }
153 }
154 }
155
156 if class_names.is_empty() {
157 return Err(VisionError::TransformError(
158 "No class directories found".to_string(),
159 ));
160 }
161
162 let metadata = DatasetMetadata {
163 name: "OptimizedImageDataset".to_string(),
164 version: "1.0".to_string(),
165 num_classes: class_names.len(),
166 class_names: class_names.clone(),
167 total_items: image_paths.len(),
168 item_shape: vec![3, 224, 224], data_type: "f32".to_string(),
170 };
171
172 Ok(Self {
173 config,
174 io,
175 cache,
176 image_paths,
177 labels,
178 class_names,
179 class_to_idx,
180 metadata,
181 })
182 }
183
184 pub fn classes(&self) -> &[String] {
186 &self.class_names
187 }
188
189 pub fn class_to_idx(&self) -> &HashMap<String, usize> {
191 &self.class_to_idx
192 }
193}
194
195impl OptimizedDataset for OptimizedImageDataset {
196 type Item = (Tensor<f32>, usize);
197
198 fn len(&self) -> usize {
199 self.image_paths.len()
200 }
201
202 fn get_item(&self, index: usize) -> Result<Self::Item> {
203 if index >= self.image_paths.len() {
204 return Err(VisionError::InvalidArgument(format!(
205 "Index {} out of bounds for dataset of size {}",
206 index,
207 self.image_paths.len()
208 )));
209 }
210
211 let image_path = &self.image_paths[index];
212 let label = self.labels[index];
213
214 let image = self.cache.get_or_load(image_path)?;
216
217 let tensor = crate::utils::image_to_tensor(&image)?;
219
220 Ok((tensor, label))
221 }
222
223 fn get_batch(&self, indices: &[usize]) -> Result<Vec<Self::Item>> {
224 for &index in indices {
226 if index >= self.image_paths.len() {
227 return Err(VisionError::InvalidArgument(format!(
228 "Index {} out of bounds for dataset of size {}",
229 index,
230 self.image_paths.len()
231 )));
232 }
233 }
234
235 if indices.len() > 4 {
237 indices.iter().map(|&i| self.get_item(i)).collect()
240 } else {
241 indices.iter().map(|&i| self.get_item(i)).collect()
242 }
243 }
244
245 fn metadata(&self) -> DatasetMetadata {
246 self.metadata.clone()
247 }
248
249 fn prefetch(&self, indices: &[usize]) -> Result<()> {
250 if !self.config.enable_prefetch {
251 return Ok(());
252 }
253
254 let paths: Vec<_> = indices
256 .iter()
257 .filter_map(|&i| self.image_paths.get(i))
258 .collect();
259
260 for path in paths {
262 let _ = self.cache.get_or_load(path);
263 }
264
265 Ok(())
266 }
267
268 fn clear_cache(&self) {
269 self.cache.clear();
270 }
271
272 fn cache_stats(&self) -> CacheStatistics {
273 let stats = self.cache.stats();
274 CacheStatistics {
275 cache_hits: stats.hit_count,
276 cache_misses: stats.miss_count,
277 hit_rate: stats.hit_rate,
278 memory_usage_mb: stats.current_size_bytes as f64 / (1024.0 * 1024.0),
279 cached_items: stats.entry_count,
280 }
281 }
282}
283
284pub struct OptimizedCIFARDataset {
286 config: DatasetConfig,
287 data_path: PathBuf,
288 is_cifar100: bool,
289 is_train: bool,
290 cached_data: Arc<std::sync::Mutex<HashMap<usize, (Tensor<f32>, usize)>>>,
291 classes: Vec<String>,
292 total_samples: usize,
293 metadata: DatasetMetadata,
294}
295
296impl OptimizedCIFARDataset {
297 pub fn new<P: AsRef<Path>>(
299 root: P,
300 is_cifar100: bool,
301 train: bool,
302 config: DatasetConfig,
303 ) -> Result<Self> {
304 let root_path = root.as_ref();
305
306 if !root_path.exists() {
307 std::fs::create_dir_all(root_path)?;
308 }
309
310 let (data_path, total_samples, classes) = if is_cifar100 {
311 let file_name = if train { "train.bin" } else { "test.bin" };
312 let path = root_path.join(file_name);
313
314 if !path.exists() {
315 return Err(VisionError::TransformError(
316 format!("CIFAR-100 {} file not found in {:?}. Please download from https://www.cs.toronto.edu/~kriz/cifar.html",
317 if train { "training" } else { "test" }, root_path)
318 ));
319 }
320
321 let samples = if train { 50000 } else { 10000 };
322 let classes = Self::get_cifar100_classes();
323 (path, samples, classes)
324 } else {
325 let total_samples = if train { 50000 } else { 10000 };
327 let path = if train {
328 root_path.join("data_batch_1.bin") } else {
330 root_path.join("test_batch.bin")
331 };
332
333 let classes = vec![
334 "airplane",
335 "automobile",
336 "bird",
337 "cat",
338 "deer",
339 "dog",
340 "frog",
341 "horse",
342 "ship",
343 "truck",
344 ]
345 .into_iter()
346 .map(|s| s.to_string())
347 .collect();
348
349 (path, total_samples, classes)
350 };
351
352 let metadata = DatasetMetadata {
353 name: if is_cifar100 { "CIFAR-100" } else { "CIFAR-10" }.to_string(),
354 version: "1.0".to_string(),
355 num_classes: classes.len(),
356 class_names: classes.clone(),
357 total_items: total_samples,
358 item_shape: vec![3, 32, 32],
359 data_type: "f32".to_string(),
360 };
361
362 Ok(Self {
363 config,
364 data_path,
365 is_cifar100,
366 is_train: train,
367 cached_data: Arc::new(std::sync::Mutex::new(HashMap::new())),
368 classes,
369 total_samples,
370 metadata,
371 })
372 }
373
374 fn get_cifar100_classes() -> Vec<String> {
375 vec![
376 "apple",
377 "aquarium_fish",
378 "baby",
379 "bear",
380 "beaver",
381 "bed",
382 "bee",
383 "beetle",
384 "bicycle",
385 "bottle",
386 "bowl",
387 "boy",
388 "bridge",
389 "bus",
390 "butterfly",
391 "camel",
392 "can",
393 "castle",
394 "caterpillar",
395 "cattle",
396 "chair",
397 "chimpanzee",
398 "clock",
399 "cloud",
400 "cockroach",
401 "couch",
402 "crab",
403 "crocodile",
404 "cup",
405 "dinosaur",
406 "dolphin",
407 "elephant",
408 "flatfish",
409 "forest",
410 "fox",
411 "girl",
412 "hamster",
413 "house",
414 "kangaroo",
415 "keyboard",
416 "lamp",
417 "lawn_mower",
418 "leopard",
419 "lion",
420 "lizard",
421 "lobster",
422 "man",
423 "maple_tree",
424 "motorcycle",
425 "mountain",
426 "mouse",
427 "mushroom",
428 "oak_tree",
429 "orange",
430 "orchid",
431 "otter",
432 "palm_tree",
433 "pear",
434 "pickup_truck",
435 "pine_tree",
436 "plain",
437 "plate",
438 "poppy",
439 "porcupine",
440 "possum",
441 "rabbit",
442 "raccoon",
443 "ray",
444 "road",
445 "rocket",
446 "rose",
447 "sea",
448 "seal",
449 "shark",
450 "shrew",
451 "skunk",
452 "skyscraper",
453 "snail",
454 "snake",
455 "spider",
456 "squirrel",
457 "streetcar",
458 "sunflower",
459 "sweet_pepper",
460 "table",
461 "tank",
462 "telephone",
463 "television",
464 "tiger",
465 "tractor",
466 "train",
467 "trout",
468 "tulip",
469 "turtle",
470 "wardrobe",
471 "whale",
472 "willow_tree",
473 "wolf",
474 "woman",
475 "worm",
476 ]
477 .into_iter()
478 .map(|s| s.to_string())
479 .collect()
480 }
481
482 fn load_cifar_sample(&self, index: usize) -> Result<(Tensor<f32>, usize)> {
483 {
485 let cache = self
486 .cached_data
487 .lock()
488 .expect("lock should not be poisoned");
489 if let Some(cached_item) = cache.get(&index) {
490 return Ok(cached_item.clone());
491 }
492 }
493
494 let (tensor, label) = if self.is_cifar100 {
496 self.load_cifar100_sample(index)?
497 } else {
498 self.load_cifar10_sample(index)?
499 };
500
501 {
503 let mut cache = self
504 .cached_data
505 .lock()
506 .expect("lock should not be poisoned");
507 if cache.len() < self.config.max_cache_items {
508 cache.insert(index, (tensor.clone(), label));
509 }
510 }
511
512 Ok((tensor, label))
513 }
514
515 fn load_cifar10_sample(&self, index: usize) -> Result<(Tensor<f32>, usize)> {
516 let (batch_idx, sample_idx) = if self.is_train {
517 (index / 10000, index % 10000)
518 } else {
519 (0, index) };
521
522 let batch_file = if self.is_train {
523 self.data_path
524 .parent()
525 .expect("data_path should have parent")
526 .join(format!("data_batch_{}.bin", batch_idx + 1))
527 } else {
528 self.data_path.clone()
529 };
530
531 if !batch_file.exists() {
532 return Err(VisionError::IoError(std::io::Error::new(
533 std::io::ErrorKind::NotFound,
534 format!("CIFAR-10 batch file {:?} not found", batch_file),
535 )));
536 }
537
538 let data = std::fs::read(batch_file)?;
539 let start_idx = sample_idx * 3073; if start_idx + 3073 > data.len() {
542 return Err(VisionError::TransformError(
543 "Invalid CIFAR-10 file format".to_string(),
544 ));
545 }
546
547 let label = data[start_idx] as usize;
548 let tensor = torsh_tensor::creation::zeros(&[3, 32, 32])
549 .map_err(|e| VisionError::TransformError(format!("Failed to create tensor: {}", e)))?;
550
551 for channel in 0..3 {
553 for y in 0..32 {
554 for x in 0..32 {
555 let pixel_idx = start_idx + 1 + channel * 1024 + y * 32 + x;
556 let pixel_val = data[pixel_idx] as f32 / 255.0;
557 tensor.set(&[channel, y, x], pixel_val)?;
558 }
559 }
560 }
561
562 Ok((tensor, label))
563 }
564
565 fn load_cifar100_sample(&self, index: usize) -> Result<(Tensor<f32>, usize)> {
566 let data = std::fs::read(&self.data_path)?;
567 let start_idx = index * 3074; if start_idx + 3074 > data.len() {
570 return Err(VisionError::TransformError(
571 "Invalid CIFAR-100 file format".to_string(),
572 ));
573 }
574
575 let _coarse_label = data[start_idx] as usize;
576 let fine_label = data[start_idx + 1] as usize;
577 let tensor = torsh_tensor::creation::zeros(&[3, 32, 32])
578 .map_err(|e| VisionError::TransformError(format!("Failed to create tensor: {}", e)))?;
579
580 for channel in 0..3 {
582 for y in 0..32 {
583 for x in 0..32 {
584 let pixel_idx = start_idx + 2 + channel * 1024 + y * 32 + x;
585 let pixel_val = data[pixel_idx] as f32 / 255.0;
586 tensor.set(&[channel, y, x], pixel_val)?;
587 }
588 }
589 }
590
591 Ok((tensor, fine_label))
592 }
593
594 pub fn classes(&self) -> &[String] {
596 &self.classes
597 }
598}
599
600impl OptimizedDataset for OptimizedCIFARDataset {
601 type Item = (Tensor<f32>, usize);
602
603 fn len(&self) -> usize {
604 self.total_samples
605 }
606
607 fn get_item(&self, index: usize) -> Result<Self::Item> {
608 if index >= self.total_samples {
609 return Err(VisionError::InvalidArgument(format!(
610 "Index {} out of bounds for dataset of size {}",
611 index, self.total_samples
612 )));
613 }
614
615 self.load_cifar_sample(index)
616 }
617
618 fn metadata(&self) -> DatasetMetadata {
619 self.metadata.clone()
620 }
621
622 fn prefetch(&self, indices: &[usize]) -> Result<()> {
623 if !self.config.enable_prefetch {
624 return Ok(());
625 }
626
627 for &index in indices {
628 if index < self.total_samples {
629 let _ = self.load_cifar_sample(index);
630 }
631 }
632
633 Ok(())
634 }
635
636 fn clear_cache(&self) {
637 let mut cache = self
638 .cached_data
639 .lock()
640 .expect("lock should not be poisoned");
641 cache.clear();
642 }
643
644 fn cache_stats(&self) -> CacheStatistics {
645 let cache = self
646 .cached_data
647 .lock()
648 .expect("lock should not be poisoned");
649 CacheStatistics {
650 cache_hits: 0, cache_misses: 0,
652 hit_rate: 0.0,
653 memory_usage_mb: cache.len() as f64 * 3.0 * 32.0 * 32.0 * 4.0 / (1024.0 * 1024.0), cached_items: cache.len(),
655 }
656 }
657}
658
659pub struct OptimizedDatasetBuilder {
661 config: DatasetConfig,
662}
663
664impl OptimizedDatasetBuilder {
665 pub fn new() -> Self {
667 Self {
668 config: DatasetConfig::default(),
669 }
670 }
671
672 pub fn with_cache(mut self, max_items: usize, max_memory_mb: usize) -> Self {
674 self.config.max_cache_items = max_items;
675 self.config.max_cache_memory_mb = max_memory_mb;
676 self
677 }
678
679 pub fn with_prefetch(mut self, enable: bool, batch_size: usize) -> Self {
681 self.config.enable_prefetch = enable;
682 self.config.prefetch_batch_size = batch_size;
683 self
684 }
685
686 pub fn with_validation(mut self, enable: bool) -> Self {
688 self.config.validate_data = enable;
689 self
690 }
691
692 pub fn build_image_dataset<P: AsRef<Path>>(self, root: P) -> Result<OptimizedImageDataset> {
694 OptimizedImageDataset::new(root, self.config)
695 }
696
697 pub fn build_cifar_dataset<P: AsRef<Path>>(
699 self,
700 root: P,
701 is_cifar100: bool,
702 train: bool,
703 ) -> Result<OptimizedCIFARDataset> {
704 OptimizedCIFARDataset::new(root, is_cifar100, train, self.config)
705 }
706}
707
708impl Default for OptimizedDatasetBuilder {
709 fn default() -> Self {
710 Self::new()
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717 use tempfile::TempDir;
718
719 #[test]
720 fn test_dataset_builder() {
721 let builder = OptimizedDatasetBuilder::new()
722 .with_cache(500, 256)
723 .with_prefetch(true, 8)
724 .with_validation(true);
725
726 assert_eq!(builder.config.max_cache_items, 500);
727 assert_eq!(builder.config.max_cache_memory_mb, 256);
728 assert_eq!(builder.config.enable_prefetch, true);
729 assert_eq!(builder.config.prefetch_batch_size, 8);
730 assert_eq!(builder.config.validate_data, true);
731 }
732
733 #[test]
734 fn test_dataset_metadata() {
735 let metadata = DatasetMetadata {
736 name: "Test".to_string(),
737 version: "1.0".to_string(),
738 num_classes: 10,
739 class_names: vec!["class1".to_string(), "class2".to_string()],
740 total_items: 1000,
741 item_shape: vec![3, 32, 32],
742 data_type: "f32".to_string(),
743 };
744
745 assert_eq!(metadata.name, "Test");
746 assert_eq!(metadata.num_classes, 10);
747 assert_eq!(metadata.total_items, 1000);
748 }
749}