yscv_model/dataset/
image_manifest.rs1use std::path::{Path, PathBuf};
2
3use crate::ModelError;
4
5use super::helpers::{
6 adapter_sample_len, build_supervised_dataset_from_flat_values, load_dataset_text_file,
7 validate_adapter_sample_shape, validate_csv_delimiter, validate_finite_values,
8};
9use super::image_folder::load_image_as_normalized_rgb_tensor;
10
11#[derive(Debug, Clone, PartialEq)]
16pub struct SupervisedImageManifestConfig {
17 target_shape: Vec<usize>,
18 output_height: usize,
19 output_width: usize,
20 pub(super) image_root: PathBuf,
21 delimiter: char,
22 has_header: bool,
23}
24
25impl SupervisedImageManifestConfig {
26 pub fn new(
27 target_shape: Vec<usize>,
28 output_height: usize,
29 output_width: usize,
30 ) -> Result<Self, ModelError> {
31 validate_adapter_sample_shape("target_shape", &target_shape)?;
32 if output_height == 0 {
33 return Err(ModelError::InvalidDatasetAdapterShape {
34 field: "output_height",
35 shape: vec![output_height],
36 message: "output_height must be > 0".to_string(),
37 });
38 }
39 if output_width == 0 {
40 return Err(ModelError::InvalidDatasetAdapterShape {
41 field: "output_width",
42 shape: vec![output_width],
43 message: "output_width must be > 0".to_string(),
44 });
45 }
46
47 Ok(Self {
48 target_shape,
49 output_height,
50 output_width,
51 image_root: PathBuf::from("."),
52 delimiter: ',',
53 has_header: false,
54 })
55 }
56
57 pub fn with_image_root<P: Into<PathBuf>>(mut self, image_root: P) -> Self {
58 self.image_root = image_root.into();
59 self
60 }
61
62 pub fn with_delimiter(mut self, delimiter: char) -> Result<Self, ModelError> {
63 validate_csv_delimiter(delimiter)?;
64 self.delimiter = delimiter;
65 Ok(self)
66 }
67
68 pub fn with_header(mut self, has_header: bool) -> Self {
69 self.has_header = has_header;
70 self
71 }
72
73 pub fn target_shape(&self) -> &[usize] {
74 &self.target_shape
75 }
76
77 pub fn output_height(&self) -> usize {
78 self.output_height
79 }
80
81 pub fn output_width(&self) -> usize {
82 self.output_width
83 }
84
85 pub fn image_root(&self) -> &Path {
86 &self.image_root
87 }
88
89 pub fn delimiter(&self) -> char {
90 self.delimiter
91 }
92
93 pub fn has_header(&self) -> bool {
94 self.has_header
95 }
96}
97
98fn resolve_manifest_image_path(image_root: &Path, image_field: &str) -> PathBuf {
99 let manifest_path = Path::new(image_field);
100 if manifest_path.is_absolute() {
101 manifest_path.to_path_buf()
102 } else {
103 image_root.join(manifest_path)
104 }
105}
106
107pub fn parse_supervised_image_manifest_csv(
112 content: &str,
113 config: &SupervisedImageManifestConfig,
114) -> Result<super::types::SupervisedDataset, ModelError> {
115 let target_row_len = adapter_sample_len("target_shape", config.target_shape())?;
116 let expected_columns =
117 target_row_len
118 .checked_add(1)
119 .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
120 field: "manifest_columns",
121 shape: vec![target_row_len],
122 message: "column count overflow".to_string(),
123 })?;
124
125 let mut input_values = Vec::new();
126 let mut target_values = Vec::new();
127 let mut sample_count = 0usize;
128 let mut header_skipped = false;
129 let mut target_row = Vec::with_capacity(target_row_len);
130
131 for (line_idx, raw_line) in content.lines().enumerate() {
132 let line_number = line_idx + 1;
133 let line = raw_line.trim();
134 if line.is_empty() || line.starts_with('#') {
135 continue;
136 }
137 if config.has_header() && !header_skipped {
138 header_skipped = true;
139 continue;
140 }
141
142 let columns = line
143 .split(config.delimiter())
144 .map(str::trim)
145 .collect::<Vec<_>>();
146 if columns.len() != expected_columns {
147 return Err(ModelError::InvalidDatasetRecordColumns {
148 line: line_number,
149 expected: expected_columns,
150 got: columns.len(),
151 });
152 }
153
154 let image_field = columns[0];
155 if image_field.is_empty() {
156 return Err(ModelError::InvalidDatasetRecordPath {
157 line: line_number,
158 message: "image path is empty".to_string(),
159 });
160 }
161
162 let image_path = resolve_manifest_image_path(config.image_root(), image_field);
163 let image_tensor = load_image_as_normalized_rgb_tensor(
164 &image_path,
165 config.output_height(),
166 config.output_width(),
167 )?;
168 input_values.extend_from_slice(image_tensor.data());
169
170 target_row.clear();
171 for (target_idx, target_str) in columns[1..].iter().enumerate() {
172 let value = target_str
173 .parse::<f32>()
174 .map_err(|error| ModelError::DatasetCsvParse {
175 line: line_number,
176 column: target_idx + 2,
177 message: error.to_string(),
178 })?;
179 target_row.push(value);
180 }
181 validate_finite_values(line_number, "target", &target_row)?;
182 target_values.extend_from_slice(&target_row);
183
184 sample_count =
185 sample_count
186 .checked_add(1)
187 .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
188 field: "sample_count",
189 shape: vec![sample_count],
190 message: "sample count overflow".to_string(),
191 })?;
192 }
193
194 if sample_count == 0 {
195 return Err(ModelError::EmptyDataset);
196 }
197
198 build_supervised_dataset_from_flat_values(
199 &[config.output_height(), config.output_width(), 3],
200 config.target_shape(),
201 sample_count,
202 input_values,
203 target_values,
204 )
205}
206
207pub fn load_supervised_image_manifest_csv_file<P: AsRef<Path>>(
209 path: P,
210 config: &SupervisedImageManifestConfig,
211) -> Result<super::types::SupervisedDataset, ModelError> {
212 let path_ref = path.as_ref();
213 let content = load_dataset_text_file(path_ref)?;
214
215 let mut effective_config = config.clone();
216 if !effective_config.image_root.is_absolute() {
217 let manifest_dir = path_ref.parent().unwrap_or_else(|| Path::new("."));
218 effective_config.image_root = manifest_dir.join(&effective_config.image_root);
219 }
220
221 parse_supervised_image_manifest_csv(&content, &effective_config)
222}