1use image::ImageReader;
2use std::fs;
3use std::path::{Path, PathBuf};
4use yscv_imgproc::resize_nearest;
5use yscv_tensor::Tensor;
6
7use crate::ModelError;
8
9use super::helpers::build_supervised_dataset_from_flat_values;
10use super::types::SupervisedDataset;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum ImageFolderTargetMode {
18 #[default]
19 ClassIndex,
20 OneHot,
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct SupervisedImageFolderConfig {
25 output_height: usize,
26 output_width: usize,
27 target_mode: ImageFolderTargetMode,
28 allowed_extensions: Vec<String>,
29}
30
31impl SupervisedImageFolderConfig {
32 pub fn new(output_height: usize, output_width: usize) -> Result<Self, ModelError> {
33 if output_height == 0 {
34 return Err(ModelError::InvalidDatasetAdapterShape {
35 field: "output_height",
36 shape: vec![output_height],
37 message: "output_height must be > 0".to_string(),
38 });
39 }
40 if output_width == 0 {
41 return Err(ModelError::InvalidDatasetAdapterShape {
42 field: "output_width",
43 shape: vec![output_width],
44 message: "output_width must be > 0".to_string(),
45 });
46 }
47 Ok(Self {
48 output_height,
49 output_width,
50 target_mode: ImageFolderTargetMode::ClassIndex,
51 allowed_extensions: default_image_folder_extensions(),
52 })
53 }
54
55 pub fn output_height(&self) -> usize {
56 self.output_height
57 }
58
59 pub fn output_width(&self) -> usize {
60 self.output_width
61 }
62
63 pub fn with_target_mode(mut self, target_mode: ImageFolderTargetMode) -> Self {
64 self.target_mode = target_mode;
65 self
66 }
67
68 pub fn target_mode(&self) -> ImageFolderTargetMode {
69 self.target_mode
70 }
71
72 pub fn with_allowed_extensions(
73 mut self,
74 allowed_extensions: Vec<String>,
75 ) -> Result<Self, ModelError> {
76 self.allowed_extensions = normalize_image_extensions(allowed_extensions)?;
77 Ok(self)
78 }
79
80 pub fn allowed_extensions(&self) -> &[String] {
81 &self.allowed_extensions
82 }
83}
84
85#[derive(Debug, Clone, PartialEq)]
87pub struct SupervisedImageFolderLoadResult {
88 pub dataset: SupervisedDataset,
89 pub class_names: Vec<String>,
90}
91
92pub fn load_supervised_image_folder_dataset<P: AsRef<Path>>(
106 root: P,
107 config: &SupervisedImageFolderConfig,
108) -> Result<SupervisedDataset, ModelError> {
109 load_supervised_image_folder_dataset_with_classes(root, config).map(|loaded| loaded.dataset)
110}
111
112pub fn load_supervised_image_folder_dataset_with_classes<P: AsRef<Path>>(
117 root: P,
118 config: &SupervisedImageFolderConfig,
119) -> Result<SupervisedImageFolderLoadResult, ModelError> {
120 let root_ref = root.as_ref();
121 let class_dirs = read_sorted_class_directories(root_ref)?;
122 let class_names = class_dirs
123 .iter()
124 .map(|class_dir| class_name_from_path(class_dir))
125 .collect::<Result<Vec<_>, _>>()?;
126 let class_count = class_dirs.len();
127
128 let mut input_values = Vec::new();
129 let mut target_values = Vec::new();
130 let mut sample_count = 0usize;
131
132 for (class_id, class_dir) in class_dirs.iter().enumerate() {
133 let image_files =
134 read_sorted_supported_image_files(class_dir, config.allowed_extensions())?;
135 for image_path in image_files {
136 let image_tensor = load_image_as_normalized_rgb_tensor(
137 &image_path,
138 config.output_height(),
139 config.output_width(),
140 )?;
141 input_values.extend_from_slice(image_tensor.data());
142 append_image_folder_target(
143 &mut target_values,
144 class_id,
145 class_count,
146 config.target_mode(),
147 )?;
148 sample_count = sample_count.checked_add(1).ok_or_else(|| {
149 ModelError::InvalidDatasetAdapterShape {
150 field: "sample_count",
151 shape: vec![sample_count],
152 message: "sample count overflow".to_string(),
153 }
154 })?;
155 }
156 }
157
158 if sample_count == 0 {
159 return Err(ModelError::EmptyDataset);
160 }
161
162 let target_shape = match config.target_mode() {
163 ImageFolderTargetMode::ClassIndex => vec![1],
164 ImageFolderTargetMode::OneHot => vec![class_count],
165 };
166
167 let dataset = build_supervised_dataset_from_flat_values(
168 &[config.output_height(), config.output_width(), 3],
169 &target_shape,
170 sample_count,
171 input_values,
172 target_values,
173 )?;
174
175 Ok(SupervisedImageFolderLoadResult {
176 dataset,
177 class_names,
178 })
179}
180
181fn class_name_from_path(class_dir: &Path) -> Result<String, ModelError> {
182 class_dir
183 .file_name()
184 .map(|name| name.to_string_lossy().into_owned())
185 .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
186 field: "class_name",
187 shape: Vec::new(),
188 message: format!(
189 "failed to infer class name from directory path {}",
190 class_dir.display()
191 ),
192 })
193}
194
195fn append_image_folder_target(
196 target_values: &mut Vec<f32>,
197 class_id: usize,
198 class_count: usize,
199 target_mode: ImageFolderTargetMode,
200) -> Result<(), ModelError> {
201 match target_mode {
202 ImageFolderTargetMode::ClassIndex => {
203 target_values.push(class_id as f32);
204 Ok(())
205 }
206 ImageFolderTargetMode::OneHot => {
207 if class_count == 0 || class_id >= class_count {
208 return Err(ModelError::InvalidDatasetAdapterShape {
209 field: "target_shape",
210 shape: vec![class_count],
211 message: "invalid one-hot class configuration".to_string(),
212 });
213 }
214 let next_len = target_values
215 .len()
216 .checked_add(class_count)
217 .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
218 field: "target_values",
219 shape: vec![target_values.len(), class_count],
220 message: "target vector length overflow".to_string(),
221 })?;
222 target_values.resize(next_len, 0.0);
223 let target_start = next_len - class_count;
224 target_values[target_start + class_id] = 1.0;
225 Ok(())
226 }
227 }
228}
229
230fn read_sorted_class_directories(root: &Path) -> Result<Vec<PathBuf>, ModelError> {
231 let entries = fs::read_dir(root).map_err(|error| ModelError::DatasetLoadIo {
232 path: root.display().to_string(),
233 message: error.to_string(),
234 })?;
235
236 let mut class_dirs = Vec::new();
237 for entry in entries {
238 let entry = entry.map_err(|error| ModelError::DatasetLoadIo {
239 path: root.display().to_string(),
240 message: error.to_string(),
241 })?;
242 let file_type = entry
243 .file_type()
244 .map_err(|error| ModelError::DatasetLoadIo {
245 path: entry.path().display().to_string(),
246 message: error.to_string(),
247 })?;
248 if file_type.is_dir() {
249 class_dirs.push(entry.path());
250 }
251 }
252
253 class_dirs.sort_by(|left, right| {
254 let left_name = left
255 .file_name()
256 .map(|name| name.to_string_lossy().into_owned())
257 .unwrap_or_default();
258 let right_name = right
259 .file_name()
260 .map(|name| name.to_string_lossy().into_owned())
261 .unwrap_or_default();
262 left_name.cmp(&right_name).then_with(|| left.cmp(right))
263 });
264 Ok(class_dirs)
265}
266
267fn read_sorted_supported_image_files(
268 class_dir: &Path,
269 allowed_extensions: &[String],
270) -> Result<Vec<PathBuf>, ModelError> {
271 let entries = fs::read_dir(class_dir).map_err(|error| ModelError::DatasetLoadIo {
272 path: class_dir.display().to_string(),
273 message: error.to_string(),
274 })?;
275
276 let mut image_files = Vec::new();
277 for entry in entries {
278 let entry = entry.map_err(|error| ModelError::DatasetLoadIo {
279 path: class_dir.display().to_string(),
280 message: error.to_string(),
281 })?;
282 let file_type = entry
283 .file_type()
284 .map_err(|error| ModelError::DatasetLoadIo {
285 path: entry.path().display().to_string(),
286 message: error.to_string(),
287 })?;
288 if file_type.is_file() && has_supported_image_extension(&entry.path(), allowed_extensions) {
289 image_files.push(entry.path());
290 }
291 }
292
293 image_files.sort_by(|left, right| {
294 let left_name = left
295 .file_name()
296 .map(|name| name.to_string_lossy().into_owned())
297 .unwrap_or_default();
298 let right_name = right
299 .file_name()
300 .map(|name| name.to_string_lossy().into_owned())
301 .unwrap_or_default();
302 left_name.cmp(&right_name).then_with(|| left.cmp(right))
303 });
304 Ok(image_files)
305}
306
307fn has_supported_image_extension(path: &Path, allowed_extensions: &[String]) -> bool {
308 let Some(extension) = path.extension().and_then(|extension| extension.to_str()) else {
309 return false;
310 };
311 allowed_extensions
312 .iter()
313 .any(|allowed| extension.eq_ignore_ascii_case(allowed))
314}
315
316fn default_image_folder_extensions() -> Vec<String> {
317 ["jpg", "jpeg", "png", "bmp", "webp"]
318 .into_iter()
319 .map(str::to_string)
320 .collect()
321}
322
323fn normalize_image_extensions(extensions: Vec<String>) -> Result<Vec<String>, ModelError> {
324 if extensions.is_empty() {
325 return Err(ModelError::InvalidImageFolderExtension {
326 extension: "<list>".to_string(),
327 message: "extension list must be non-empty".to_string(),
328 });
329 }
330
331 let mut normalized = Vec::with_capacity(extensions.len());
332 for extension in extensions {
333 let trimmed = extension.trim();
334 if trimmed.is_empty() {
335 return Err(ModelError::InvalidImageFolderExtension {
336 extension,
337 message: "extension must be non-empty".to_string(),
338 });
339 }
340 if trimmed.starts_with('.') {
341 return Err(ModelError::InvalidImageFolderExtension {
342 extension: trimmed.to_string(),
343 message: "extension must not start with '.'".to_string(),
344 });
345 }
346 if !trimmed
347 .bytes()
348 .all(|byte| byte.is_ascii_alphanumeric() || byte == b'_')
349 {
350 return Err(ModelError::InvalidImageFolderExtension {
351 extension: trimmed.to_string(),
352 message: "extension must contain only ASCII letters, digits, or '_'".to_string(),
353 });
354 }
355 let lowered = trimmed.to_ascii_lowercase();
356 if !normalized.iter().any(|existing| existing == &lowered) {
357 normalized.push(lowered);
358 }
359 }
360
361 Ok(normalized)
362}
363
364pub(super) fn load_image_as_normalized_rgb_tensor(
365 path: &Path,
366 output_height: usize,
367 output_width: usize,
368) -> Result<Tensor, ModelError> {
369 let path_string = path.display().to_string();
370 let decoded = ImageReader::open(path)
371 .map_err(|error| ModelError::DatasetImageDecode {
372 path: path_string.clone(),
373 message: error.to_string(),
374 })?
375 .decode()
376 .map_err(|error| ModelError::DatasetImageDecode {
377 path: path_string.clone(),
378 message: error.to_string(),
379 })?;
380 let rgb = decoded.to_rgb8();
381 let (width, height) = rgb.dimensions();
382
383 let image_height = height as usize;
384 let image_width = width as usize;
385 let normalized = rgb
386 .as_raw()
387 .iter()
388 .map(|value| (*value as f32) * (1.0 / 255.0))
389 .collect::<Vec<_>>();
390 let image_tensor = Tensor::from_vec(vec![image_height, image_width, 3], normalized)?;
391
392 if image_height == output_height && image_width == output_width {
393 Ok(image_tensor)
394 } else {
395 resize_nearest(&image_tensor, output_height, output_width).map_err(Into::into)
396 }
397}