1use crate::{kind, kind::Kind, Device, IndexOp, TchError, Tensor};
3use std::collections::HashMap;
4
5#[derive(Debug)]
11pub struct Iter2 {
12 xs: Tensor,
13 ys: Tensor,
14 batch_index: i64,
15 batch_size: i64,
16 total_size: i64,
17 device: Device,
18 return_smaller_last_batch: bool,
19}
20
21impl Iter2 {
22 pub fn f_new(xs: &Tensor, ys: &Tensor, batch_size: i64) -> Result<Iter2, TchError> {
36 let total_size = xs.size()[0];
37 if ys.size()[0] != total_size {
38 return Err(TchError::Shape(format!(
39 "different dimension for the two inputs {xs:?} {ys:?}"
40 )));
41 }
42 Ok(Iter2 {
43 xs: xs.shallow_clone(),
44 ys: ys.shallow_clone(),
45 batch_index: 0,
46 batch_size,
47 total_size,
48 device: Device::Cpu,
49 return_smaller_last_batch: false,
50 })
51 }
52
53 pub fn new(xs: &Tensor, ys: &Tensor, batch_size: i64) -> Iter2 {
66 Iter2::f_new(xs, ys, batch_size).unwrap()
67 }
68
69 pub fn shuffle(&mut self) -> &mut Iter2 {
74 let index = Tensor::randperm(self.total_size, (Kind::Int64, self.device));
75 self.xs = self.xs.index_select(0, &index);
76 self.ys = self.ys.index_select(0, &index);
77 self
78 }
79
80 #[allow(clippy::wrong_self_convention)]
82 pub fn to_device(&mut self, device: Device) -> &mut Iter2 {
83 self.device = device;
84 self
85 }
86
87 pub fn return_smaller_last_batch(&mut self) -> &mut Iter2 {
89 self.return_smaller_last_batch = true;
90 self
91 }
92}
93
94impl Iterator for Iter2 {
95 type Item = (Tensor, Tensor);
96
97 fn next(&mut self) -> Option<Self::Item> {
98 let start = self.batch_index * self.batch_size;
99 let size = std::cmp::min(self.batch_size, self.total_size - start);
100 if size <= 0 || (!self.return_smaller_last_batch && size < self.batch_size) {
101 None
102 } else {
103 self.batch_index += 1;
104 Some((
105 self.xs.i(start..start + size).to_device(self.device),
106 self.ys.i(start..start + size).to_device(self.device),
107 ))
108 }
109 }
110}
111
112#[derive(Debug)]
114pub struct TextData {
115 data: Tensor,
116 char_for_label: Vec<char>,
117 label_for_char: HashMap<u8, u8>,
118}
119
120#[derive(Debug)]
122pub struct TextDataIter {
123 data: Tensor,
124 seq_len: i64,
125 batch_index: i64,
126 batch_size: i64,
127 indexes: Tensor,
128 indexes_len: i64,
129}
130
131impl TextData {
132 pub fn new<P: AsRef<std::path::Path>>(filename: P) -> Result<TextData, TchError> {
134 let mut buffer = std::fs::read(&filename).map_err(|err| {
135 std::io::Error::new(err.kind(), format!("{:?} {err}", filename.as_ref()))
136 })?;
137
138 let mut label_for_char = HashMap::<u8, u8>::new();
139 let mut char_for_label = Vec::<char>::new();
140 for c in buffer.iter_mut() {
141 *c = *label_for_char.entry(*c).or_insert_with(|| {
142 let label = char_for_label.len() as u8;
143 char_for_label.push(*c as char);
144 label
145 })
146 }
147
148 Ok(TextData { data: Tensor::from_slice(&buffer), char_for_label, label_for_char })
149 }
150
151 pub fn labels(&self) -> i64 {
153 self.char_for_label.len() as i64
154 }
155
156 pub fn data(&self) -> Tensor {
158 self.data.shallow_clone()
159 }
160
161 pub fn label_to_char(&self, label: i64) -> char {
162 self.char_for_label[label as usize]
163 }
164
165 pub fn char_to_label(&self, c: char) -> Result<u8, TchError> {
166 match self.label_for_char.get(&(c as u8)) {
167 None => Err(TchError::Convert(format!("cannot find char {c}"))),
168 Some(v) => Ok(*v),
169 }
170 }
171
172 pub fn iter_shuffle(&self, seq_len: i64, batch_size: i64) -> TextDataIter {
175 let indexes_len = self.data.size()[0] - seq_len + 1;
176 TextDataIter {
177 data: self.data.shallow_clone(),
178 seq_len,
179 batch_index: 0,
180 batch_size,
181 indexes: Tensor::randperm(indexes_len, kind::INT64_CPU),
182 indexes_len,
183 }
184 }
185}
186
187impl Iterator for TextDataIter {
188 type Item = Tensor;
189
190 fn next(&mut self) -> Option<Self::Item> {
191 let start = self.batch_index * self.batch_size;
192 let size = std::cmp::min(self.batch_size, self.indexes_len - start);
193 if size < self.batch_size {
194 None
195 } else {
196 self.batch_index += 1;
197 let indexes = Vec::<i64>::try_from(&self.indexes.i(start..start + size)).unwrap();
198 let batch: Vec<_> = indexes.iter().map(|&i| self.data.i(i..i + self.seq_len)).collect();
199 let batch: Vec<_> = batch.iter().collect();
200 Some(Tensor::stack(&batch, 0))
201 }
202 }
203}