1use super::dataset::Dataset;
7use crate::{kind, IndexOp, Kind, Tensor};
8use std::fs::File;
9use std::io::{BufReader, Read, Result};
10
11const W: i64 = 32;
12const H: i64 = 32;
13const C: i64 = 3;
14const BYTES_PER_IMAGE: i64 = W * H * C + 1;
15const SAMPLES_PER_FILE: i64 = 10000;
16
17fn read_file_(filename: &std::path::Path) -> Result<(Tensor, Tensor)> {
18 let mut buf_reader = BufReader::new(File::open(filename)?);
19 let mut data = vec![0u8; (SAMPLES_PER_FILE * BYTES_PER_IMAGE) as usize];
20 buf_reader.read_exact(&mut data)?;
21 let content = Tensor::from_slice(&data);
22 let images = Tensor::zeros([SAMPLES_PER_FILE, C, H, W], kind::FLOAT_CPU);
23 let labels = Tensor::zeros([SAMPLES_PER_FILE], kind::INT64_CPU);
24 for index in 0..SAMPLES_PER_FILE {
25 let content_offset = BYTES_PER_IMAGE * index;
26 labels.i(index).copy_(&content.i(content_offset));
27 images.i(index).copy_(
28 &content
29 .narrow(0, 1 + content_offset, BYTES_PER_IMAGE - 1)
30 .view((C, H, W))
31 .to_kind(Kind::Float),
32 );
33 }
34 Ok((images.to_kind(Kind::Float) / 255.0, labels))
35}
36
37fn read_file(filename: &std::path::Path) -> Result<(Tensor, Tensor)> {
38 read_file_(filename)
39 .map_err(|err| std::io::Error::new(err.kind(), format!("{filename:?} {err}")))
40}
41
42pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
43 let dir = dir.as_ref();
44 let (test_images, test_labels) = read_file(&dir.join("test_batch.bin"))?;
45 let train_images_and_labels = [
46 "data_batch_1.bin",
47 "data_batch_2.bin",
48 "data_batch_3.bin",
49 "data_batch_4.bin",
50 "data_batch_5.bin",
51 ]
52 .iter()
53 .map(|x| read_file(&dir.join(x)))
54 .collect::<Result<Vec<_>>>()?;
55 let (train_images, train_labels): (Vec<_>, Vec<_>) =
56 train_images_and_labels.into_iter().unzip();
57 Ok(Dataset {
58 train_images: Tensor::cat(&train_images, 0),
59 train_labels: Tensor::cat(&train_labels, 0),
60 test_images,
61 test_labels,
62 labels: 10,
63 })
64}