tch_plus/
data.rs

1//! Dataset iterators.
2use crate::{kind, kind::Kind, Device, IndexOp, TchError, Tensor};
3use std::collections::HashMap;
4
5/// An iterator over a pair of tensors which have the same first dimension
6/// size.
7/// The typical use case is to iterate over batches. Each batch is a pair
8/// containing a (potentially random) slice of each of the two input
9/// tensors.
10#[derive(Debug)]
11pub struct Iter2 {
12    xs: Tensor,
13    ys: Tensor,
14    batch_index: i64,
15    batch_size: i64,
16    total_size: i64,
17    device: Device,
18    return_smaller_last_batch: bool,
19}
20
21impl Iter2 {
22    /// Returns a new iterator.
23    ///
24    /// This takes as input two tensors which first dimension must match. The
25    /// returned iterator can be used to range over mini-batches of data of
26    /// specified size.
27    /// An error is returned if `xs` and `ys` have different first dimension
28    /// sizes.
29    ///
30    /// # Arguments
31    ///
32    /// * `xs` - the features to be used by the model.
33    /// * `ys` - the targets that the model attempts to predict.
34    /// * `batch_size` - the size of batches to be returned.
35    pub fn f_new(xs: &Tensor, ys: &Tensor, batch_size: i64) -> Result<Iter2, TchError> {
36        let total_size = xs.size()[0];
37        if ys.size()[0] != total_size {
38            return Err(TchError::Shape(format!(
39                "different dimension for the two inputs {xs:?} {ys:?}"
40            )));
41        }
42        Ok(Iter2 {
43            xs: xs.shallow_clone(),
44            ys: ys.shallow_clone(),
45            batch_index: 0,
46            batch_size,
47            total_size,
48            device: Device::Cpu,
49            return_smaller_last_batch: false,
50        })
51    }
52
53    /// Returns a new iterator.
54    ///
55    /// This takes as input two tensors which first dimension must match. The
56    /// returned iterator can be used to range over mini-batches of data of
57    /// specified size.
58    /// Panics if `xs` and `ys` have different first dimension sizes.
59    ///
60    /// # Arguments
61    ///
62    /// * `xs` - the features to be used by the model.
63    /// * `ys` - the targets that the model attempts to predict.
64    /// * `batch_size` - the size of batches to be returned.
65    pub fn new(xs: &Tensor, ys: &Tensor, batch_size: i64) -> Iter2 {
66        Iter2::f_new(xs, ys, batch_size).unwrap()
67    }
68
69    /// Shuffles the dataset.
70    ///
71    /// The iterator would still run over the whole dataset but the order in
72    /// which elements are grouped in mini-batches is randomized.
73    pub fn shuffle(&mut self) -> &mut Iter2 {
74        let index = Tensor::randperm(self.total_size, (Kind::Int64, self.device));
75        self.xs = self.xs.index_select(0, &index);
76        self.ys = self.ys.index_select(0, &index);
77        self
78    }
79
80    /// Transfers the mini-batches to a specified device.
81    #[allow(clippy::wrong_self_convention)]
82    pub fn to_device(&mut self, device: Device) -> &mut Iter2 {
83        self.device = device;
84        self
85    }
86
87    /// When set, returns the last batch even if smaller than the batch size.
88    pub fn return_smaller_last_batch(&mut self) -> &mut Iter2 {
89        self.return_smaller_last_batch = true;
90        self
91    }
92}
93
94impl Iterator for Iter2 {
95    type Item = (Tensor, Tensor);
96
97    fn next(&mut self) -> Option<Self::Item> {
98        let start = self.batch_index * self.batch_size;
99        let size = std::cmp::min(self.batch_size, self.total_size - start);
100        if size <= 0 || (!self.return_smaller_last_batch && size < self.batch_size) {
101            None
102        } else {
103            self.batch_index += 1;
104            Some((
105                self.xs.i(start..start + size).to_device(self.device),
106                self.ys.i(start..start + size).to_device(self.device),
107            ))
108        }
109    }
110}
111
112/// Text data holder.
113#[derive(Debug)]
114pub struct TextData {
115    data: Tensor,
116    char_for_label: Vec<char>,
117    label_for_char: HashMap<u8, u8>,
118}
119
120/// Text data iterator.
121#[derive(Debug)]
122pub struct TextDataIter {
123    data: Tensor,
124    seq_len: i64,
125    batch_index: i64,
126    batch_size: i64,
127    indexes: Tensor,
128    indexes_len: i64,
129}
130
131impl TextData {
132    /// Creates a text dataset from a file.
133    pub fn new<P: AsRef<std::path::Path>>(filename: P) -> Result<TextData, TchError> {
134        let mut buffer = std::fs::read(&filename).map_err(|err| {
135            std::io::Error::new(err.kind(), format!("{:?} {err}", filename.as_ref()))
136        })?;
137
138        let mut label_for_char = HashMap::<u8, u8>::new();
139        let mut char_for_label = Vec::<char>::new();
140        for c in buffer.iter_mut() {
141            *c = *label_for_char.entry(*c).or_insert_with(|| {
142                let label = char_for_label.len() as u8;
143                char_for_label.push(*c as char);
144                label
145            })
146        }
147
148        Ok(TextData { data: Tensor::from_slice(&buffer), char_for_label, label_for_char })
149    }
150
151    /// Returns the number of different characters/labels used by the dataset.
152    pub fn labels(&self) -> i64 {
153        self.char_for_label.len() as i64
154    }
155
156    /// Returns a shallow copy of the data.
157    pub fn data(&self) -> Tensor {
158        self.data.shallow_clone()
159    }
160
161    pub fn label_to_char(&self, label: i64) -> char {
162        self.char_for_label[label as usize]
163    }
164
165    pub fn char_to_label(&self, c: char) -> Result<u8, TchError> {
166        match self.label_for_char.get(&(c as u8)) {
167            None => Err(TchError::Convert(format!("cannot find char {c}"))),
168            Some(v) => Ok(*v),
169        }
170    }
171
172    /// Returns a batch iterator over the dataset.
173    /// Each sample is made of seq_len characters.
174    pub fn iter_shuffle(&self, seq_len: i64, batch_size: i64) -> TextDataIter {
175        let indexes_len = self.data.size()[0] - seq_len + 1;
176        TextDataIter {
177            data: self.data.shallow_clone(),
178            seq_len,
179            batch_index: 0,
180            batch_size,
181            indexes: Tensor::randperm(indexes_len, kind::INT64_CPU),
182            indexes_len,
183        }
184    }
185}
186
187impl Iterator for TextDataIter {
188    type Item = Tensor;
189
190    fn next(&mut self) -> Option<Self::Item> {
191        let start = self.batch_index * self.batch_size;
192        let size = std::cmp::min(self.batch_size, self.indexes_len - start);
193        if size < self.batch_size {
194            None
195        } else {
196            self.batch_index += 1;
197            let indexes = Vec::<i64>::try_from(&self.indexes.i(start..start + size)).unwrap();
198            let batch: Vec<_> = indexes.iter().map(|&i| self.data.i(i..i + self.seq_len)).collect();
199            let batch: Vec<_> = batch.iter().collect();
200            Some(Tensor::stack(&batch, 0))
201        }
202    }
203}