Skip to main content

shrew_data/
image_folder.rs

1// ImageFolder — Directory-based image classification dataset
2//
3// Loads images from a directory structure where each subdirectory is a class:
4//
5//   root/
6//     class_a/
7//       img_001.png
8//       img_002.jpg
9//     class_b/
10//       img_003.png
11//       ...
12//
13// Class labels are assigned as sorted indices of subdirectory names.
14//
15// The dataset returns samples with:
16//   - features: pixel values in [C, H, W] layout, normalised to [0, 1]
17//   - feature_shape: [C, H, W]
18//   - target: [class_index as f64]
19//   - target_shape: [1]
20//
21// USAGE:
22//
23//   let ds = ImageFolder::new("data/imagenet/train")
24//       .resize(224, 224)
25//       .build()?;
26//   println!("{} images, {} classes", ds.len(), ds.class_names().len());
27//
28// Requires the `image-folder` feature (which brings in the `image` crate).
29
30#[cfg(feature = "image-folder")]
31pub use inner::*;
32
33#[cfg(feature = "image-folder")]
34mod inner {
35    use std::path::{Path, PathBuf};
36
37    use image::imageops::FilterType;
38    use image::GenericImageView;
39
40    use crate::dataset::{Dataset, Sample};
41
42    /// Supported image extensions (case-insensitive).
43    const EXTENSIONS: &[&str] = &["jpg", "jpeg", "png", "bmp", "gif", "tiff", "tif", "webp"];
44
45    fn is_image(path: &Path) -> bool {
46        path.extension()
47            .and_then(|e| e.to_str())
48            .map(|e| EXTENSIONS.contains(&e.to_ascii_lowercase().as_str()))
49            .unwrap_or(false)
50    }
51
52    // ImageFolderConfig (builder)
53
54    /// Builder for [`ImageFolder`].
55    pub struct ImageFolderBuilder {
56        root: PathBuf,
57        resize: Option<(u32, u32)>,
58        grayscale: bool,
59    }
60
61    impl ImageFolderBuilder {
62        /// Create a builder rooted at the given directory.
63        pub fn new<P: AsRef<Path>>(root: P) -> Self {
64            ImageFolderBuilder {
65                root: root.as_ref().to_path_buf(),
66                resize: None,
67                grayscale: false,
68            }
69        }
70
71        /// Resize all images to (width, height) using Lanczos3 filter.
72        pub fn resize(mut self, width: u32, height: u32) -> Self {
73            self.resize = Some((width, height));
74            self
75        }
76
77        /// Convert images to grayscale (1 channel instead of 3).
78        pub fn grayscale(mut self, yes: bool) -> Self {
79            self.grayscale = yes;
80            self
81        }
82
83        /// Scan the directory tree and build the dataset.
84        pub fn build(self) -> Result<ImageFolder, crate::ImageFolderError> {
85            ImageFolder::scan(self.root, self.resize, self.grayscale)
86        }
87    }
88
89    // ImageFolder dataset
90
91    /// A directory-based image classification dataset (like torchvision ImageFolder).
92    #[derive(Debug)]
93    pub struct ImageFolder {
94        /// Sorted class names (subdirectory names).
95        class_names: Vec<String>,
96        /// Per-sample metadata: (path, class_index).
97        entries: Vec<(PathBuf, usize)>,
98        /// Optional resize target (width, height).
99        resize: Option<(u32, u32)>,
100        /// Whether to convert to grayscale.
101        grayscale: bool,
102        /// Number of channels (1 if grayscale, 3 otherwise).
103        channels: usize,
104        /// Image width after optional resize (0 if no resize — varies per image).
105        width: u32,
106        /// Image height after optional resize.
107        height: u32,
108    }
109
110    impl ImageFolder {
111        /// Convenience entry-point: `ImageFolder::new(root)` returns a builder.
112        pub fn new<P: AsRef<Path>>(root: P) -> ImageFolderBuilder {
113            ImageFolderBuilder::new(root)
114        }
115
116        /// Scan the directory and collect all image paths + class labels.
117        fn scan(
118            root: PathBuf,
119            resize: Option<(u32, u32)>,
120            grayscale: bool,
121        ) -> Result<Self, crate::ImageFolderError> {
122            if !root.is_dir() {
123                return Err(crate::ImageFolderError::NotADirectory(
124                    root.display().to_string(),
125                ));
126            }
127
128            // Collect class subdirectories (sorted)
129            let mut class_dirs: Vec<(String, PathBuf)> = Vec::new();
130            for entry in std::fs::read_dir(&root).map_err(|e| crate::ImageFolderError::Io(e))? {
131                let entry = entry.map_err(|e| crate::ImageFolderError::Io(e))?;
132                let path = entry.path();
133                if path.is_dir() {
134                    if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
135                        class_dirs.push((name.to_string(), path));
136                    }
137                }
138            }
139            class_dirs.sort_by(|a, b| a.0.cmp(&b.0));
140
141            if class_dirs.is_empty() {
142                return Err(crate::ImageFolderError::NoClasses(
143                    root.display().to_string(),
144                ));
145            }
146
147            let class_names: Vec<String> = class_dirs.iter().map(|(n, _)| n.clone()).collect();
148
149            // Collect image paths per class
150            let mut entries: Vec<(PathBuf, usize)> = Vec::new();
151            for (class_idx, (_name, dir)) in class_dirs.iter().enumerate() {
152                let mut paths: Vec<PathBuf> = Vec::new();
153                Self::collect_images(dir, &mut paths);
154                paths.sort();
155                for p in paths {
156                    entries.push((p, class_idx));
157                }
158            }
159
160            if entries.is_empty() {
161                return Err(crate::ImageFolderError::NoImages(
162                    root.display().to_string(),
163                ));
164            }
165
166            let channels = if grayscale { 1 } else { 3 };
167            let (width, height) = resize.unwrap_or((0, 0));
168
169            Ok(ImageFolder {
170                class_names,
171                entries,
172                resize,
173                grayscale,
174                channels,
175                width,
176                height,
177            })
178        }
179
180        /// Recursively collect image files.
181        fn collect_images(dir: &Path, out: &mut Vec<PathBuf>) {
182            if let Ok(rd) = std::fs::read_dir(dir) {
183                for entry in rd.flatten() {
184                    let path = entry.path();
185                    if path.is_dir() {
186                        Self::collect_images(&path, out);
187                    } else if is_image(&path) {
188                        out.push(path);
189                    }
190                }
191            }
192        }
193
194        /// Get the class names (sorted).
195        pub fn class_names(&self) -> &[String] {
196            &self.class_names
197        }
198
199        /// Number of classes.
200        pub fn num_classes(&self) -> usize {
201            self.class_names.len()
202        }
203
204        /// Get the class index for the i-th sample.
205        pub fn class_of(&self, index: usize) -> usize {
206            self.entries[index].1
207        }
208
209        /// Get the file path of the i-th sample.
210        pub fn path_of(&self, index: usize) -> &Path {
211            &self.entries[index].0
212        }
213
214        /// Load and decode an image, returning pixel data in [C, H, W] layout
215        /// with values normalised to [0, 1].
216        fn load_image(
217            &self,
218            index: usize,
219        ) -> Result<(Vec<f64>, [usize; 3]), crate::ImageFolderError> {
220            let path = &self.entries[index].0;
221            let img = image::open(path).map_err(|e| {
222                crate::ImageFolderError::ImageDecode(path.display().to_string(), e.to_string())
223            })?;
224
225            // Optional resize
226            let img = match self.resize {
227                Some((w, h)) => img.resize_exact(w, h, FilterType::Lanczos3),
228                None => img,
229            };
230
231            // Grayscale or RGB
232            let (w, h) = img.dimensions();
233            let (pixels, c) = if self.grayscale {
234                let gray = img.to_luma8();
235                let data: Vec<f64> = gray.as_raw().iter().map(|&v| v as f64 / 255.0).collect();
236                (data, 1usize)
237            } else {
238                let rgb = img.to_rgb8();
239                let raw = rgb.as_raw();
240                // Convert from [H, W, C] interleaved to [C, H, W] planar
241                let npix = (w * h) as usize;
242                let mut data = vec![0.0f64; 3 * npix];
243                for i in 0..npix {
244                    data[i] = raw[i * 3] as f64 / 255.0; // R
245                    data[npix + i] = raw[i * 3 + 1] as f64 / 255.0; // G
246                    data[2 * npix + i] = raw[i * 3 + 2] as f64 / 255.0; // B
247                }
248                (data, 3usize)
249            };
250
251            Ok((pixels, [c, h as usize, w as usize]))
252        }
253    }
254
255    impl Dataset for ImageFolder {
256        fn len(&self) -> usize {
257            self.entries.len()
258        }
259
260        fn get(&self, index: usize) -> Sample {
261            match self.load_image(index) {
262                Ok((features, shape)) => Sample {
263                    features,
264                    feature_shape: shape.to_vec(),
265                    target: vec![self.entries[index].1 as f64],
266                    target_shape: vec![1],
267                },
268                Err(e) => {
269                    // Return a zero sample on error (avoids panicking in Iterator)
270                    let c = self.channels;
271                    let (w, h) = self.resize.unwrap_or((1, 1));
272                    eprintln!(
273                        "ImageFolder: failed to load {:?}: {}",
274                        self.entries[index].0, e
275                    );
276                    Sample {
277                        features: vec![0.0; c * (h as usize) * (w as usize)],
278                        feature_shape: vec![c, h as usize, w as usize],
279                        target: vec![self.entries[index].1 as f64],
280                        target_shape: vec![1],
281                    }
282                }
283            }
284        }
285
286        fn feature_shape(&self) -> &[usize] {
287            // Only valid when resize is set; otherwise shape varies per image.
288            // We return a static reference to a leaked slice for the fixed case.
289            // For dynamic case, we return a placeholder.
290            &[]
291        }
292
293        fn target_shape(&self) -> &[usize] {
294            &[]
295        }
296
297        fn name(&self) -> &str {
298            "ImageFolder"
299        }
300    }
301
302    // Send + Sync — all fields are owned data
303    unsafe impl Send for ImageFolder {}
304    unsafe impl Sync for ImageFolder {}
305}