train_station/tensor/iterator/
chunks.rs1use crate::tensor::core::Tensor;
4use std::iter::{ExactSizeIterator, FusedIterator};
5
6pub struct TensorChunksIterator<'a> {
7 pub(crate) source: &'a Tensor,
8 pub(crate) chunk_size: usize,
9 pub(crate) position: usize,
10 pub(crate) end: usize,
11}
12
13impl<'a> TensorChunksIterator<'a> {
14 #[inline]
15 pub fn new(source: &'a Tensor, chunk_size: usize) -> Self {
16 assert!(chunk_size > 0, "chunk_size must be > 0");
17 Self {
18 source,
19 chunk_size,
20 position: 0,
21 end: source.size(),
22 }
23 }
24
25 #[inline]
26 fn create_chunk_view(&self, start: usize, len: usize) -> Tensor {
27 if len == 0 {
28 return Tensor::new(vec![0]);
29 }
30 let v = self.source.slice_view(start, 1, len);
32 if v.is_contiguous() {
33 v
34 } else {
35 v.contiguous()
36 }
37 }
38}
39
40impl<'a> Iterator for TensorChunksIterator<'a> {
41 type Item = Tensor;
42 #[inline]
43 fn next(&mut self) -> Option<Self::Item> {
44 if self.position >= self.end {
45 return None;
46 }
47 let start = self.position;
48 let remaining = self.end - start;
49 let take = remaining.min(self.chunk_size);
50 self.position += take;
51 Some(self.create_chunk_view(start, take))
52 }
53 #[inline]
54 fn size_hint(&self) -> (usize, Option<usize>) {
55 let remaining = self.end.saturating_sub(self.position);
56 let n = if remaining == 0 {
57 0
58 } else {
59 remaining.div_ceil(self.chunk_size)
60 };
61 (n, Some(n))
62 }
63}
64
65impl<'a> ExactSizeIterator for TensorChunksIterator<'a> {
66 #[inline]
67 fn len(&self) -> usize {
68 let remaining = self.end.saturating_sub(self.position);
69 if remaining == 0 {
70 0
71 } else {
72 remaining.div_ceil(self.chunk_size)
73 }
74 }
75}
76
77impl<'a> FusedIterator for TensorChunksIterator<'a> {}
78
79impl<'a> DoubleEndedIterator for TensorChunksIterator<'a> {
80 #[inline]
81 fn next_back(&mut self) -> Option<Self::Item> {
82 if self.position >= self.end {
83 return None;
84 }
85 let remaining = self.end - self.position;
86 let take = remaining.min(self.chunk_size);
87 self.end -= take;
88 Some(self.create_chunk_view(self.end, take))
89 }
90}
91
92pub struct TensorChunksExactIterator<'a> {
93 pub(crate) source: &'a Tensor,
94 pub(crate) chunk_size: usize,
95 pub(crate) position: usize,
96 pub(crate) exact_end: usize,
97 pub(crate) remainder_start: usize,
98 pub(crate) remainder_len: usize,
99}
100
101impl<'a> TensorChunksExactIterator<'a> {
102 #[inline]
103 pub fn new(source: &'a Tensor, chunk_size: usize) -> Self {
104 assert!(chunk_size > 0, "chunk_size must be > 0");
105 let size = source.size();
106 let exact_chunks = size / chunk_size;
107 let exact_end = exact_chunks * chunk_size;
108 let remainder_len = size - exact_end;
109 Self {
110 source,
111 chunk_size,
112 position: 0,
113 exact_end,
114 remainder_start: exact_end,
115 remainder_len,
116 }
117 }
118
119 #[inline]
120 pub fn remainder(&self) -> Tensor {
121 if self.remainder_len == 0 {
122 Tensor::new(vec![0])
123 } else {
124 let v = self
125 .source
126 .slice_view(self.remainder_start, 1, self.remainder_len);
127 if v.is_contiguous() {
128 v
129 } else {
130 v.contiguous()
131 }
132 }
133 }
134
135 #[inline]
136 fn create_chunk_view(&self, start: usize) -> Tensor {
137 let v = self.source.slice_view(start, 1, self.chunk_size);
138 if v.is_contiguous() {
139 v
140 } else {
141 v.contiguous()
142 }
143 }
144}
145
146impl<'a> Iterator for TensorChunksExactIterator<'a> {
147 type Item = Tensor;
148 #[inline]
149 fn next(&mut self) -> Option<Self::Item> {
150 if self.position >= self.exact_end {
151 return None;
152 }
153 let start = self.position;
154 self.position += self.chunk_size;
155 Some(self.create_chunk_view(start))
156 }
157 #[inline]
158 fn size_hint(&self) -> (usize, Option<usize>) {
159 let remaining = self.exact_end.saturating_sub(self.position);
160 let n = remaining / self.chunk_size;
161 (n, Some(n))
162 }
163}
164
165impl<'a> ExactSizeIterator for TensorChunksExactIterator<'a> {
166 #[inline]
167 fn len(&self) -> usize {
168 (self.exact_end.saturating_sub(self.position)) / self.chunk_size
169 }
170}
171
172impl<'a> FusedIterator for TensorChunksExactIterator<'a> {}
173
174impl<'a> DoubleEndedIterator for TensorChunksExactIterator<'a> {
175 #[inline]
176 fn next_back(&mut self) -> Option<Self::Item> {
177 if self.position >= self.exact_end {
178 return None;
179 }
180 self.exact_end = self.exact_end.saturating_sub(self.chunk_size);
181 Some(self.create_chunk_view(self.exact_end))
182 }
183}
184
185impl Tensor {
186 #[inline]
187 pub fn iter_chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_> {
188 TensorChunksIterator::new(self, chunk_size)
189 }
190
191 #[inline]
192 pub fn iter_chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_> {
193 TensorChunksExactIterator::new(self, chunk_size)
194 }
195
196 #[inline]
203 pub fn iter_fast_chunks(&self) -> TensorChunksIterator<'_> {
204 let n = self.size();
205 if n == 0 {
206 return TensorChunksIterator::new(self, 1);
207 }
208 let mut sz = 16_384usize;
210 if n < 16_384 {
212 sz = 4_096;
213 }
214 if n > 1_048_576 {
215 sz = 65_536;
216 }
217 let lane = crate::tensor::core::Tensor::simd_lane_width_elems_runtime();
219 if lane > 1 {
220 sz = sz.div_ceil(lane) * lane;
221 }
222 sz = sz.clamp(4_096, 262_144);
224 TensorChunksIterator::new(self, sz)
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_chunks() {
234 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
235 let v: Vec<Tensor> = t.iter_chunks(2).collect();
236 assert_eq!(v.len(), 3);
237 assert_eq!(v[0].data(), &[1.0, 2.0]);
238 assert_eq!(v[1].data(), &[3.0, 4.0]);
239 assert_eq!(v[2].data(), &[5.0]);
240 }
241
242 #[test]
243 fn test_chunks_exact() {
244 let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
245 let mut it = t.iter_chunks_exact(2);
246 let a = it.next().unwrap();
247 let b = it.next().unwrap();
248 assert!(it.next().is_none());
249 assert_eq!(a.data(), &[10.0, 20.0]);
250 assert_eq!(b.data(), &[30.0, 40.0]);
251 assert_eq!(it.remainder().data(), &[50.0]);
252 }
253
254 #[test]
255 fn test_chunks_over_3d_outerdim() {
256 let vals: Vec<f32> = (0..24).map(|i| i as f32).collect();
258 let t = Tensor::from_slice(&vals, vec![2, 3, 4]).unwrap();
259 let mut collected: Vec<Vec<Tensor>> = Vec::new();
261 for b in t.iter_dim(0) {
262 let chunks: Vec<Tensor> = b.split(2, 0); assert_eq!(chunks[0].shape().dims(), vec![2, 4]);
265 assert_eq!(chunks[1].shape().dims(), vec![1, 4]);
266 collected.push(chunks);
267 }
268 assert_eq!(collected.len(), 2);
269 }
270
271 #[test]
272 fn test_chunks_gradient_after_collect() {
273 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
274 .unwrap()
275 .with_requires_grad();
276 let parts = t.split(2, 1);
278 let parts2: Vec<Tensor> = parts.into_iter().map(|p| p.mul_scalar(3.0)).collect();
279 let y = Tensor::cat(&parts2, 1);
280 let mut loss = y.sum();
281 loss.backward(None);
282 let g = t.grad_owned().unwrap();
283 assert_eq!(g.data(), &[3.0, 3.0, 3.0, 3.0, 3.0, 3.0]);
285 }
286}