1use std::fs;
17use std::io;
18use std::path::{Path, PathBuf};
19
20use crate::dataset::{Dataset, Sample};
21
22#[derive(Debug)]
24pub enum MnistError {
25 Io(io::Error),
26 InvalidMagic { expected: u32, got: u32 },
27 CountMismatch { images: usize, labels: usize },
28 MissingFile(PathBuf),
29}
30
31impl std::fmt::Display for MnistError {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 MnistError::Io(e) => write!(f, "MNIST I/O error: {e}"),
35 MnistError::InvalidMagic { expected, got } => write!(
36 f,
37 "MNIST invalid magic: expected {expected:#06x}, got {got:#06x}"
38 ),
39 MnistError::CountMismatch { images, labels } => write!(
40 f,
41 "MNIST count mismatch: {images} images vs {labels} labels"
42 ),
43 MnistError::MissingFile(p) => write!(f, "MNIST file not found: {}", p.display()),
44 }
45 }
46}
47
48impl std::error::Error for MnistError {}
49
50impl From<io::Error> for MnistError {
51 fn from(e: io::Error) -> Self {
52 MnistError::Io(e)
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum MnistSplit {
59 Train,
60 Test,
61}
62
63#[derive(Debug)]
68pub struct MnistDataset {
69 images: Vec<Vec<u8>>,
70 labels: Vec<u8>,
71 rows: usize,
72 cols: usize,
73 split: MnistSplit,
74}
75
76impl MnistDataset {
77 pub fn load(dir: impl AsRef<Path>, split: MnistSplit) -> Result<Self, MnistError> {
85 let dir = dir.as_ref();
86
87 let (img_name, lbl_name) = match split {
88 MnistSplit::Train => ("train-images-idx3-ubyte", "train-labels-idx1-ubyte"),
89 MnistSplit::Test => ("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"),
90 };
91
92 let img_bytes = read_maybe_gz(dir, img_name)?;
93 let lbl_bytes = read_maybe_gz(dir, lbl_name)?;
94
95 let (images, rows, cols) = parse_idx3_images(&img_bytes)?;
96 let labels = parse_idx1_labels(&lbl_bytes)?;
97
98 if images.len() != labels.len() {
99 return Err(MnistError::CountMismatch {
100 images: images.len(),
101 labels: labels.len(),
102 });
103 }
104
105 Ok(Self {
106 images,
107 labels,
108 rows,
109 cols,
110 split,
111 })
112 }
113
114 pub fn from_raw(
116 image_bytes: &[u8],
117 label_bytes: &[u8],
118 split: MnistSplit,
119 ) -> Result<Self, MnistError> {
120 let (images, rows, cols) = parse_idx3_images(image_bytes)?;
121 let labels = parse_idx1_labels(label_bytes)?;
122
123 if images.len() != labels.len() {
124 return Err(MnistError::CountMismatch {
125 images: images.len(),
126 labels: labels.len(),
127 });
128 }
129
130 Ok(Self {
131 images,
132 labels,
133 rows,
134 cols,
135 split,
136 })
137 }
138
139 pub fn synthetic(n: usize, split: MnistSplit) -> Self {
143 use rand::Rng;
144 let mut rng = rand::thread_rng();
145 let rows = 28;
146 let cols = 28;
147 let mut images = Vec::with_capacity(n);
148 let mut labels = Vec::with_capacity(n);
149
150 for _ in 0..n {
151 let mut img = vec![0u8; rows * cols];
152 for px in &mut img {
153 *px = rng.gen();
154 }
155 images.push(img);
156 labels.push(rng.gen_range(0..10u8));
157 }
158
159 Self {
160 images,
161 labels,
162 rows,
163 cols,
164 split,
165 }
166 }
167
168 pub fn num_samples(&self) -> usize {
170 self.images.len()
171 }
172
173 pub fn image_dims(&self) -> (usize, usize) {
175 (self.rows, self.cols)
176 }
177
178 pub fn image_u8(&self, i: usize) -> &[u8] {
180 &self.images[i]
181 }
182
183 pub fn label(&self, i: usize) -> u8 {
185 self.labels[i]
186 }
187
188 pub fn split(&self) -> MnistSplit {
190 self.split
191 }
192
193 pub fn take(mut self, n: usize) -> Self {
195 let n = n.min(self.images.len());
196 self.images.truncate(n);
197 self.labels.truncate(n);
198 self
199 }
200}
201
202impl Dataset for MnistDataset {
203 fn len(&self) -> usize {
204 self.images.len()
205 }
206
207 fn get(&self, index: usize) -> Sample {
208 let pixels = &self.images[index];
209 let label = self.labels[index];
210
211 Sample {
212 features: pixels.iter().map(|&p| p as f64).collect(),
213 feature_shape: vec![self.rows * self.cols],
214 target: vec![label as f64],
215 target_shape: vec![1],
216 }
217 }
218
219 fn feature_shape(&self) -> &[usize] {
220 &[784] }
226
227 fn target_shape(&self) -> &[usize] {
228 &[1]
229 }
230
231 fn name(&self) -> &str {
232 match self.split {
233 MnistSplit::Train => "MNIST-train",
234 MnistSplit::Test => "MNIST-test",
235 }
236 }
237}
238
239fn read_maybe_gz(dir: &Path, base_name: &str) -> Result<Vec<u8>, MnistError> {
243 let plain = dir.join(base_name);
244 let gz = dir.join(format!("{base_name}.gz"));
245
246 if plain.exists() {
247 Ok(fs::read(&plain)?)
248 } else if gz.exists() {
249 let compressed = fs::read(&gz)?;
250 decompress_gz(&compressed)
251 } else {
252 Err(MnistError::MissingFile(plain))
253 }
254}
255
256fn decompress_gz(data: &[u8]) -> Result<Vec<u8>, MnistError> {
261 if data.len() < 10 {
266 return Err(MnistError::Io(io::Error::new(
267 io::ErrorKind::InvalidData,
268 "gzip data too short",
269 )));
270 }
271
272 if data[0] != 0x1f || data[1] != 0x8b {
274 return Err(MnistError::Io(io::Error::new(
275 io::ErrorKind::InvalidData,
276 "not a gzip file",
277 )));
278 }
279
280 let mut pos = 10;
282 let flags = data[3];
283
284 if flags & 0x04 != 0 {
286 if pos + 2 > data.len() {
287 return Err(io_err("truncated gzip FEXTRA"));
288 }
289 let xlen = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
290 pos += 2 + xlen;
291 }
292 if flags & 0x08 != 0 {
294 while pos < data.len() && data[pos] != 0 {
295 pos += 1;
296 }
297 pos += 1; }
299 if flags & 0x10 != 0 {
301 while pos < data.len() && data[pos] != 0 {
302 pos += 1;
303 }
304 pos += 1;
305 }
306 if flags & 0x02 != 0 {
308 pos += 2;
309 }
310
311 if pos >= data.len() {
312 return Err(io_err("truncated gzip header"));
313 }
314
315 let deflate_end = if data.len() >= 8 {
317 data.len() - 8
318 } else {
319 data.len()
320 };
321 let deflate_data = &data[pos..deflate_end];
322
323 inflate_deflate(deflate_data)
325}
326
327fn inflate_deflate(_data: &[u8]) -> Result<Vec<u8>, MnistError> {
336 Err(MnistError::Io(io::Error::new(
337 io::ErrorKind::Unsupported,
338 "gzip decompression requires the `flate2` feature. \
339 Please decompress MNIST files manually (gunzip) or enable flate2.",
340 )))
341}
342
343fn io_err(msg: &str) -> MnistError {
344 MnistError::Io(io::Error::new(io::ErrorKind::InvalidData, msg))
345}
346
347fn parse_idx3_images(data: &[u8]) -> Result<(Vec<Vec<u8>>, usize, usize), MnistError> {
349 if data.len() < 16 {
350 return Err(io_err("IDX3 file too short"));
351 }
352
353 let magic = read_u32_be(data, 0);
354 if magic != 2051 {
355 return Err(MnistError::InvalidMagic {
356 expected: 2051,
357 got: magic,
358 });
359 }
360
361 let count = read_u32_be(data, 4) as usize;
362 let rows = read_u32_be(data, 8) as usize;
363 let cols = read_u32_be(data, 12) as usize;
364 let pixels_per_image = rows * cols;
365
366 let expected_len = 16 + count * pixels_per_image;
367 if data.len() < expected_len {
368 return Err(io_err(&format!(
369 "IDX3 truncated: expected {expected_len} bytes, got {}",
370 data.len()
371 )));
372 }
373
374 let mut images = Vec::with_capacity(count);
375 for i in 0..count {
376 let start = 16 + i * pixels_per_image;
377 let end = start + pixels_per_image;
378 images.push(data[start..end].to_vec());
379 }
380
381 Ok((images, rows, cols))
382}
383
384fn parse_idx1_labels(data: &[u8]) -> Result<Vec<u8>, MnistError> {
386 if data.len() < 8 {
387 return Err(io_err("IDX1 file too short"));
388 }
389
390 let magic = read_u32_be(data, 0);
391 if magic != 2049 {
392 return Err(MnistError::InvalidMagic {
393 expected: 2049,
394 got: magic,
395 });
396 }
397
398 let count = read_u32_be(data, 4) as usize;
399 let expected_len = 8 + count;
400 if data.len() < expected_len {
401 return Err(io_err(&format!(
402 "IDX1 truncated: expected {expected_len} bytes, got {}",
403 data.len()
404 )));
405 }
406
407 Ok(data[8..8 + count].to_vec())
408}
409
410fn read_u32_be(data: &[u8], off: usize) -> u32 {
412 u32::from_be_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]])
413}
414
415pub fn build_idx3_bytes(images: &[&[u8]], rows: u32, cols: u32) -> Vec<u8> {
419 let count = images.len() as u32;
420 let mut buf = Vec::new();
421 buf.extend_from_slice(&2051u32.to_be_bytes());
422 buf.extend_from_slice(&count.to_be_bytes());
423 buf.extend_from_slice(&rows.to_be_bytes());
424 buf.extend_from_slice(&cols.to_be_bytes());
425 for img in images {
426 buf.extend_from_slice(img);
427 }
428 buf
429}
430
431pub fn build_idx1_bytes(labels: &[u8]) -> Vec<u8> {
433 let count = labels.len() as u32;
434 let mut buf = Vec::new();
435 buf.extend_from_slice(&2049u32.to_be_bytes());
436 buf.extend_from_slice(&count.to_be_bytes());
437 buf.extend_from_slice(labels);
438 buf
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_parse_idx3_roundtrip() {
447 let img1 = vec![0u8; 4]; let img2 = vec![255u8; 4];
449 let bytes = build_idx3_bytes(&[&img1, &img2], 2, 2);
450 let (images, rows, cols) = parse_idx3_images(&bytes).unwrap();
451 assert_eq!(images.len(), 2);
452 assert_eq!(rows, 2);
453 assert_eq!(cols, 2);
454 assert_eq!(images[0], vec![0, 0, 0, 0]);
455 assert_eq!(images[1], vec![255, 255, 255, 255]);
456 }
457
458 #[test]
459 fn test_parse_idx1_roundtrip() {
460 let labels_in = vec![0, 1, 2, 9, 5];
461 let bytes = build_idx1_bytes(&labels_in);
462 let labels = parse_idx1_labels(&bytes).unwrap();
463 assert_eq!(labels, labels_in);
464 }
465
466 #[test]
467 fn test_invalid_magic_idx3() {
468 let mut bytes = build_idx3_bytes(&[&[0u8; 4]], 2, 2);
469 bytes[3] = 99; let err = parse_idx3_images(&bytes).unwrap_err();
471 assert!(matches!(err, MnistError::InvalidMagic { .. }));
472 }
473
474 #[test]
475 fn test_invalid_magic_idx1() {
476 let mut bytes = build_idx1_bytes(&[0, 1]);
477 bytes[3] = 99;
478 let err = parse_idx1_labels(&bytes).unwrap_err();
479 assert!(matches!(err, MnistError::InvalidMagic { .. }));
480 }
481
482 #[test]
483 fn test_from_raw() {
484 let img_bytes = build_idx3_bytes(&[&[128u8; 4], &[64u8; 4]], 2, 2);
485 let lbl_bytes = build_idx1_bytes(&[3, 7]);
486 let ds = MnistDataset::from_raw(&img_bytes, &lbl_bytes, MnistSplit::Train).unwrap();
487 assert_eq!(ds.num_samples(), 2);
488 assert_eq!(ds.label(0), 3);
489 assert_eq!(ds.label(1), 7);
490 assert_eq!(ds.image_u8(0), &[128; 4]);
491 }
492
493 #[test]
494 fn test_count_mismatch() {
495 let img_bytes = build_idx3_bytes(&[&[0u8; 4]], 2, 2); let lbl_bytes = build_idx1_bytes(&[0, 1]); let err = MnistDataset::from_raw(&img_bytes, &lbl_bytes, MnistSplit::Train).unwrap_err();
498 assert!(matches!(err, MnistError::CountMismatch { .. }));
499 }
500
501 #[test]
502 fn test_dataset_trait() {
503 let img_bytes = build_idx3_bytes(&[&[100u8; 4], &[200u8; 4]], 2, 2);
504 let lbl_bytes = build_idx1_bytes(&[5, 8]);
505 let ds = MnistDataset::from_raw(&img_bytes, &lbl_bytes, MnistSplit::Test).unwrap();
506
507 assert_eq!(ds.len(), 2);
508 assert!(!ds.is_empty());
509 assert_eq!(ds.name(), "MNIST-test");
510
511 let s0 = ds.get(0);
512 assert_eq!(s0.features.len(), 4); assert_eq!(s0.features[0], 100.0);
514 assert_eq!(s0.target, vec![5.0]);
515 assert_eq!(s0.feature_shape, vec![4]); assert_eq!(s0.target_shape, vec![1]);
517 }
518
519 #[test]
520 fn test_synthetic() {
521 let ds = MnistDataset::synthetic(100, MnistSplit::Train);
522 assert_eq!(ds.num_samples(), 100);
523 assert_eq!(ds.image_dims(), (28, 28));
524 for i in 0..100 {
525 assert!(ds.label(i) < 10);
526 }
527 }
528
529 #[test]
530 fn test_take() {
531 let ds = MnistDataset::synthetic(100, MnistSplit::Train).take(10);
532 assert_eq!(ds.num_samples(), 10);
533 }
534}