Skip to main content

yscv_model/dataset/
image_manifest.rs

1use std::path::{Path, PathBuf};
2
3use crate::ModelError;
4
5use super::helpers::{
6    adapter_sample_len, build_supervised_dataset_from_flat_values, load_dataset_text_file,
7    validate_adapter_sample_shape, validate_csv_delimiter, validate_finite_values,
8};
9use super::image_folder::load_image_as_normalized_rgb_tensor;
10
11/// Configuration for parsing/loading supervised image-manifest CSV datasets.
12///
13/// Expected row format:
14/// `image_path,target_0,target_1,...`
15#[derive(Debug, Clone, PartialEq)]
16pub struct SupervisedImageManifestConfig {
17    target_shape: Vec<usize>,
18    output_height: usize,
19    output_width: usize,
20    pub(super) image_root: PathBuf,
21    delimiter: char,
22    has_header: bool,
23}
24
25impl SupervisedImageManifestConfig {
26    pub fn new(
27        target_shape: Vec<usize>,
28        output_height: usize,
29        output_width: usize,
30    ) -> Result<Self, ModelError> {
31        validate_adapter_sample_shape("target_shape", &target_shape)?;
32        if output_height == 0 {
33            return Err(ModelError::InvalidDatasetAdapterShape {
34                field: "output_height",
35                shape: vec![output_height],
36                message: "output_height must be > 0".to_string(),
37            });
38        }
39        if output_width == 0 {
40            return Err(ModelError::InvalidDatasetAdapterShape {
41                field: "output_width",
42                shape: vec![output_width],
43                message: "output_width must be > 0".to_string(),
44            });
45        }
46
47        Ok(Self {
48            target_shape,
49            output_height,
50            output_width,
51            image_root: PathBuf::from("."),
52            delimiter: ',',
53            has_header: false,
54        })
55    }
56
57    pub fn with_image_root<P: Into<PathBuf>>(mut self, image_root: P) -> Self {
58        self.image_root = image_root.into();
59        self
60    }
61
62    pub fn with_delimiter(mut self, delimiter: char) -> Result<Self, ModelError> {
63        validate_csv_delimiter(delimiter)?;
64        self.delimiter = delimiter;
65        Ok(self)
66    }
67
68    pub fn with_header(mut self, has_header: bool) -> Self {
69        self.has_header = has_header;
70        self
71    }
72
73    pub fn target_shape(&self) -> &[usize] {
74        &self.target_shape
75    }
76
77    pub fn output_height(&self) -> usize {
78        self.output_height
79    }
80
81    pub fn output_width(&self) -> usize {
82        self.output_width
83    }
84
85    pub fn image_root(&self) -> &Path {
86        &self.image_root
87    }
88
89    pub fn delimiter(&self) -> char {
90        self.delimiter
91    }
92
93    pub fn has_header(&self) -> bool {
94        self.has_header
95    }
96}
97
98fn resolve_manifest_image_path(image_root: &Path, image_field: &str) -> PathBuf {
99    let manifest_path = Path::new(image_field);
100    if manifest_path.is_absolute() {
101        manifest_path.to_path_buf()
102    } else {
103        image_root.join(manifest_path)
104    }
105}
106
107/// Parses supervised training image-manifest CSV into a `SupervisedDataset`.
108///
109/// Manifest row format:
110/// `image_path,target_0,target_1,...`
111pub fn parse_supervised_image_manifest_csv(
112    content: &str,
113    config: &SupervisedImageManifestConfig,
114) -> Result<super::types::SupervisedDataset, ModelError> {
115    let target_row_len = adapter_sample_len("target_shape", config.target_shape())?;
116    let expected_columns =
117        target_row_len
118            .checked_add(1)
119            .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
120                field: "manifest_columns",
121                shape: vec![target_row_len],
122                message: "column count overflow".to_string(),
123            })?;
124
125    let mut input_values = Vec::new();
126    let mut target_values = Vec::new();
127    let mut sample_count = 0usize;
128    let mut header_skipped = false;
129    let mut target_row = Vec::with_capacity(target_row_len);
130
131    for (line_idx, raw_line) in content.lines().enumerate() {
132        let line_number = line_idx + 1;
133        let line = raw_line.trim();
134        if line.is_empty() || line.starts_with('#') {
135            continue;
136        }
137        if config.has_header() && !header_skipped {
138            header_skipped = true;
139            continue;
140        }
141
142        let columns = line
143            .split(config.delimiter())
144            .map(str::trim)
145            .collect::<Vec<_>>();
146        if columns.len() != expected_columns {
147            return Err(ModelError::InvalidDatasetRecordColumns {
148                line: line_number,
149                expected: expected_columns,
150                got: columns.len(),
151            });
152        }
153
154        let image_field = columns[0];
155        if image_field.is_empty() {
156            return Err(ModelError::InvalidDatasetRecordPath {
157                line: line_number,
158                message: "image path is empty".to_string(),
159            });
160        }
161
162        let image_path = resolve_manifest_image_path(config.image_root(), image_field);
163        let image_tensor = load_image_as_normalized_rgb_tensor(
164            &image_path,
165            config.output_height(),
166            config.output_width(),
167        )?;
168        input_values.extend_from_slice(image_tensor.data());
169
170        target_row.clear();
171        for (target_idx, target_str) in columns[1..].iter().enumerate() {
172            let value = target_str
173                .parse::<f32>()
174                .map_err(|error| ModelError::DatasetCsvParse {
175                    line: line_number,
176                    column: target_idx + 2,
177                    message: error.to_string(),
178                })?;
179            target_row.push(value);
180        }
181        validate_finite_values(line_number, "target", &target_row)?;
182        target_values.extend_from_slice(&target_row);
183
184        sample_count =
185            sample_count
186                .checked_add(1)
187                .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
188                    field: "sample_count",
189                    shape: vec![sample_count],
190                    message: "sample count overflow".to_string(),
191                })?;
192    }
193
194    if sample_count == 0 {
195        return Err(ModelError::EmptyDataset);
196    }
197
198    build_supervised_dataset_from_flat_values(
199        &[config.output_height(), config.output_width(), 3],
200        config.target_shape(),
201        sample_count,
202        input_values,
203        target_values,
204    )
205}
206
207/// Loads supervised training image-manifest CSV from file.
208pub fn load_supervised_image_manifest_csv_file<P: AsRef<Path>>(
209    path: P,
210    config: &SupervisedImageManifestConfig,
211) -> Result<super::types::SupervisedDataset, ModelError> {
212    let path_ref = path.as_ref();
213    let content = load_dataset_text_file(path_ref)?;
214
215    let mut effective_config = config.clone();
216    if !effective_config.image_root.is_absolute() {
217        let manifest_dir = path_ref.parent().unwrap_or_else(|| Path::new("."));
218        effective_config.image_root = manifest_dir.join(&effective_config.image_root);
219    }
220
221    parse_supervised_image_manifest_csv(&content, &effective_config)
222}