rust_mnist/
lib.rs

1#![warn(clippy::pedantic)]
2//! A simple struct build by parsing the MNIST dataset.
3
4use log::info;
5use std::convert::TryFrom;
6use std::fs;
7use std::io;
8use std::io::Read;
9
10// Filenames
11const TRAIN_DATA_FILENAME: &str = "train-images-idx3-ubyte";
12const TEST_DATA_FILENAME: &str = "t10k-images-idx3-ubyte";
13const TRAIN_LABEL_FILENAME: &str = "train-labels-idx1-ubyte";
14const TEST_LABEL_FILENAME: &str = "t10k-labels-idx1-ubyte";
15
16// Constants relating to the MNIST dataset. All usize for array/vec indexing.
17const IMAGES_MAGIC_NUMBER: usize = 2051;
18const LABELS_MAGIC_NUMBER: usize = 2049;
19const NUM_TRAIN_IMAGES: usize = 60_000;
20const NUM_TEST_IMAGES: usize = 10_000;
21const IMAGE_ROWS: usize = 28;
22const IMAGE_COLUMNS: usize = 28;
23
24pub struct Mnist {
25    // Arrays of images.
26    pub train_data: Vec<[u8; IMAGE_ROWS * IMAGE_COLUMNS]>,
27    pub test_data: Vec<[u8; IMAGE_ROWS * IMAGE_COLUMNS]>,
28
29    // Arrays of labels.
30    pub train_labels: Vec<u8>,
31    pub test_labels: Vec<u8>,
32}
33
34impl Mnist {
35    /// Load MNIST dataset.
36    ///
37    /// # Panics
38    ///
39    /// Panics if the MNIST dataset is not present at the specified path, or if the dataset is
40    /// malformed.
41    #[must_use]
42    pub fn new(mnist_path: &str) -> Mnist {
43        // Get Training Data.
44        info!("Reading MNIST training data.");
45        let train_data = parse_images(&[mnist_path, TRAIN_DATA_FILENAME].concat()).expect(
46            &format!(
47                "Training data file \"{}{}\" not found; did you \
48                     remember to download and extract it?",
49                mnist_path, TRAIN_DATA_FILENAME,
50            )[..],
51        );
52
53        // Assert that numbers extracted from the file were as expected.
54        assert_eq!(
55            train_data.magic_number, IMAGES_MAGIC_NUMBER,
56            "Magic number for training data does not match expected value."
57        );
58        assert_eq!(
59            train_data.num_images, NUM_TRAIN_IMAGES,
60            "Number of images in training data does not match expected value."
61        );
62        assert_eq!(
63            train_data.num_rows, IMAGE_ROWS,
64            "Number of rows per image in training data does not match expected value."
65        );
66        assert_eq!(
67            train_data.num_cols, IMAGE_COLUMNS,
68            "Number of columns per image in training data does not match expected value."
69        );
70
71        // Get Testing Data.
72        info!("Reading MNIST testing data.");
73        let test_data = parse_images(&[mnist_path, TEST_DATA_FILENAME].concat()).expect(
74            &format!(
75                "Test data file \"{}{}\" not found; did you \
76                     remember to download and extract it?",
77                mnist_path, TEST_DATA_FILENAME,
78            )[..],
79        );
80
81        // Assert that numbers extracted from the file were as expected.
82        assert_eq!(
83            test_data.magic_number, IMAGES_MAGIC_NUMBER,
84            "Magic number for testing data does not match expected value."
85        );
86        assert_eq!(
87            test_data.num_images, NUM_TEST_IMAGES,
88            "Number of images in testing data does not match expected value."
89        );
90        assert_eq!(
91            test_data.num_rows, IMAGE_ROWS,
92            "Number of rows per image in testing data does not match expected value."
93        );
94        assert_eq!(
95            test_data.num_cols, IMAGE_COLUMNS,
96            "Number of columns per image in testing data does not match expected value."
97        );
98
99        // Get Training Labels.
100        info!("Reading MNIST training labels.");
101        let (magic_number, num_labels, train_labels) =
102            parse_labels(&[mnist_path, TRAIN_LABEL_FILENAME].concat()).expect(
103                &format!(
104                    "Training label file \"{}{}\" not found; did you \
105                     remember to download and extract it?",
106                    mnist_path, TRAIN_LABEL_FILENAME,
107                )[..],
108            );
109
110        // Assert that numbers extracted from the file were as expected.
111        assert_eq!(
112            magic_number, LABELS_MAGIC_NUMBER,
113            "Magic number for training labels does not match expected value."
114        );
115        assert_eq!(
116            num_labels, NUM_TRAIN_IMAGES,
117            "Number of labels in training labels does not match expected value."
118        );
119
120        // Get Testing Labels.
121        info!("Reading MNIST testing labels.");
122        let (magic_number, num_labels, test_labels) =
123            parse_labels(&[mnist_path, TEST_LABEL_FILENAME].concat()).expect(
124                &format!(
125                    "Test labels file \"{}{}\" not found; did you \
126                     remember to download and extract it?",
127                    mnist_path, TEST_LABEL_FILENAME,
128                )[..],
129            );
130
131        // Assert that numbers extracted from the file were as expected.
132        assert_eq!(
133            magic_number, LABELS_MAGIC_NUMBER,
134            "Magic number for testing labels does not match expected value."
135        );
136        assert_eq!(
137            num_labels, NUM_TEST_IMAGES,
138            "Number of labels in testing labels does not match expected value."
139        );
140
141        Mnist {
142            train_data: train_data.images,
143            test_data: test_data.images,
144            train_labels,
145            test_labels,
146        }
147    }
148}
149
150/// Print a sample image.
151///
152/// # Examples
153/// ```
154/// use rust_mnist::{print_image, Mnist};
155///
156/// let mnist = Mnist::new("examples/MNIST_data/");
157///
158/// // Print one image (the one at index 5).
159/// print_image(&mnist.train_data[5], mnist.train_labels[5]);
160/// ```
161pub fn print_image(image: &[u8; IMAGE_ROWS * IMAGE_COLUMNS], label: u8) {
162    println!("Sample image label: {label} \nSample image:");
163
164    // Print each row.
165    for row in 0..IMAGE_ROWS {
166        for col in 0..IMAGE_COLUMNS {
167            if image[row * IMAGE_COLUMNS + col] == 0 {
168                print!("__");
169            } else {
170                print!("##");
171            }
172        }
173        println!();
174    }
175}
176
177struct MnistImages {
178    magic_number: usize,
179    num_images: usize,
180    num_rows: usize,
181    num_cols: usize,
182    images: Vec<[u8; IMAGE_ROWS * IMAGE_COLUMNS]>,
183}
184
185fn parse_images(filename: &str) -> io::Result<MnistImages> {
186    // Open the file.
187    let images_data_bytes = fs::File::open(filename)?;
188    let images_data_bytes = io::BufReader::new(images_data_bytes);
189    let mut buffer_32: [u8; 4] = [0; 4];
190
191    // Get the magic number.
192    images_data_bytes
193        .get_ref()
194        .take(4)
195        .read_exact(&mut buffer_32)?;
196    let magic_number = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
197
198    // Get number of images.
199    images_data_bytes
200        .get_ref()
201        .take(4)
202        .read_exact(&mut buffer_32)?;
203    let num_images = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
204
205    // Get number or rows per image.
206    images_data_bytes
207        .get_ref()
208        .take(4)
209        .read_exact(&mut buffer_32)?;
210    let num_rows = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
211
212    // Get number or columns per image.
213    images_data_bytes
214        .get_ref()
215        .take(4)
216        .read_exact(&mut buffer_32)?;
217    let num_cols = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
218
219    // Buffer for holding image pixels.
220    let mut image_buffer: [u8; IMAGE_ROWS * IMAGE_COLUMNS] = [0; IMAGE_ROWS * IMAGE_COLUMNS];
221
222    // Vector to hold all images in the file.
223    let mut images: Vec<[u8; IMAGE_ROWS * IMAGE_COLUMNS]> = Vec::with_capacity(num_images);
224
225    // Get images from file.
226    for _image in 0..num_images {
227        images_data_bytes
228            .get_ref()
229            .take(u64::try_from(num_rows * num_cols).unwrap())
230            .read_exact(&mut image_buffer)
231            .unwrap();
232        images.push(image_buffer);
233    }
234
235    Ok(MnistImages {
236        magic_number,
237        num_images,
238        num_rows,
239        num_cols,
240        images,
241    })
242}
243
244fn parse_labels(filename: &str) -> io::Result<(usize, usize, Vec<u8>)> {
245    let labels_data_bytes = fs::File::open(filename)?;
246    let labels_data_bytes = io::BufReader::new(labels_data_bytes);
247    let mut buffer_32: [u8; 4] = [0; 4];
248
249    // Get the magic number.
250    labels_data_bytes
251        .get_ref()
252        .take(4)
253        .read_exact(&mut buffer_32)
254        .unwrap();
255    let magic_number = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
256
257    // Get number of labels.
258    labels_data_bytes
259        .get_ref()
260        .take(4)
261        .read_exact(&mut buffer_32)
262        .unwrap();
263    let num_labels = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
264
265    // Buffer for holding image label.
266    let mut label_buffer: [u8; 1] = [0; 1];
267
268    // Vector to hold all labels in the file.
269    let mut labels: Vec<u8> = Vec::with_capacity(num_labels);
270
271    // Get labels from file.
272    for _label in 0..num_labels {
273        labels_data_bytes
274            .get_ref()
275            .take(1)
276            .read_exact(&mut label_buffer)
277            .unwrap();
278        labels.push(label_buffer[0]);
279    }
280    Ok((magic_number, num_labels, labels))
281}