train_station/tensor/iterator/
chunks.rs1use crate::tensor::core::utils::should_use_fast_path;
15use crate::tensor::core::Tensor;
16use std::iter::{ExactSizeIterator, FusedIterator};
17
18pub struct TensorChunksIterator<'a> {
19 pub(crate) source: &'a Tensor,
20 pub(crate) chunk_size: usize,
21 pub(crate) position: usize,
22 pub(crate) end: usize,
23 pub(crate) owner: Option<Tensor>,
25}
26
27impl<'a> TensorChunksIterator<'a> {
28 #[inline]
29 pub fn new(source: &'a Tensor, chunk_size: usize) -> Self {
30 assert!(chunk_size > 0, "chunk_size must be > 0");
31 let fast = should_use_fast_path(&[source]);
32 let owner = if fast && !source.is_contiguous() && source.size() > 0 {
34 Some(source.contiguous())
35 } else {
36 None
37 };
38 Self {
39 source,
40 chunk_size,
41 position: 0,
42 end: source.size(),
43 owner,
44 }
45 }
46
47 #[inline]
48 fn create_chunk_view(&self, start: usize, len: usize) -> Tensor {
49 if len == 0 {
50 return Tensor::new(vec![0]);
51 }
52 let base: &Tensor = match &self.owner {
54 Some(o) => o,
55 None => self.source,
56 };
57 base.slice_view(start, 1, len)
58 }
59}
60
61impl<'a> Iterator for TensorChunksIterator<'a> {
62 type Item = Tensor;
63 #[inline]
64 fn next(&mut self) -> Option<Self::Item> {
65 if self.position >= self.end {
66 return None;
67 }
68 let start = self.position;
69 let remaining = self.end - start;
70 let take = remaining.min(self.chunk_size);
71 self.position += take;
72 Some(self.create_chunk_view(start, take))
73 }
74 #[inline]
75 fn size_hint(&self) -> (usize, Option<usize>) {
76 let remaining = self.end.saturating_sub(self.position);
77 let n = if remaining == 0 {
78 0
79 } else {
80 remaining.div_ceil(self.chunk_size)
81 };
82 (n, Some(n))
83 }
84}
85
86impl<'a> ExactSizeIterator for TensorChunksIterator<'a> {
87 #[inline]
88 fn len(&self) -> usize {
89 let remaining = self.end.saturating_sub(self.position);
90 if remaining == 0 {
91 0
92 } else {
93 remaining.div_ceil(self.chunk_size)
94 }
95 }
96}
97
98impl<'a> FusedIterator for TensorChunksIterator<'a> {}
99
100impl<'a> DoubleEndedIterator for TensorChunksIterator<'a> {
101 #[inline]
102 fn next_back(&mut self) -> Option<Self::Item> {
103 if self.position >= self.end {
104 return None;
105 }
106 let remaining = self.end - self.position;
107 let rem = remaining % self.chunk_size;
109 let take = if rem == 0 { self.chunk_size } else { rem };
110 self.end -= take;
111 Some(self.create_chunk_view(self.end, take))
112 }
113}
114
115pub struct TensorChunksExactIterator<'a> {
116 pub(crate) source: &'a Tensor,
117 pub(crate) chunk_size: usize,
118 pub(crate) position: usize,
119 pub(crate) exact_end: usize,
120 pub(crate) remainder_start: usize,
121 pub(crate) remainder_len: usize,
122 pub(crate) owner: Option<Tensor>,
124}
125
126impl<'a> TensorChunksExactIterator<'a> {
127 #[inline]
128 pub fn new(source: &'a Tensor, chunk_size: usize) -> Self {
129 assert!(chunk_size > 0, "chunk_size must be > 0");
130 let size = source.size();
131 let exact_chunks = size / chunk_size;
132 let exact_end = exact_chunks * chunk_size;
133 let remainder_len = size - exact_end;
134 let fast = should_use_fast_path(&[source]);
135 let owner = if fast && !source.is_contiguous() && size > 0 {
136 Some(source.contiguous())
137 } else {
138 None
139 };
140 Self {
141 source,
142 chunk_size,
143 position: 0,
144 exact_end,
145 remainder_start: exact_end,
146 remainder_len,
147 owner,
148 }
149 }
150
151 #[inline]
152 pub fn remainder(&self) -> Tensor {
153 if self.remainder_len == 0 {
154 Tensor::new(vec![0])
155 } else {
156 let base: &Tensor = match &self.owner {
157 Some(o) => o,
158 None => self.source,
159 };
160 base.slice_view(self.remainder_start, 1, self.remainder_len)
161 }
162 }
163
164 #[inline]
165 fn create_chunk_view(&self, start: usize) -> Tensor {
166 let base: &Tensor = match &self.owner {
167 Some(o) => o,
168 None => self.source,
169 };
170 base.slice_view(start, 1, self.chunk_size)
171 }
172}
173
174impl<'a> Iterator for TensorChunksExactIterator<'a> {
175 type Item = Tensor;
176 #[inline]
177 fn next(&mut self) -> Option<Self::Item> {
178 if self.position >= self.exact_end {
179 return None;
180 }
181 let start = self.position;
182 self.position += self.chunk_size;
183 Some(self.create_chunk_view(start))
184 }
185 #[inline]
186 fn size_hint(&self) -> (usize, Option<usize>) {
187 let remaining = self.exact_end.saturating_sub(self.position);
188 let n = remaining / self.chunk_size;
189 (n, Some(n))
190 }
191}
192
193impl<'a> ExactSizeIterator for TensorChunksExactIterator<'a> {
194 #[inline]
195 fn len(&self) -> usize {
196 (self.exact_end.saturating_sub(self.position)) / self.chunk_size
197 }
198}
199
200impl<'a> FusedIterator for TensorChunksExactIterator<'a> {}
201
202impl<'a> DoubleEndedIterator for TensorChunksExactIterator<'a> {
203 #[inline]
204 fn next_back(&mut self) -> Option<Self::Item> {
205 if self.position >= self.exact_end {
206 return None;
207 }
208 self.exact_end = self.exact_end.saturating_sub(self.chunk_size);
209 Some(self.create_chunk_view(self.exact_end))
210 }
211}
212
213impl Tensor {
214 #[inline]
235 pub fn chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_> {
236 TensorChunksIterator::new(self, chunk_size)
237 }
238
239 #[inline]
257 pub fn chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_> {
258 TensorChunksExactIterator::new(self, chunk_size)
259 }
260
261 #[deprecated(note = "Use Tensor::chunks(...) instead. This alias will be removed before 1.0.")]
262 #[inline]
263 pub fn iter_chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_> {
264 TensorChunksIterator::new(self, chunk_size)
265 }
266
267 #[deprecated(
268 note = "Use Tensor::chunks_exact(...) instead. This alias will be removed before 1.0."
269 )]
270 #[inline]
271 pub fn iter_chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_> {
272 TensorChunksExactIterator::new(self, chunk_size)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::gradtrack::NoGradTrack;
280
281 #[test]
282 fn test_chunks() {
283 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
284 let v: Vec<Tensor> = t.chunks(2).collect();
285 assert_eq!(v.len(), 3);
286 assert_eq!(v[0].data(), &[1.0, 2.0]);
287 assert_eq!(v[1].data(), &[3.0, 4.0]);
288 assert_eq!(v[2].data(), &[5.0]);
289 }
290
291 #[test]
292 fn test_chunks_exact() {
293 let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
294 let mut it = t.chunks_exact(2);
295 let a = it.next().unwrap();
296 let b = it.next().unwrap();
297 assert!(it.next().is_none());
298 assert_eq!(a.data(), &[10.0, 20.0]);
299 assert_eq!(b.data(), &[30.0, 40.0]);
300 assert_eq!(it.remainder().data(), &[50.0]);
301 }
302
303 #[test]
304 fn test_chunks_over_3d_outerdim() {
305 let vals: Vec<f32> = (0..24).map(|i| i as f32).collect();
307 let t = Tensor::from_slice(&vals, vec![2, 3, 4]).unwrap();
308 let mut collected: Vec<Vec<Tensor>> = Vec::new();
310 for b in t.iter_dim(0) {
311 let chunks: Vec<Tensor> = b.split(2, 0); assert_eq!(chunks[0].shape().dims(), vec![2, 4]);
314 assert_eq!(chunks[1].shape().dims(), vec![1, 4]);
315 collected.push(chunks);
316 }
317 assert_eq!(collected.len(), 2);
318 }
319
320 #[test]
321 fn test_chunks_gradient_after_collect() {
322 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
323 .unwrap()
324 .with_requires_grad();
325 let parts = t.split(2, 1);
327 let parts2: Vec<Tensor> = parts.into_iter().map(|p| p.mul_scalar(3.0)).collect();
328 let y = Tensor::cat(&parts2, 1);
329 let mut loss = y.sum();
330 loss.backward(None);
331 let g = t.grad_owned().unwrap();
332 assert_eq!(g.data(), &[3.0, 3.0, 3.0, 3.0, 3.0, 3.0]);
334 }
335
336 #[test]
337 fn test_chunks_iter_gradient_propagation() {
338 let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0], vec![4])
339 .unwrap()
340 .with_requires_grad();
341 let parts: Vec<Tensor> = t.chunks(2).map(|c| c.add_scalar(1.0)).collect();
343 let y = Tensor::cat(&parts, 0);
344 let mut loss = y.sum();
345 loss.backward(None);
346 let g = t.grad_owned().unwrap();
347 assert_eq!(g.data(), &[1.0, 1.0, 1.0, 1.0]);
348 }
349
350 #[test]
351 fn test_chunks_double_ended_and_size_hint() {
352 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
353 let mut it = t.chunks(2);
354 assert_eq!(it.size_hint(), (3, Some(3)));
355 assert_eq!(it.len(), 3);
356
357 let last = it.next_back().unwrap();
358 assert_eq!(last.data(), &[5.0]);
359 assert_eq!(it.len(), 2);
360
361 let first = it.next().unwrap();
362 assert_eq!(first.data(), &[1.0, 2.0]);
363 assert_eq!(it.len(), 1);
364
365 let middle = it.next().unwrap();
366 assert_eq!(middle.data(), &[3.0, 4.0]);
367 assert!(it.next().is_none());
368 assert!(it.next_back().is_none());
369 assert_eq!(it.size_hint(), (0, Some(0)));
370 }
371
372 #[test]
373 fn test_chunks_zero_sized_tensor() {
374 let t = Tensor::new(vec![0]);
375 let it = t.chunks(3);
376 assert_eq!(it.len(), 0);
377 assert_eq!(it.size_hint(), (0, Some(0)));
378 assert_eq!(it.collect::<Vec<_>>().len(), 0);
379 }
380
381 #[test]
382 fn test_chunks_no_grad_guard_disables_requires_grad() {
383 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4])
384 .unwrap()
385 .with_requires_grad();
386 let _guard = NoGradTrack::new();
387 let mut it = t.chunks(2);
388 let a = it.next().unwrap();
389 let b = it.next().unwrap();
390 assert!(!a.requires_grad());
391 assert!(!b.requires_grad());
392 let y: Tensor = t.chunks(2).collect_shape(vec![2, 2]);
394 assert!(!y.requires_grad());
395 }
396}