1#![warn(clippy::pedantic)]
2use log::info;
5use std::convert::TryFrom;
6use std::fs;
7use std::io;
8use std::io::Read;
9
10const TRAIN_DATA_FILENAME: &str = "train-images-idx3-ubyte";
12const TEST_DATA_FILENAME: &str = "t10k-images-idx3-ubyte";
13const TRAIN_LABEL_FILENAME: &str = "train-labels-idx1-ubyte";
14const TEST_LABEL_FILENAME: &str = "t10k-labels-idx1-ubyte";
15
16const IMAGES_MAGIC_NUMBER: usize = 2051;
18const LABELS_MAGIC_NUMBER: usize = 2049;
19const NUM_TRAIN_IMAGES: usize = 60_000;
20const NUM_TEST_IMAGES: usize = 10_000;
21const IMAGE_ROWS: usize = 28;
22const IMAGE_COLUMNS: usize = 28;
23
24pub struct Mnist {
25 pub train_data: Vec<[u8; IMAGE_ROWS * IMAGE_COLUMNS]>,
27 pub test_data: Vec<[u8; IMAGE_ROWS * IMAGE_COLUMNS]>,
28
29 pub train_labels: Vec<u8>,
31 pub test_labels: Vec<u8>,
32}
33
34impl Mnist {
35 #[must_use]
42 pub fn new(mnist_path: &str) -> Mnist {
43 info!("Reading MNIST training data.");
45 let train_data = parse_images(&[mnist_path, TRAIN_DATA_FILENAME].concat()).expect(
46 &format!(
47 "Training data file \"{}{}\" not found; did you \
48 remember to download and extract it?",
49 mnist_path, TRAIN_DATA_FILENAME,
50 )[..],
51 );
52
53 assert_eq!(
55 train_data.magic_number, IMAGES_MAGIC_NUMBER,
56 "Magic number for training data does not match expected value."
57 );
58 assert_eq!(
59 train_data.num_images, NUM_TRAIN_IMAGES,
60 "Number of images in training data does not match expected value."
61 );
62 assert_eq!(
63 train_data.num_rows, IMAGE_ROWS,
64 "Number of rows per image in training data does not match expected value."
65 );
66 assert_eq!(
67 train_data.num_cols, IMAGE_COLUMNS,
68 "Number of columns per image in training data does not match expected value."
69 );
70
71 info!("Reading MNIST testing data.");
73 let test_data = parse_images(&[mnist_path, TEST_DATA_FILENAME].concat()).expect(
74 &format!(
75 "Test data file \"{}{}\" not found; did you \
76 remember to download and extract it?",
77 mnist_path, TEST_DATA_FILENAME,
78 )[..],
79 );
80
81 assert_eq!(
83 test_data.magic_number, IMAGES_MAGIC_NUMBER,
84 "Magic number for testing data does not match expected value."
85 );
86 assert_eq!(
87 test_data.num_images, NUM_TEST_IMAGES,
88 "Number of images in testing data does not match expected value."
89 );
90 assert_eq!(
91 test_data.num_rows, IMAGE_ROWS,
92 "Number of rows per image in testing data does not match expected value."
93 );
94 assert_eq!(
95 test_data.num_cols, IMAGE_COLUMNS,
96 "Number of columns per image in testing data does not match expected value."
97 );
98
99 info!("Reading MNIST training labels.");
101 let (magic_number, num_labels, train_labels) =
102 parse_labels(&[mnist_path, TRAIN_LABEL_FILENAME].concat()).expect(
103 &format!(
104 "Training label file \"{}{}\" not found; did you \
105 remember to download and extract it?",
106 mnist_path, TRAIN_LABEL_FILENAME,
107 )[..],
108 );
109
110 assert_eq!(
112 magic_number, LABELS_MAGIC_NUMBER,
113 "Magic number for training labels does not match expected value."
114 );
115 assert_eq!(
116 num_labels, NUM_TRAIN_IMAGES,
117 "Number of labels in training labels does not match expected value."
118 );
119
120 info!("Reading MNIST testing labels.");
122 let (magic_number, num_labels, test_labels) =
123 parse_labels(&[mnist_path, TEST_LABEL_FILENAME].concat()).expect(
124 &format!(
125 "Test labels file \"{}{}\" not found; did you \
126 remember to download and extract it?",
127 mnist_path, TEST_LABEL_FILENAME,
128 )[..],
129 );
130
131 assert_eq!(
133 magic_number, LABELS_MAGIC_NUMBER,
134 "Magic number for testing labels does not match expected value."
135 );
136 assert_eq!(
137 num_labels, NUM_TEST_IMAGES,
138 "Number of labels in testing labels does not match expected value."
139 );
140
141 Mnist {
142 train_data: train_data.images,
143 test_data: test_data.images,
144 train_labels,
145 test_labels,
146 }
147 }
148}
149
150pub fn print_image(image: &[u8; IMAGE_ROWS * IMAGE_COLUMNS], label: u8) {
162 println!("Sample image label: {label} \nSample image:");
163
164 for row in 0..IMAGE_ROWS {
166 for col in 0..IMAGE_COLUMNS {
167 if image[row * IMAGE_COLUMNS + col] == 0 {
168 print!("__");
169 } else {
170 print!("##");
171 }
172 }
173 println!();
174 }
175}
176
177struct MnistImages {
178 magic_number: usize,
179 num_images: usize,
180 num_rows: usize,
181 num_cols: usize,
182 images: Vec<[u8; IMAGE_ROWS * IMAGE_COLUMNS]>,
183}
184
185fn parse_images(filename: &str) -> io::Result<MnistImages> {
186 let images_data_bytes = fs::File::open(filename)?;
188 let images_data_bytes = io::BufReader::new(images_data_bytes);
189 let mut buffer_32: [u8; 4] = [0; 4];
190
191 images_data_bytes
193 .get_ref()
194 .take(4)
195 .read_exact(&mut buffer_32)?;
196 let magic_number = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
197
198 images_data_bytes
200 .get_ref()
201 .take(4)
202 .read_exact(&mut buffer_32)?;
203 let num_images = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
204
205 images_data_bytes
207 .get_ref()
208 .take(4)
209 .read_exact(&mut buffer_32)?;
210 let num_rows = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
211
212 images_data_bytes
214 .get_ref()
215 .take(4)
216 .read_exact(&mut buffer_32)?;
217 let num_cols = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
218
219 let mut image_buffer: [u8; IMAGE_ROWS * IMAGE_COLUMNS] = [0; IMAGE_ROWS * IMAGE_COLUMNS];
221
222 let mut images: Vec<[u8; IMAGE_ROWS * IMAGE_COLUMNS]> = Vec::with_capacity(num_images);
224
225 for _image in 0..num_images {
227 images_data_bytes
228 .get_ref()
229 .take(u64::try_from(num_rows * num_cols).unwrap())
230 .read_exact(&mut image_buffer)
231 .unwrap();
232 images.push(image_buffer);
233 }
234
235 Ok(MnistImages {
236 magic_number,
237 num_images,
238 num_rows,
239 num_cols,
240 images,
241 })
242}
243
244fn parse_labels(filename: &str) -> io::Result<(usize, usize, Vec<u8>)> {
245 let labels_data_bytes = fs::File::open(filename)?;
246 let labels_data_bytes = io::BufReader::new(labels_data_bytes);
247 let mut buffer_32: [u8; 4] = [0; 4];
248
249 labels_data_bytes
251 .get_ref()
252 .take(4)
253 .read_exact(&mut buffer_32)
254 .unwrap();
255 let magic_number = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
256
257 labels_data_bytes
259 .get_ref()
260 .take(4)
261 .read_exact(&mut buffer_32)
262 .unwrap();
263 let num_labels = usize::try_from(u32::from_be_bytes(buffer_32)).unwrap();
264
265 let mut label_buffer: [u8; 1] = [0; 1];
267
268 let mut labels: Vec<u8> = Vec::with_capacity(num_labels);
270
271 for _label in 0..num_labels {
273 labels_data_bytes
274 .get_ref()
275 .take(1)
276 .read_exact(&mut label_buffer)
277 .unwrap();
278 labels.push(label_buffer[0]);
279 }
280 Ok((magic_number, num_labels, labels))
281}