tch_plus/vision/
dataset.rs

1//! A simple dataset structure shared by various computer vision datasets.
2use crate::data::Iter2;
3use crate::{IndexOp, Tensor};
4use rand::Rng;
5
6#[derive(Debug)]
7pub struct Dataset {
8    pub train_images: Tensor,
9    pub train_labels: Tensor,
10    pub test_images: Tensor,
11    pub test_labels: Tensor,
12    pub labels: i64,
13}
14
15impl Dataset {
16    pub fn train_iter(&self, batch_size: i64) -> Iter2 {
17        Iter2::new(&self.train_images, &self.train_labels, batch_size)
18    }
19
20    pub fn test_iter(&self, batch_size: i64) -> Iter2 {
21        Iter2::new(&self.test_images, &self.test_labels, batch_size)
22    }
23}
24
25/// Randomly applies horizontal flips
26/// This expects a 4 dimension NCHW tensor and returns a tensor with
27/// an identical shape.
28pub fn random_flip(t: &Tensor) -> Tensor {
29    let size = t.size();
30    if size.len() != 4 {
31        panic!("unexpected shape for tensor {t:?}")
32    }
33    let output = t.zeros_like();
34    for batch_index in 0..size[0] {
35        let mut output_view = output.i(batch_index);
36        let t_view = t.i(batch_index);
37        let src = if rand::random() { t_view } else { t_view.flip([2]) };
38        output_view.copy_(&src)
39    }
40    output
41}
42
43/// Pad the image using reflections and take some random crops.
44/// This expects a 4 dimension NCHW tensor and returns a tensor with
45/// an identical shape.
46pub fn random_crop(t: &Tensor, pad: i64) -> Tensor {
47    let size = t.size();
48    if size.len() != 4 {
49        panic!("unexpected shape for tensor {t:?}")
50    }
51    let sz_h = size[2];
52    let sz_w = size[3];
53    let padded = t.reflection_pad2d([pad, pad, pad, pad]);
54    let output = t.zeros_like();
55    for bindex in 0..size[0] {
56        let mut output_view = output.i(bindex);
57        let start_w = rand::thread_rng().gen_range(0..2 * pad);
58        let start_h = rand::thread_rng().gen_range(0..2 * pad);
59        let src = padded.i((bindex, .., start_h..start_h + sz_h, start_w..start_w + sz_w));
60        output_view.copy_(&src)
61    }
62    output
63}
64
65/// Applies cutout: randomly remove some square areas in the original images.
66/// <https://arxiv.org/abs/1708.04552>
67pub fn random_cutout(t: &Tensor, sz: i64) -> Tensor {
68    let size = t.size();
69    if size.len() != 4 || sz > size[2] || sz > size[3] {
70        panic!("unexpected shape for tensor {t:?} {sz}")
71    }
72    let mut output = t.zeros_like();
73    output.copy_(t);
74    for bindex in 0..size[0] {
75        let start_h = rand::thread_rng().gen_range(0..size[2] - sz + 1);
76        let start_w = rand::thread_rng().gen_range(0..size[3] - sz + 1);
77        let _output =
78            output.i((bindex, .., start_h..start_h + sz, start_w..start_w + sz)).fill_(0.0);
79    }
80    output
81}
82
83pub fn augmentation(t: &Tensor, flip: bool, crop: i64, cutout: i64) -> Tensor {
84    let mut t = t.shallow_clone();
85    if flip {
86        t = random_flip(&t);
87    }
88    if crop > 0 {
89        t = random_crop(&t, crop);
90    }
91    if cutout > 0 {
92        t = random_cutout(&t, cutout);
93    }
94    t
95}