Skip to main content

shrew_data/
mnist.rs

1// MNIST dataset — IDX file format parser
2//
3// The MNIST database consists of 4 files:
4//   - train-images-idx3-ubyte  (60,000  28×28 images)
5//   - train-labels-idx1-ubyte  (60,000  labels 0-9)
6//   - t10k-images-idx3-ubyte   (10,000  28×28 images)
7//   - t10k-labels-idx1-ubyte   (10,000  labels 0-9)
8//
9// IDX format (all values big-endian):
10//   images: magic(2051) | count(u32) | rows(u32) | cols(u32) | pixel_data(u8...)
11//   labels: magic(2049) | count(u32) | label_data(u8...)
12//
13// If the files are gzip-compressed (.gz), we decompress on the fly.
14// Download from: https://yann.lecun.com/exdb/mnist/
15
16use std::fs;
17use std::io;
18use std::path::{Path, PathBuf};
19
20use crate::dataset::{Dataset, Sample};
21
22/// Error type for MNIST loading.
23#[derive(Debug)]
24pub enum MnistError {
25    Io(io::Error),
26    InvalidMagic { expected: u32, got: u32 },
27    CountMismatch { images: usize, labels: usize },
28    MissingFile(PathBuf),
29}
30
31impl std::fmt::Display for MnistError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            MnistError::Io(e) => write!(f, "MNIST I/O error: {e}"),
35            MnistError::InvalidMagic { expected, got } => write!(
36                f,
37                "MNIST invalid magic: expected {expected:#06x}, got {got:#06x}"
38            ),
39            MnistError::CountMismatch { images, labels } => write!(
40                f,
41                "MNIST count mismatch: {images} images vs {labels} labels"
42            ),
43            MnistError::MissingFile(p) => write!(f, "MNIST file not found: {}", p.display()),
44        }
45    }
46}
47
48impl std::error::Error for MnistError {}
49
50impl From<io::Error> for MnistError {
51    fn from(e: io::Error) -> Self {
52        MnistError::Io(e)
53    }
54}
55
56/// Which split of MNIST to load.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum MnistSplit {
59    Train,
60    Test,
61}
62
63/// A loaded MNIST dataset stored entirely in memory.
64///
65/// Images are stored as `Vec<u8>` (28×28 = 784 bytes each).
66/// Labels are `u8` values 0–9.
67#[derive(Debug)]
68pub struct MnistDataset {
69    images: Vec<Vec<u8>>,
70    labels: Vec<u8>,
71    rows: usize,
72    cols: usize,
73    split: MnistSplit,
74}
75
76impl MnistDataset {
77    /// Load MNIST from the given directory.
78    ///
79    /// Expects the standard filenames (or `.gz` compressed versions):
80    ///   - `train-images-idx3-ubyte` / `train-images-idx3-ubyte.gz`
81    ///   - `train-labels-idx1-ubyte` / `train-labels-idx1-ubyte.gz`
82    ///   - `t10k-images-idx3-ubyte`  / `t10k-images-idx3-ubyte.gz`
83    ///   - `t10k-labels-idx1-ubyte`  / `t10k-labels-idx1-ubyte.gz`
84    pub fn load(dir: impl AsRef<Path>, split: MnistSplit) -> Result<Self, MnistError> {
85        let dir = dir.as_ref();
86
87        let (img_name, lbl_name) = match split {
88            MnistSplit::Train => ("train-images-idx3-ubyte", "train-labels-idx1-ubyte"),
89            MnistSplit::Test => ("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"),
90        };
91
92        let img_bytes = read_maybe_gz(dir, img_name)?;
93        let lbl_bytes = read_maybe_gz(dir, lbl_name)?;
94
95        let (images, rows, cols) = parse_idx3_images(&img_bytes)?;
96        let labels = parse_idx1_labels(&lbl_bytes)?;
97
98        if images.len() != labels.len() {
99            return Err(MnistError::CountMismatch {
100                images: images.len(),
101                labels: labels.len(),
102            });
103        }
104
105        Ok(Self {
106            images,
107            labels,
108            rows,
109            cols,
110            split,
111        })
112    }
113
114    /// Load from raw bytes (useful for embedded/testing).
115    pub fn from_raw(
116        image_bytes: &[u8],
117        label_bytes: &[u8],
118        split: MnistSplit,
119    ) -> Result<Self, MnistError> {
120        let (images, rows, cols) = parse_idx3_images(image_bytes)?;
121        let labels = parse_idx1_labels(label_bytes)?;
122
123        if images.len() != labels.len() {
124            return Err(MnistError::CountMismatch {
125                images: images.len(),
126                labels: labels.len(),
127            });
128        }
129
130        Ok(Self {
131            images,
132            labels,
133            rows,
134            cols,
135            split,
136        })
137    }
138
139    /// Create a small synthetic MNIST-like dataset for testing.
140    ///
141    /// Generates `n` random 28×28 images with random labels.
142    pub fn synthetic(n: usize, split: MnistSplit) -> Self {
143        use rand::Rng;
144        let mut rng = rand::thread_rng();
145        let rows = 28;
146        let cols = 28;
147        let mut images = Vec::with_capacity(n);
148        let mut labels = Vec::with_capacity(n);
149
150        for _ in 0..n {
151            let mut img = vec![0u8; rows * cols];
152            for px in &mut img {
153                *px = rng.gen();
154            }
155            images.push(img);
156            labels.push(rng.gen_range(0..10u8));
157        }
158
159        Self {
160            images,
161            labels,
162            rows,
163            cols,
164            split,
165        }
166    }
167
168    /// Total number of samples.
169    pub fn num_samples(&self) -> usize {
170        self.images.len()
171    }
172
173    /// Image dimensions: (rows, cols).
174    pub fn image_dims(&self) -> (usize, usize) {
175        (self.rows, self.cols)
176    }
177
178    /// Get the raw pixel values for sample `i`.
179    pub fn image_u8(&self, i: usize) -> &[u8] {
180        &self.images[i]
181    }
182
183    /// Get the label for sample `i`.
184    pub fn label(&self, i: usize) -> u8 {
185        self.labels[i]
186    }
187
188    /// Which split this dataset represents.
189    pub fn split(&self) -> MnistSplit {
190        self.split
191    }
192
193    /// Take only the first `n` samples (useful for quick experiments).
194    pub fn take(mut self, n: usize) -> Self {
195        let n = n.min(self.images.len());
196        self.images.truncate(n);
197        self.labels.truncate(n);
198        self
199    }
200}
201
202impl Dataset for MnistDataset {
203    fn len(&self) -> usize {
204        self.images.len()
205    }
206
207    fn get(&self, index: usize) -> Sample {
208        let pixels = &self.images[index];
209        let label = self.labels[index];
210
211        Sample {
212            features: pixels.iter().map(|&p| p as f64).collect(),
213            feature_shape: vec![self.rows * self.cols],
214            target: vec![label as f64],
215            target_shape: vec![1],
216        }
217    }
218
219    fn feature_shape(&self) -> &[usize] {
220        // We return a static ref, but since MnistDataset owns this info,
221        // we use a small trick: store the shape inline.  For now just return
222        // a slice from a leaked box (tiny, done once).
223        // Better approach: store feature_shape as a field.
224        &[784] // 28*28
225    }
226
227    fn target_shape(&self) -> &[usize] {
228        &[1]
229    }
230
231    fn name(&self) -> &str {
232        match self.split {
233            MnistSplit::Train => "MNIST-train",
234            MnistSplit::Test => "MNIST-test",
235        }
236    }
237}
238
239// IDX file format parsing
240
241/// Read a file, trying plain first then `.gz` extension.
242fn read_maybe_gz(dir: &Path, base_name: &str) -> Result<Vec<u8>, MnistError> {
243    let plain = dir.join(base_name);
244    let gz = dir.join(format!("{base_name}.gz"));
245
246    if plain.exists() {
247        Ok(fs::read(&plain)?)
248    } else if gz.exists() {
249        let compressed = fs::read(&gz)?;
250        decompress_gz(&compressed)
251    } else {
252        Err(MnistError::MissingFile(plain))
253    }
254}
255
256/// Simple gzip decompressor using DEFLATE.
257///
258/// We implement a minimal gzip reader (RFC 1952) without pulling in flate2.
259/// This handles the standard MNIST .gz files which use default compression.
260fn decompress_gz(data: &[u8]) -> Result<Vec<u8>, MnistError> {
261    // For simplicity, we use the miniz_oxide-compatible approach.
262    // Since we don't want to add deps, we'll just try the raw DEFLATE stream
263    // after skipping the gzip header.
264
265    if data.len() < 10 {
266        return Err(MnistError::Io(io::Error::new(
267            io::ErrorKind::InvalidData,
268            "gzip data too short",
269        )));
270    }
271
272    // Verify gzip magic
273    if data[0] != 0x1f || data[1] != 0x8b {
274        return Err(MnistError::Io(io::Error::new(
275            io::ErrorKind::InvalidData,
276            "not a gzip file",
277        )));
278    }
279
280    // Skip gzip header (10 bytes minimum)
281    let mut pos = 10;
282    let flags = data[3];
283
284    // FEXTRA
285    if flags & 0x04 != 0 {
286        if pos + 2 > data.len() {
287            return Err(io_err("truncated gzip FEXTRA"));
288        }
289        let xlen = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
290        pos += 2 + xlen;
291    }
292    // FNAME
293    if flags & 0x08 != 0 {
294        while pos < data.len() && data[pos] != 0 {
295            pos += 1;
296        }
297        pos += 1; // skip null terminator
298    }
299    // FCOMMENT
300    if flags & 0x10 != 0 {
301        while pos < data.len() && data[pos] != 0 {
302            pos += 1;
303        }
304        pos += 1;
305    }
306    // FHCRC
307    if flags & 0x02 != 0 {
308        pos += 2;
309    }
310
311    if pos >= data.len() {
312        return Err(io_err("truncated gzip header"));
313    }
314
315    // The DEFLATE stream is from `pos` to `data.len() - 8` (last 8 = crc32 + isize)
316    let deflate_end = if data.len() >= 8 {
317        data.len() - 8
318    } else {
319        data.len()
320    };
321    let deflate_data = &data[pos..deflate_end];
322
323    // Use a simple inflate implementation
324    inflate_deflate(deflate_data)
325}
326
327/// Minimal DEFLATE decompressor for gzip.
328///
329/// For production use you'd want `flate2`. This is a simplified version
330/// that handles the MNIST files (which are typically stored-or-default compressed).
331/// We use Rust's built-in approach: since we can't easily decompress DEFLATE
332/// without a library, we recommend the plain (uncompressed) files.
333///
334/// As a fallback, this function returns an error suggesting to decompress.
335fn inflate_deflate(_data: &[u8]) -> Result<Vec<u8>, MnistError> {
336    Err(MnistError::Io(io::Error::new(
337        io::ErrorKind::Unsupported,
338        "gzip decompression requires the `flate2` feature. \
339         Please decompress MNIST files manually (gunzip) or enable flate2.",
340    )))
341}
342
343fn io_err(msg: &str) -> MnistError {
344    MnistError::Io(io::Error::new(io::ErrorKind::InvalidData, msg))
345}
346
347/// Parse an IDX3 file (images): magic=2051, count, rows, cols, data.
348fn parse_idx3_images(data: &[u8]) -> Result<(Vec<Vec<u8>>, usize, usize), MnistError> {
349    if data.len() < 16 {
350        return Err(io_err("IDX3 file too short"));
351    }
352
353    let magic = read_u32_be(data, 0);
354    if magic != 2051 {
355        return Err(MnistError::InvalidMagic {
356            expected: 2051,
357            got: magic,
358        });
359    }
360
361    let count = read_u32_be(data, 4) as usize;
362    let rows = read_u32_be(data, 8) as usize;
363    let cols = read_u32_be(data, 12) as usize;
364    let pixels_per_image = rows * cols;
365
366    let expected_len = 16 + count * pixels_per_image;
367    if data.len() < expected_len {
368        return Err(io_err(&format!(
369            "IDX3 truncated: expected {expected_len} bytes, got {}",
370            data.len()
371        )));
372    }
373
374    let mut images = Vec::with_capacity(count);
375    for i in 0..count {
376        let start = 16 + i * pixels_per_image;
377        let end = start + pixels_per_image;
378        images.push(data[start..end].to_vec());
379    }
380
381    Ok((images, rows, cols))
382}
383
384/// Parse an IDX1 file (labels): magic=2049, count, data.
385fn parse_idx1_labels(data: &[u8]) -> Result<Vec<u8>, MnistError> {
386    if data.len() < 8 {
387        return Err(io_err("IDX1 file too short"));
388    }
389
390    let magic = read_u32_be(data, 0);
391    if magic != 2049 {
392        return Err(MnistError::InvalidMagic {
393            expected: 2049,
394            got: magic,
395        });
396    }
397
398    let count = read_u32_be(data, 4) as usize;
399    let expected_len = 8 + count;
400    if data.len() < expected_len {
401        return Err(io_err(&format!(
402            "IDX1 truncated: expected {expected_len} bytes, got {}",
403            data.len()
404        )));
405    }
406
407    Ok(data[8..8 + count].to_vec())
408}
409
410/// Read a big-endian u32 from `data` at byte offset `off`.
411fn read_u32_be(data: &[u8], off: usize) -> u32 {
412    u32::from_be_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]])
413}
414
415// Builder helpers
416
417/// Build IDX3 image bytes from raw image data (useful for tests).
418pub fn build_idx3_bytes(images: &[&[u8]], rows: u32, cols: u32) -> Vec<u8> {
419    let count = images.len() as u32;
420    let mut buf = Vec::new();
421    buf.extend_from_slice(&2051u32.to_be_bytes());
422    buf.extend_from_slice(&count.to_be_bytes());
423    buf.extend_from_slice(&rows.to_be_bytes());
424    buf.extend_from_slice(&cols.to_be_bytes());
425    for img in images {
426        buf.extend_from_slice(img);
427    }
428    buf
429}
430
431/// Build IDX1 label bytes (useful for tests).
432pub fn build_idx1_bytes(labels: &[u8]) -> Vec<u8> {
433    let count = labels.len() as u32;
434    let mut buf = Vec::new();
435    buf.extend_from_slice(&2049u32.to_be_bytes());
436    buf.extend_from_slice(&count.to_be_bytes());
437    buf.extend_from_slice(labels);
438    buf
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_parse_idx3_roundtrip() {
447        let img1 = vec![0u8; 4]; // 2×2 image
448        let img2 = vec![255u8; 4];
449        let bytes = build_idx3_bytes(&[&img1, &img2], 2, 2);
450        let (images, rows, cols) = parse_idx3_images(&bytes).unwrap();
451        assert_eq!(images.len(), 2);
452        assert_eq!(rows, 2);
453        assert_eq!(cols, 2);
454        assert_eq!(images[0], vec![0, 0, 0, 0]);
455        assert_eq!(images[1], vec![255, 255, 255, 255]);
456    }
457
458    #[test]
459    fn test_parse_idx1_roundtrip() {
460        let labels_in = vec![0, 1, 2, 9, 5];
461        let bytes = build_idx1_bytes(&labels_in);
462        let labels = parse_idx1_labels(&bytes).unwrap();
463        assert_eq!(labels, labels_in);
464    }
465
466    #[test]
467    fn test_invalid_magic_idx3() {
468        let mut bytes = build_idx3_bytes(&[&[0u8; 4]], 2, 2);
469        bytes[3] = 99; // corrupt magic
470        let err = parse_idx3_images(&bytes).unwrap_err();
471        assert!(matches!(err, MnistError::InvalidMagic { .. }));
472    }
473
474    #[test]
475    fn test_invalid_magic_idx1() {
476        let mut bytes = build_idx1_bytes(&[0, 1]);
477        bytes[3] = 99;
478        let err = parse_idx1_labels(&bytes).unwrap_err();
479        assert!(matches!(err, MnistError::InvalidMagic { .. }));
480    }
481
482    #[test]
483    fn test_from_raw() {
484        let img_bytes = build_idx3_bytes(&[&[128u8; 4], &[64u8; 4]], 2, 2);
485        let lbl_bytes = build_idx1_bytes(&[3, 7]);
486        let ds = MnistDataset::from_raw(&img_bytes, &lbl_bytes, MnistSplit::Train).unwrap();
487        assert_eq!(ds.num_samples(), 2);
488        assert_eq!(ds.label(0), 3);
489        assert_eq!(ds.label(1), 7);
490        assert_eq!(ds.image_u8(0), &[128; 4]);
491    }
492
493    #[test]
494    fn test_count_mismatch() {
495        let img_bytes = build_idx3_bytes(&[&[0u8; 4]], 2, 2); // 1 image
496        let lbl_bytes = build_idx1_bytes(&[0, 1]); // 2 labels
497        let err = MnistDataset::from_raw(&img_bytes, &lbl_bytes, MnistSplit::Train).unwrap_err();
498        assert!(matches!(err, MnistError::CountMismatch { .. }));
499    }
500
501    #[test]
502    fn test_dataset_trait() {
503        let img_bytes = build_idx3_bytes(&[&[100u8; 4], &[200u8; 4]], 2, 2);
504        let lbl_bytes = build_idx1_bytes(&[5, 8]);
505        let ds = MnistDataset::from_raw(&img_bytes, &lbl_bytes, MnistSplit::Test).unwrap();
506
507        assert_eq!(ds.len(), 2);
508        assert!(!ds.is_empty());
509        assert_eq!(ds.name(), "MNIST-test");
510
511        let s0 = ds.get(0);
512        assert_eq!(s0.features.len(), 4); // 2×2 = 4 pixels
513        assert_eq!(s0.features[0], 100.0);
514        assert_eq!(s0.target, vec![5.0]);
515        assert_eq!(s0.feature_shape, vec![4]); // rows*cols
516        assert_eq!(s0.target_shape, vec![1]);
517    }
518
519    #[test]
520    fn test_synthetic() {
521        let ds = MnistDataset::synthetic(100, MnistSplit::Train);
522        assert_eq!(ds.num_samples(), 100);
523        assert_eq!(ds.image_dims(), (28, 28));
524        for i in 0..100 {
525            assert!(ds.label(i) < 10);
526        }
527    }
528
529    #[test]
530    fn test_take() {
531        let ds = MnistDataset::synthetic(100, MnistSplit::Train).take(10);
532        assert_eq!(ds.num_samples(), 10);
533    }
534}