1use crate::{Error, Result};
7
8#[derive(Debug, Clone)]
13pub struct Inputs {
14 inputs: Vec<f32>,
15 len: usize,
16 input_dim: usize,
17}
18
19impl Inputs {
20 pub fn from_flat(inputs: Vec<f32>, input_dim: usize) -> Result<Self> {
22 if input_dim == 0 {
23 return Err(Error::InvalidData("input_dim must be > 0".to_owned()));
24 }
25 if inputs.len() % input_dim != 0 {
26 return Err(Error::InvalidData(format!(
27 "inputs length {} is not divisible by input_dim {}",
28 inputs.len(),
29 input_dim
30 )));
31 }
32
33 let len = inputs.len() / input_dim;
34
35 Ok(Self {
36 inputs,
37 len,
38 input_dim,
39 })
40 }
41
42 pub fn from_rows(inputs: &[Vec<f32>]) -> Result<Self> {
46 if inputs.is_empty() {
47 return Err(Error::InvalidData("inputs must not be empty".to_owned()));
48 }
49
50 let input_dim = inputs[0].len();
51 if input_dim == 0 {
52 return Err(Error::InvalidData("input_dim must be > 0".to_owned()));
53 }
54
55 for (i, row) in inputs.iter().enumerate() {
56 if row.len() != input_dim {
57 return Err(Error::InvalidData(format!(
58 "input row {i} has len {}, expected {input_dim}",
59 row.len()
60 )));
61 }
62 }
63
64 let len = inputs.len();
65 let mut inputs_flat = Vec::with_capacity(len * input_dim);
66 for row in inputs {
67 inputs_flat.extend_from_slice(row);
68 }
69
70 Ok(Self {
71 inputs: inputs_flat,
72 len,
73 input_dim,
74 })
75 }
76
77 #[inline]
78 pub fn len(&self) -> usize {
80 self.len
81 }
82
83 #[inline]
84 pub fn is_empty(&self) -> bool {
86 self.len == 0
87 }
88
89 #[inline]
90 pub fn input_dim(&self) -> usize {
92 self.input_dim
93 }
94
95 #[inline]
96 pub fn as_flat(&self) -> &[f32] {
100 &self.inputs
101 }
102
103 #[inline]
104 pub fn input(&self, idx: usize) -> &[f32] {
108 let start = idx * self.input_dim;
109 &self.inputs[start..start + self.input_dim]
110 }
111}
112
113#[derive(Debug, Clone)]
119pub struct Dataset {
120 inputs: Inputs,
121 targets: Vec<f32>,
122 target_dim: usize,
123}
124
125impl Dataset {
126 pub fn from_flat(
130 inputs: Vec<f32>,
131 targets: Vec<f32>,
132 input_dim: usize,
133 target_dim: usize,
134 ) -> Result<Self> {
135 let inputs = Inputs::from_flat(inputs, input_dim)?;
136 if target_dim == 0 {
137 return Err(Error::InvalidData("target_dim must be > 0".to_owned()));
138 }
139
140 if targets.len() != inputs.len() * target_dim {
141 return Err(Error::InvalidData(format!(
142 "targets length {} does not match len * target_dim ({} * {})",
143 targets.len(),
144 inputs.len(),
145 target_dim
146 )));
147 }
148
149 Ok(Self {
150 inputs,
151 targets,
152 target_dim,
153 })
154 }
155
156 pub fn from_rows(inputs: &[Vec<f32>], targets: &[Vec<f32>]) -> Result<Self> {
160 if inputs.len() != targets.len() {
161 return Err(Error::InvalidData(format!(
162 "inputs/targets length mismatch: {} vs {}",
163 inputs.len(),
164 targets.len()
165 )));
166 }
167
168 let inputs = Inputs::from_rows(inputs)?;
169 let target_dim = targets.first().map(|t| t.len()).unwrap_or(0);
170 if target_dim == 0 {
171 return Err(Error::InvalidData("target_dim must be > 0".to_owned()));
172 }
173 for (i, row) in targets.iter().enumerate() {
174 if row.len() != target_dim {
175 return Err(Error::InvalidData(format!(
176 "target row {i} has len {}, expected {target_dim}",
177 row.len()
178 )));
179 }
180 }
181
182 let len = inputs.len();
183 let mut targets_flat = Vec::with_capacity(len * target_dim);
184 for row in targets {
185 targets_flat.extend_from_slice(row);
186 }
187
188 Ok(Self {
189 inputs,
190 targets: targets_flat,
191 target_dim,
192 })
193 }
194
195 #[inline]
196 pub fn len(&self) -> usize {
198 self.inputs.len()
199 }
200
201 #[inline]
202 pub fn is_empty(&self) -> bool {
204 self.inputs.is_empty()
205 }
206
207 #[inline]
208 pub fn input_dim(&self) -> usize {
210 self.inputs.input_dim()
211 }
212
213 #[inline]
214 pub fn target_dim(&self) -> usize {
216 self.target_dim
217 }
218
219 #[inline]
220 pub fn inputs(&self) -> &Inputs {
222 &self.inputs
223 }
224
225 #[inline]
226 pub fn inputs_flat(&self) -> &[f32] {
230 self.inputs.as_flat()
231 }
232
233 #[inline]
234 pub fn targets_flat(&self) -> &[f32] {
238 &self.targets
239 }
240
241 #[inline]
242 pub fn input(&self, idx: usize) -> &[f32] {
246 self.inputs.input(idx)
247 }
248
249 #[inline]
250 pub fn target(&self, idx: usize) -> &[f32] {
254 let start = idx * self.target_dim;
255 &self.targets[start..start + self.target_dim]
256 }
257
258 pub fn batch(&self, start: usize, len: usize) -> Batch<'_> {
262 assert!(len > 0, "batch len must be > 0");
263 assert!(start < self.len(), "batch start out of bounds");
264 assert!(
265 start + len <= self.len(),
266 "batch range out of bounds: start={start} len={len} dataset_len={}",
267 self.len()
268 );
269
270 let in_dim = self.input_dim();
271 let t_dim = self.target_dim();
272 let x0 = start * in_dim;
273 let x1 = (start + len) * in_dim;
274 let y0 = start * t_dim;
275 let y1 = (start + len) * t_dim;
276 Batch {
277 inputs: &self.inputs_flat()[x0..x1],
278 targets: &self.targets_flat()[y0..y1],
279 len,
280 input_dim: in_dim,
281 target_dim: t_dim,
282 }
283 }
284
285 pub fn batches(&self, batch_size: usize) -> Batches<'_> {
289 assert!(batch_size > 0, "batch_size must be > 0");
290 Batches {
291 data: self,
292 batch_size,
293 pos: 0,
294 }
295 }
296}
297
298#[derive(Debug, Clone, Copy)]
302pub struct Batch<'a> {
303 inputs: &'a [f32],
304 targets: &'a [f32],
305 len: usize,
306 input_dim: usize,
307 target_dim: usize,
308}
309
310impl<'a> Batch<'a> {
311 #[inline]
312 pub fn len(&self) -> usize {
314 self.len
315 }
316
317 #[inline]
318 pub fn is_empty(&self) -> bool {
320 self.len == 0
321 }
322
323 #[inline]
324 pub fn input_dim(&self) -> usize {
326 self.input_dim
327 }
328
329 #[inline]
330 pub fn target_dim(&self) -> usize {
332 self.target_dim
333 }
334
335 #[inline]
336 pub fn inputs_flat(&self) -> &'a [f32] {
340 self.inputs
341 }
342
343 #[inline]
344 pub fn targets_flat(&self) -> &'a [f32] {
348 self.targets
349 }
350
351 #[inline]
352 pub fn input(&self, idx: usize) -> &'a [f32] {
356 let start = idx * self.input_dim;
357 &self.inputs[start..start + self.input_dim]
358 }
359
360 #[inline]
361 pub fn target(&self, idx: usize) -> &'a [f32] {
365 let start = idx * self.target_dim;
366 &self.targets[start..start + self.target_dim]
367 }
368}
369
370#[derive(Debug, Clone)]
372pub struct Batches<'a> {
373 data: &'a Dataset,
374 batch_size: usize,
375 pos: usize,
376}
377
378impl<'a> Iterator for Batches<'a> {
379 type Item = Batch<'a>;
380
381 fn next(&mut self) -> Option<Self::Item> {
382 if self.pos >= self.data.len() {
383 return None;
384 }
385
386 let start = self.pos;
387 let end = (start + self.batch_size).min(self.data.len());
388 self.pos = end;
389 Some(self.data.batch(start, end - start))
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn dataset_from_flat_validates_shapes() {
399 let ok = Dataset::from_flat(vec![0.0, 1.0, 2.0, 3.0], vec![0.0, 1.0], 2, 1);
400 assert!(ok.is_ok());
401
402 let err = Dataset::from_flat(vec![0.0, 1.0, 2.0], vec![0.0], 2, 1);
403 assert!(err.is_err());
404 }
405
406 #[test]
407 fn batches_cover_all_samples_in_order() {
408 let x = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
410 let y = vec![10.0, 11.0, 12.0, 13.0, 14.0];
411 let data = Dataset::from_flat(x, y, 2, 1).unwrap();
412
413 let batches: Vec<_> = data.batches(2).collect();
414 assert_eq!(batches.len(), 3);
415
416 assert_eq!(batches[0].len(), 2);
417 assert_eq!(batches[0].inputs_flat(), &[0.0, 1.0, 2.0, 3.0]);
418 assert_eq!(batches[0].targets_flat(), &[10.0, 11.0]);
419
420 assert_eq!(batches[1].len(), 2);
421 assert_eq!(batches[1].inputs_flat(), &[4.0, 5.0, 6.0, 7.0]);
422 assert_eq!(batches[1].targets_flat(), &[12.0, 13.0]);
423
424 assert_eq!(batches[2].len(), 1);
425 assert_eq!(batches[2].inputs_flat(), &[8.0, 9.0]);
426 assert_eq!(batches[2].targets_flat(), &[14.0]);
427 }
428}