Skip to main content

rust_mlp/
data.rs

1//! Contiguous dataset helpers.
2//!
3//! The training loop operates on slices to avoid per-step allocations. `Inputs` and
4//! `Dataset` provide validated, row-major storage for feature/target matrices.
5
6use crate::{Error, Result};
7
8/// A collection of input samples (X).
9///
10/// Stored as a contiguous buffer with row-major layout:
11/// - `inputs.len() == len * input_dim`
12#[derive(Debug, Clone)]
13pub struct Inputs {
14    inputs: Vec<f32>,
15    len: usize,
16    input_dim: usize,
17}
18
19impl Inputs {
20    /// Build inputs from a flat buffer with shape `(len, input_dim)`.
21    pub fn from_flat(inputs: Vec<f32>, input_dim: usize) -> Result<Self> {
22        if input_dim == 0 {
23            return Err(Error::InvalidData("input_dim must be > 0".to_owned()));
24        }
25        if inputs.len() % input_dim != 0 {
26            return Err(Error::InvalidData(format!(
27                "inputs length {} is not divisible by input_dim {}",
28                inputs.len(),
29                input_dim
30            )));
31        }
32
33        let len = inputs.len() / input_dim;
34
35        Ok(Self {
36            inputs,
37            len,
38            input_dim,
39        })
40    }
41
42    /// Build inputs from per-sample rows.
43    ///
44    /// This is a convenience constructor (it copies into contiguous storage).
45    pub fn from_rows(inputs: &[Vec<f32>]) -> Result<Self> {
46        if inputs.is_empty() {
47            return Err(Error::InvalidData("inputs must not be empty".to_owned()));
48        }
49
50        let input_dim = inputs[0].len();
51        if input_dim == 0 {
52            return Err(Error::InvalidData("input_dim must be > 0".to_owned()));
53        }
54
55        for (i, row) in inputs.iter().enumerate() {
56            if row.len() != input_dim {
57                return Err(Error::InvalidData(format!(
58                    "input row {i} has len {}, expected {input_dim}",
59                    row.len()
60                )));
61            }
62        }
63
64        let len = inputs.len();
65        let mut inputs_flat = Vec::with_capacity(len * input_dim);
66        for row in inputs {
67            inputs_flat.extend_from_slice(row);
68        }
69
70        Ok(Self {
71            inputs: inputs_flat,
72            len,
73            input_dim,
74        })
75    }
76
77    #[inline]
78    /// Returns the number of samples.
79    pub fn len(&self) -> usize {
80        self.len
81    }
82
83    #[inline]
84    /// Returns true if there are no samples.
85    pub fn is_empty(&self) -> bool {
86        self.len == 0
87    }
88
89    #[inline]
90    /// Returns the per-sample input dimension.
91    pub fn input_dim(&self) -> usize {
92        self.input_dim
93    }
94
95    #[inline]
96    /// Returns the underlying contiguous buffer.
97    ///
98    /// Shape: `(len * input_dim,)`.
99    pub fn as_flat(&self) -> &[f32] {
100        &self.inputs
101    }
102
103    #[inline]
104    /// Returns the `idx`-th input row (shape: `(input_dim,)`).
105    ///
106    /// Panics if `idx >= len`.
107    pub fn input(&self, idx: usize) -> &[f32] {
108        let start = idx * self.input_dim;
109        &self.inputs[start..start + self.input_dim]
110    }
111}
112
113/// A supervised dataset: inputs (X) and targets (Y).
114///
115/// Stored as contiguous buffers with row-major layout:
116/// - `inputs.len() == len * input_dim`
117/// - `targets.len() == len * target_dim`
118#[derive(Debug, Clone)]
119pub struct Dataset {
120    inputs: Inputs,
121    targets: Vec<f32>,
122    target_dim: usize,
123}
124
125impl Dataset {
126    /// Build a dataset from flat buffers.
127    ///
128    /// `inputs` is `(len, input_dim)` and `targets` is `(len, target_dim)`.
129    pub fn from_flat(
130        inputs: Vec<f32>,
131        targets: Vec<f32>,
132        input_dim: usize,
133        target_dim: usize,
134    ) -> Result<Self> {
135        let inputs = Inputs::from_flat(inputs, input_dim)?;
136        if target_dim == 0 {
137            return Err(Error::InvalidData("target_dim must be > 0".to_owned()));
138        }
139
140        if targets.len() != inputs.len() * target_dim {
141            return Err(Error::InvalidData(format!(
142                "targets length {} does not match len * target_dim ({} * {})",
143                targets.len(),
144                inputs.len(),
145                target_dim
146            )));
147        }
148
149        Ok(Self {
150            inputs,
151            targets,
152            target_dim,
153        })
154    }
155
156    /// Build a dataset from per-sample rows.
157    ///
158    /// This is a convenience constructor (it copies into contiguous storage).
159    pub fn from_rows(inputs: &[Vec<f32>], targets: &[Vec<f32>]) -> Result<Self> {
160        if inputs.len() != targets.len() {
161            return Err(Error::InvalidData(format!(
162                "inputs/targets length mismatch: {} vs {}",
163                inputs.len(),
164                targets.len()
165            )));
166        }
167
168        let inputs = Inputs::from_rows(inputs)?;
169        let target_dim = targets.first().map(|t| t.len()).unwrap_or(0);
170        if target_dim == 0 {
171            return Err(Error::InvalidData("target_dim must be > 0".to_owned()));
172        }
173        for (i, row) in targets.iter().enumerate() {
174            if row.len() != target_dim {
175                return Err(Error::InvalidData(format!(
176                    "target row {i} has len {}, expected {target_dim}",
177                    row.len()
178                )));
179            }
180        }
181
182        let len = inputs.len();
183        let mut targets_flat = Vec::with_capacity(len * target_dim);
184        for row in targets {
185            targets_flat.extend_from_slice(row);
186        }
187
188        Ok(Self {
189            inputs,
190            targets: targets_flat,
191            target_dim,
192        })
193    }
194
195    #[inline]
196    /// Returns the number of samples.
197    pub fn len(&self) -> usize {
198        self.inputs.len()
199    }
200
201    #[inline]
202    /// Returns true if there are no samples.
203    pub fn is_empty(&self) -> bool {
204        self.inputs.is_empty()
205    }
206
207    #[inline]
208    /// Returns the per-sample input dimension.
209    pub fn input_dim(&self) -> usize {
210        self.inputs.input_dim()
211    }
212
213    #[inline]
214    /// Returns the per-sample target dimension.
215    pub fn target_dim(&self) -> usize {
216        self.target_dim
217    }
218
219    #[inline]
220    /// Returns a view of the inputs (X).
221    pub fn inputs(&self) -> &Inputs {
222        &self.inputs
223    }
224
225    #[inline]
226    /// Returns the underlying contiguous inputs buffer.
227    ///
228    /// Shape: `(len * input_dim,)`.
229    pub fn inputs_flat(&self) -> &[f32] {
230        self.inputs.as_flat()
231    }
232
233    #[inline]
234    /// Returns the underlying contiguous targets buffer.
235    ///
236    /// Shape: `(len * target_dim,)`.
237    pub fn targets_flat(&self) -> &[f32] {
238        &self.targets
239    }
240
241    #[inline]
242    /// Returns the `idx`-th input row (shape: `(input_dim,)`).
243    ///
244    /// Panics if `idx >= len`.
245    pub fn input(&self, idx: usize) -> &[f32] {
246        self.inputs.input(idx)
247    }
248
249    #[inline]
250    /// Returns the `idx`-th target row (shape: `(target_dim,)`).
251    ///
252    /// Panics if `idx >= len`.
253    pub fn target(&self, idx: usize) -> &[f32] {
254        let start = idx * self.target_dim;
255        &self.targets[start..start + self.target_dim]
256    }
257
258    /// Returns a contiguous batch view.
259    ///
260    /// Panics if the requested range is out of bounds.
261    pub fn batch(&self, start: usize, len: usize) -> Batch<'_> {
262        assert!(len > 0, "batch len must be > 0");
263        assert!(start < self.len(), "batch start out of bounds");
264        assert!(
265            start + len <= self.len(),
266            "batch range out of bounds: start={start} len={len} dataset_len={}",
267            self.len()
268        );
269
270        let in_dim = self.input_dim();
271        let t_dim = self.target_dim();
272        let x0 = start * in_dim;
273        let x1 = (start + len) * in_dim;
274        let y0 = start * t_dim;
275        let y1 = (start + len) * t_dim;
276        Batch {
277            inputs: &self.inputs_flat()[x0..x1],
278            targets: &self.targets_flat()[y0..y1],
279            len,
280            input_dim: in_dim,
281            target_dim: t_dim,
282        }
283    }
284
285    /// Iterate contiguous batch views.
286    ///
287    /// Panics if `batch_size == 0`.
288    pub fn batches(&self, batch_size: usize) -> Batches<'_> {
289        assert!(batch_size > 0, "batch_size must be > 0");
290        Batches {
291            data: self,
292            batch_size,
293            pos: 0,
294        }
295    }
296}
297
298/// A contiguous dataset batch view.
299///
300/// `inputs` and `targets` are flat row-major buffers.
301#[derive(Debug, Clone, Copy)]
302pub struct Batch<'a> {
303    inputs: &'a [f32],
304    targets: &'a [f32],
305    len: usize,
306    input_dim: usize,
307    target_dim: usize,
308}
309
310impl<'a> Batch<'a> {
311    #[inline]
312    /// Returns the number of samples in this batch.
313    pub fn len(&self) -> usize {
314        self.len
315    }
316
317    #[inline]
318    /// Returns true if this batch is empty.
319    pub fn is_empty(&self) -> bool {
320        self.len == 0
321    }
322
323    #[inline]
324    /// Returns the per-sample input dimension.
325    pub fn input_dim(&self) -> usize {
326        self.input_dim
327    }
328
329    #[inline]
330    /// Returns the per-sample target dimension.
331    pub fn target_dim(&self) -> usize {
332        self.target_dim
333    }
334
335    #[inline]
336    /// Returns the contiguous flat inputs buffer.
337    ///
338    /// Shape: `(len * input_dim,)`.
339    pub fn inputs_flat(&self) -> &'a [f32] {
340        self.inputs
341    }
342
343    #[inline]
344    /// Returns the contiguous flat targets buffer.
345    ///
346    /// Shape: `(len * target_dim,)`.
347    pub fn targets_flat(&self) -> &'a [f32] {
348        self.targets
349    }
350
351    #[inline]
352    /// Returns the `idx`-th input row (shape: `(input_dim,)`).
353    ///
354    /// Panics if `idx >= len`.
355    pub fn input(&self, idx: usize) -> &'a [f32] {
356        let start = idx * self.input_dim;
357        &self.inputs[start..start + self.input_dim]
358    }
359
360    #[inline]
361    /// Returns the `idx`-th target row (shape: `(target_dim,)`).
362    ///
363    /// Panics if `idx >= len`.
364    pub fn target(&self, idx: usize) -> &'a [f32] {
365        let start = idx * self.target_dim;
366        &self.targets[start..start + self.target_dim]
367    }
368}
369
370/// Iterator over contiguous batches.
371#[derive(Debug, Clone)]
372pub struct Batches<'a> {
373    data: &'a Dataset,
374    batch_size: usize,
375    pos: usize,
376}
377
378impl<'a> Iterator for Batches<'a> {
379    type Item = Batch<'a>;
380
381    fn next(&mut self) -> Option<Self::Item> {
382        if self.pos >= self.data.len() {
383            return None;
384        }
385
386        let start = self.pos;
387        let end = (start + self.batch_size).min(self.data.len());
388        self.pos = end;
389        Some(self.data.batch(start, end - start))
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn dataset_from_flat_validates_shapes() {
399        let ok = Dataset::from_flat(vec![0.0, 1.0, 2.0, 3.0], vec![0.0, 1.0], 2, 1);
400        assert!(ok.is_ok());
401
402        let err = Dataset::from_flat(vec![0.0, 1.0, 2.0], vec![0.0], 2, 1);
403        assert!(err.is_err());
404    }
405
406    #[test]
407    fn batches_cover_all_samples_in_order() {
408        // len=5, input_dim=2, target_dim=1
409        let x = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
410        let y = vec![10.0, 11.0, 12.0, 13.0, 14.0];
411        let data = Dataset::from_flat(x, y, 2, 1).unwrap();
412
413        let batches: Vec<_> = data.batches(2).collect();
414        assert_eq!(batches.len(), 3);
415
416        assert_eq!(batches[0].len(), 2);
417        assert_eq!(batches[0].inputs_flat(), &[0.0, 1.0, 2.0, 3.0]);
418        assert_eq!(batches[0].targets_flat(), &[10.0, 11.0]);
419
420        assert_eq!(batches[1].len(), 2);
421        assert_eq!(batches[1].inputs_flat(), &[4.0, 5.0, 6.0, 7.0]);
422        assert_eq!(batches[1].targets_flat(), &[12.0, 13.0]);
423
424        assert_eq!(batches[2].len(), 1);
425        assert_eq!(batches[2].inputs_flat(), &[8.0, 9.0]);
426        assert_eq!(batches[2].targets_flat(), &[14.0]);
427    }
428}