Skip to main content

yscv_model/dataset/
image_folder.rs

1use image::ImageReader;
2use std::fs;
3use std::path::{Path, PathBuf};
4use yscv_imgproc::resize_nearest;
5use yscv_tensor::Tensor;
6
7use crate::ModelError;
8
9use super::helpers::build_supervised_dataset_from_flat_values;
10use super::types::SupervisedDataset;
11
12/// Configuration for loading supervised image-folder datasets.
13///
14/// Expected directory layout:
15/// `root/<class_name>/<image_file>`
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum ImageFolderTargetMode {
18    #[default]
19    ClassIndex,
20    OneHot,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct SupervisedImageFolderConfig {
25    output_height: usize,
26    output_width: usize,
27    target_mode: ImageFolderTargetMode,
28    allowed_extensions: Vec<String>,
29}
30
31impl SupervisedImageFolderConfig {
32    pub fn new(output_height: usize, output_width: usize) -> Result<Self, ModelError> {
33        if output_height == 0 {
34            return Err(ModelError::InvalidDatasetAdapterShape {
35                field: "output_height",
36                shape: vec![output_height],
37                message: "output_height must be > 0".to_string(),
38            });
39        }
40        if output_width == 0 {
41            return Err(ModelError::InvalidDatasetAdapterShape {
42                field: "output_width",
43                shape: vec![output_width],
44                message: "output_width must be > 0".to_string(),
45            });
46        }
47        Ok(Self {
48            output_height,
49            output_width,
50            target_mode: ImageFolderTargetMode::ClassIndex,
51            allowed_extensions: default_image_folder_extensions(),
52        })
53    }
54
55    pub fn output_height(&self) -> usize {
56        self.output_height
57    }
58
59    pub fn output_width(&self) -> usize {
60        self.output_width
61    }
62
63    pub fn with_target_mode(mut self, target_mode: ImageFolderTargetMode) -> Self {
64        self.target_mode = target_mode;
65        self
66    }
67
68    pub fn target_mode(&self) -> ImageFolderTargetMode {
69        self.target_mode
70    }
71
72    pub fn with_allowed_extensions(
73        mut self,
74        allowed_extensions: Vec<String>,
75    ) -> Result<Self, ModelError> {
76        self.allowed_extensions = normalize_image_extensions(allowed_extensions)?;
77        Ok(self)
78    }
79
80    pub fn allowed_extensions(&self) -> &[String] {
81        &self.allowed_extensions
82    }
83}
84
85/// Result payload for image-folder dataset loading with explicit class mapping.
86#[derive(Debug, Clone, PartialEq)]
87pub struct SupervisedImageFolderLoadResult {
88    pub dataset: SupervisedDataset,
89    pub class_names: Vec<String>,
90}
91
92/// Loads supervised training samples from an image-folder classification tree.
93///
94/// Directory contract:
95/// - direct child directories under `root` are class buckets,
96/// - class index is assigned by deterministic lexicographic directory ordering,
97/// - supported image extensions are `jpg`, `jpeg`, and `png` (case-insensitive),
98/// - non-image files are ignored.
99///
100/// Runtime mapping:
101/// - inputs shape: `[N, output_height, output_width, 3]`,
102/// - targets shape depends on `config.target_mode()`:
103///   - `ClassIndex`: `[N, 1]` with scalar class ids,
104///   - `OneHot`: `[N, class_count]` one-hot vectors.
105pub fn load_supervised_image_folder_dataset<P: AsRef<Path>>(
106    root: P,
107    config: &SupervisedImageFolderConfig,
108) -> Result<SupervisedDataset, ModelError> {
109    load_supervised_image_folder_dataset_with_classes(root, config).map(|loaded| loaded.dataset)
110}
111
112/// Loads supervised training samples from an image-folder classification tree and returns class mapping.
113///
114/// Class names are returned in deterministic lexicographic directory order
115/// and correspond to class indices used in targets.
116pub fn load_supervised_image_folder_dataset_with_classes<P: AsRef<Path>>(
117    root: P,
118    config: &SupervisedImageFolderConfig,
119) -> Result<SupervisedImageFolderLoadResult, ModelError> {
120    let root_ref = root.as_ref();
121    let class_dirs = read_sorted_class_directories(root_ref)?;
122    let class_names = class_dirs
123        .iter()
124        .map(|class_dir| class_name_from_path(class_dir))
125        .collect::<Result<Vec<_>, _>>()?;
126    let class_count = class_dirs.len();
127
128    let mut input_values = Vec::new();
129    let mut target_values = Vec::new();
130    let mut sample_count = 0usize;
131
132    for (class_id, class_dir) in class_dirs.iter().enumerate() {
133        let image_files =
134            read_sorted_supported_image_files(class_dir, config.allowed_extensions())?;
135        for image_path in image_files {
136            let image_tensor = load_image_as_normalized_rgb_tensor(
137                &image_path,
138                config.output_height(),
139                config.output_width(),
140            )?;
141            input_values.extend_from_slice(image_tensor.data());
142            append_image_folder_target(
143                &mut target_values,
144                class_id,
145                class_count,
146                config.target_mode(),
147            )?;
148            sample_count = sample_count.checked_add(1).ok_or_else(|| {
149                ModelError::InvalidDatasetAdapterShape {
150                    field: "sample_count",
151                    shape: vec![sample_count],
152                    message: "sample count overflow".to_string(),
153                }
154            })?;
155        }
156    }
157
158    if sample_count == 0 {
159        return Err(ModelError::EmptyDataset);
160    }
161
162    let target_shape = match config.target_mode() {
163        ImageFolderTargetMode::ClassIndex => vec![1],
164        ImageFolderTargetMode::OneHot => vec![class_count],
165    };
166
167    let dataset = build_supervised_dataset_from_flat_values(
168        &[config.output_height(), config.output_width(), 3],
169        &target_shape,
170        sample_count,
171        input_values,
172        target_values,
173    )?;
174
175    Ok(SupervisedImageFolderLoadResult {
176        dataset,
177        class_names,
178    })
179}
180
181fn class_name_from_path(class_dir: &Path) -> Result<String, ModelError> {
182    class_dir
183        .file_name()
184        .map(|name| name.to_string_lossy().into_owned())
185        .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
186            field: "class_name",
187            shape: Vec::new(),
188            message: format!(
189                "failed to infer class name from directory path {}",
190                class_dir.display()
191            ),
192        })
193}
194
195fn append_image_folder_target(
196    target_values: &mut Vec<f32>,
197    class_id: usize,
198    class_count: usize,
199    target_mode: ImageFolderTargetMode,
200) -> Result<(), ModelError> {
201    match target_mode {
202        ImageFolderTargetMode::ClassIndex => {
203            target_values.push(class_id as f32);
204            Ok(())
205        }
206        ImageFolderTargetMode::OneHot => {
207            if class_count == 0 || class_id >= class_count {
208                return Err(ModelError::InvalidDatasetAdapterShape {
209                    field: "target_shape",
210                    shape: vec![class_count],
211                    message: "invalid one-hot class configuration".to_string(),
212                });
213            }
214            let next_len = target_values
215                .len()
216                .checked_add(class_count)
217                .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
218                    field: "target_values",
219                    shape: vec![target_values.len(), class_count],
220                    message: "target vector length overflow".to_string(),
221                })?;
222            target_values.resize(next_len, 0.0);
223            let target_start = next_len - class_count;
224            target_values[target_start + class_id] = 1.0;
225            Ok(())
226        }
227    }
228}
229
230fn read_sorted_class_directories(root: &Path) -> Result<Vec<PathBuf>, ModelError> {
231    let entries = fs::read_dir(root).map_err(|error| ModelError::DatasetLoadIo {
232        path: root.display().to_string(),
233        message: error.to_string(),
234    })?;
235
236    let mut class_dirs = Vec::new();
237    for entry in entries {
238        let entry = entry.map_err(|error| ModelError::DatasetLoadIo {
239            path: root.display().to_string(),
240            message: error.to_string(),
241        })?;
242        let file_type = entry
243            .file_type()
244            .map_err(|error| ModelError::DatasetLoadIo {
245                path: entry.path().display().to_string(),
246                message: error.to_string(),
247            })?;
248        if file_type.is_dir() {
249            class_dirs.push(entry.path());
250        }
251    }
252
253    class_dirs.sort_by(|left, right| {
254        let left_name = left
255            .file_name()
256            .map(|name| name.to_string_lossy().into_owned())
257            .unwrap_or_default();
258        let right_name = right
259            .file_name()
260            .map(|name| name.to_string_lossy().into_owned())
261            .unwrap_or_default();
262        left_name.cmp(&right_name).then_with(|| left.cmp(right))
263    });
264    Ok(class_dirs)
265}
266
267fn read_sorted_supported_image_files(
268    class_dir: &Path,
269    allowed_extensions: &[String],
270) -> Result<Vec<PathBuf>, ModelError> {
271    let entries = fs::read_dir(class_dir).map_err(|error| ModelError::DatasetLoadIo {
272        path: class_dir.display().to_string(),
273        message: error.to_string(),
274    })?;
275
276    let mut image_files = Vec::new();
277    for entry in entries {
278        let entry = entry.map_err(|error| ModelError::DatasetLoadIo {
279            path: class_dir.display().to_string(),
280            message: error.to_string(),
281        })?;
282        let file_type = entry
283            .file_type()
284            .map_err(|error| ModelError::DatasetLoadIo {
285                path: entry.path().display().to_string(),
286                message: error.to_string(),
287            })?;
288        if file_type.is_file() && has_supported_image_extension(&entry.path(), allowed_extensions) {
289            image_files.push(entry.path());
290        }
291    }
292
293    image_files.sort_by(|left, right| {
294        let left_name = left
295            .file_name()
296            .map(|name| name.to_string_lossy().into_owned())
297            .unwrap_or_default();
298        let right_name = right
299            .file_name()
300            .map(|name| name.to_string_lossy().into_owned())
301            .unwrap_or_default();
302        left_name.cmp(&right_name).then_with(|| left.cmp(right))
303    });
304    Ok(image_files)
305}
306
307fn has_supported_image_extension(path: &Path, allowed_extensions: &[String]) -> bool {
308    let Some(extension) = path.extension().and_then(|extension| extension.to_str()) else {
309        return false;
310    };
311    allowed_extensions
312        .iter()
313        .any(|allowed| extension.eq_ignore_ascii_case(allowed))
314}
315
316fn default_image_folder_extensions() -> Vec<String> {
317    ["jpg", "jpeg", "png", "bmp", "webp"]
318        .into_iter()
319        .map(str::to_string)
320        .collect()
321}
322
323fn normalize_image_extensions(extensions: Vec<String>) -> Result<Vec<String>, ModelError> {
324    if extensions.is_empty() {
325        return Err(ModelError::InvalidImageFolderExtension {
326            extension: "<list>".to_string(),
327            message: "extension list must be non-empty".to_string(),
328        });
329    }
330
331    let mut normalized = Vec::with_capacity(extensions.len());
332    for extension in extensions {
333        let trimmed = extension.trim();
334        if trimmed.is_empty() {
335            return Err(ModelError::InvalidImageFolderExtension {
336                extension,
337                message: "extension must be non-empty".to_string(),
338            });
339        }
340        if trimmed.starts_with('.') {
341            return Err(ModelError::InvalidImageFolderExtension {
342                extension: trimmed.to_string(),
343                message: "extension must not start with '.'".to_string(),
344            });
345        }
346        if !trimmed
347            .bytes()
348            .all(|byte| byte.is_ascii_alphanumeric() || byte == b'_')
349        {
350            return Err(ModelError::InvalidImageFolderExtension {
351                extension: trimmed.to_string(),
352                message: "extension must contain only ASCII letters, digits, or '_'".to_string(),
353            });
354        }
355        let lowered = trimmed.to_ascii_lowercase();
356        if !normalized.iter().any(|existing| existing == &lowered) {
357            normalized.push(lowered);
358        }
359    }
360
361    Ok(normalized)
362}
363
364pub(super) fn load_image_as_normalized_rgb_tensor(
365    path: &Path,
366    output_height: usize,
367    output_width: usize,
368) -> Result<Tensor, ModelError> {
369    let path_string = path.display().to_string();
370    let decoded = ImageReader::open(path)
371        .map_err(|error| ModelError::DatasetImageDecode {
372            path: path_string.clone(),
373            message: error.to_string(),
374        })?
375        .decode()
376        .map_err(|error| ModelError::DatasetImageDecode {
377            path: path_string.clone(),
378            message: error.to_string(),
379        })?;
380    let rgb = decoded.to_rgb8();
381    let (width, height) = rgb.dimensions();
382
383    let image_height = height as usize;
384    let image_width = width as usize;
385    let normalized = rgb
386        .as_raw()
387        .iter()
388        .map(|value| (*value as f32) * (1.0 / 255.0))
389        .collect::<Vec<_>>();
390    let image_tensor = Tensor::from_vec(vec![image_height, image_width, 3], normalized)?;
391
392    if image_height == output_height && image_width == output_width {
393        Ok(image_tensor)
394    } else {
395        resize_nearest(&image_tensor, output_height, output_width).map_err(Into::into)
396    }
397}