train_station/tensor/iterator/
collect.rs

1//! Collection helpers for iterators yielding tensors
2
3use crate::gradtrack::is_grad_enabled;
4#[cfg(target_arch = "x86_64")]
5use crate::tensor::core::memory::{detect_runtime_simd, simd_alignment_bytes, SimdLevel};
6use crate::tensor::core::Tensor;
7use std::iter::FromIterator;
8
9impl Tensor {
10    /// Collect tensors into a single tensor with target shape, copying data in iterator order.
11    /// Optimizes copy using SIMD when available; asserts total size matches.
12    #[inline]
13    pub fn collect_into_shape<I: IntoIterator<Item = Tensor>>(iter: I, dims: Vec<usize>) -> Tensor {
14        let total: usize = dims.iter().copied().product();
15        let elements: Vec<Tensor> = iter.into_iter().collect();
16        let sum_sizes: usize = elements.iter().map(|t| t.size()).sum();
17        assert_eq!(
18            sum_sizes, total,
19            "collect_into_shape: element sizes {} do not match target size {}",
20            sum_sizes, total
21        );
22        let requires_grad = elements.iter().any(|t| t.requires_grad()) && is_grad_enabled();
23
24        if requires_grad {
25            // Autograd-preserving path: flatten each element, concatenate with cat,
26            // then reshape to the requested dims via view/reshape to preserve GradFn wiring.
27            let mut flat_parts: Vec<Tensor> = Vec::with_capacity(elements.len());
28            for t in elements.into_iter() {
29                // Ensure rank 1 for cat
30                flat_parts.push(t.flatten());
31            }
32            let concatenated = Tensor::cat(&flat_parts, 0); // [total]
33                                                            // Use view for reshape (registers GradFn::Reshape)
34            let new_shape: Vec<i32> = dims.iter().map(|&d| d as i32).collect();
35            let out = concatenated.view(new_shape);
36            return out;
37        }
38
39        // Fast forward-only copy when no gradients are required
40        let mut result = Tensor::new_uninitialized(dims);
41        let mut offset = 0usize;
42        unsafe {
43            let dst = result.as_mut_ptr();
44            for t in &elements {
45                let sz = t.size();
46                if sz == 0 {
47                    continue;
48                }
49                optimized_copy(t.as_ptr(), dst.add(offset), sz);
50                offset += sz;
51            }
52        }
53        result
54    }
55}
56
57// ===== Inherent convenience methods on core iterators =====
58// These allow calling `.collect_shape(dims)` directly on the iterator returned by
59// Tensor's iterator constructors without importing any extension traits. For more
60// complex iterator chains that use adapters like `.map(...)`, the extension traits
61// are still available (and re-exported at crate root) to enable method-style usage.
62
63use crate::tensor::iterator::chunks::{TensorChunksExactIterator, TensorChunksIterator};
64use crate::tensor::iterator::element::TensorElementIterator;
65use crate::tensor::iterator::viewdim::TensorDimIterator;
66use crate::tensor::iterator::windows::TensorWindowsIterator;
67
68impl<'a> TensorChunksIterator<'a> {
69    /// Collect this chunks iterator into a tensor with the provided shape.
70    ///
71    /// Gradient tracking: If any produced chunk tensor requires gradients and
72    /// gradients are enabled, the resulting tensor will preserve autograd
73    /// connections back to the original source tensor.
74    #[inline]
75    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
76        Tensor::collect_into_shape(self, dims)
77    }
78}
79
80impl<'a> TensorChunksExactIterator<'a> {
81    /// Collect this exact-sized chunks iterator into a tensor with the provided shape.
82    /// See [`TensorChunksIterator::collect_shape`] for gradient behavior.
83    #[inline]
84    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
85        Tensor::collect_into_shape(self, dims)
86    }
87}
88
89impl<'a> TensorWindowsIterator<'a> {
90    /// Collect this windows iterator into a tensor with the provided shape.
91    /// See [`TensorChunksIterator::collect_shape`] for gradient behavior.
92    #[inline]
93    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
94        Tensor::collect_into_shape(self, dims)
95    }
96}
97
98impl<'a> TensorDimIterator<'a> {
99    /// Collect this dimension iterator into a tensor with the provided shape.
100    /// See [`TensorChunksIterator::collect_shape`] for gradient behavior.
101    #[inline]
102    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
103        Tensor::collect_into_shape(self, dims)
104    }
105}
106
107impl<'a> TensorElementIterator<'a> {
108    /// Collect this element iterator into a tensor with the provided shape.
109    /// See [`TensorChunksIterator::collect_shape`] for gradient behavior.
110    #[inline]
111    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
112        Tensor::collect_into_shape(self, dims)
113    }
114}
115
116impl Tensor {
117    /// Inherent helper to collect any iterator of `Tensor` into the specified shape.
118    ///
119    /// This is equivalent to calling `.collect_shape(dims)` on the iterator via the
120    /// extension trait, but provided here as a convenience that works without bringing
121    /// the trait into scope. Example:
122    ///
123    /// ```
124    /// # use train_station::Tensor;
125    /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
126    /// let y = Tensor::collect_shape_from(x.iter_elements().map(|e| e.mul_scalar(2.0)), vec![4]);
127    /// ```
128    ///
129    /// Gradient tracking behavior matches [`Tensor::collect_into_shape`].
130    #[inline]
131    pub fn collect_shape_from<I: IntoIterator<Item = Tensor>>(iter: I, dims: Vec<usize>) -> Tensor {
132        Tensor::collect_into_shape(iter, dims)
133    }
134
135    /// Inherent helper to collect any iterator of `f32` into a shaped tensor.
136    /// This mirrors the `ValuesCollectExt::collect_shape` functionality but does
137    /// not require importing the extension trait.
138    #[inline]
139    pub fn collect_values_shape<I: IntoIterator<Item = f32>>(iter: I, dims: Vec<usize>) -> Tensor {
140        let total: usize = dims.iter().copied().product();
141        let mut out = Tensor::new_uninitialized(dims);
142        if total == 0 {
143            return out;
144        }
145        unsafe {
146            let dst = out.as_mut_ptr();
147            let mut i = 0usize;
148            for v in iter {
149                if i >= total {
150                    break;
151                }
152                *dst.add(i) = v;
153                i += 1;
154            }
155            assert_eq!(
156                i, total,
157                "values collect_shape: provided iterator produced {} values, expected {}",
158                i, total
159            );
160        }
161        out
162    }
163}
164
165// Optimized collection from Iterator<Item=f32>
166impl FromIterator<f32> for Tensor {
167    /// Collect f32 values into a 1D contiguous, SIMD-aligned Tensor
168    ///
169    /// - Pre-allocates with optimized alignment and possible padding
170    /// - Copies with AVX512→AVX2→SSE→scalar fallback
171    /// - No gradient tracking is set on the result
172    #[inline]
173    fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self {
174        // First pass: collect into a Vec<f32> to know exact length
175        // Note: we could attempt a multi-pass size_hint growth strategy, but
176        // a single Vec collect is generally fastest in practice and keeps code simple.
177        let v: Vec<f32> = iter.into_iter().collect();
178        let n = v.len();
179        if n == 0 {
180            return Tensor::new(vec![0]);
181        }
182
183        let mut out = Tensor::new_uninitialized(vec![n]);
184        unsafe {
185            optimized_copy(v.as_ptr(), out.as_mut_ptr(), n);
186        }
187        out
188    }
189}
190
191/// Use SIMD-optimized copy when available; falls back to scalar/unrolled copies.
192#[inline]
193pub(crate) unsafe fn optimized_copy(src: *const f32, dst: *mut f32, count: usize) {
194    if count == 0 {
195        return;
196    }
197    if count <= 32 {
198        std::ptr::copy_nonoverlapping(src, dst, count);
199        return;
200    }
201
202    #[cfg(target_arch = "x86_64")]
203    {
204        match detect_runtime_simd() {
205            SimdLevel::Avx512 => {
206                if simd_copy_avx512_best(src, dst, count) {
207                    return;
208                }
209            }
210            SimdLevel::Avx2 => {
211                if simd_copy_avx2_best(src, dst, count) {
212                    return;
213                }
214            }
215            SimdLevel::Sse2 => {
216                if simd_copy_sse_best(src, dst, count) {
217                    return;
218                }
219            }
220            SimdLevel::Scalar => {}
221        }
222    }
223
224    scalar_copy_unrolled(src, dst, count);
225}
226
227#[cfg(target_arch = "x86_64")]
228#[inline]
229unsafe fn simd_copy_avx512_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
230    if !is_x86_feature_detected!("avx512f") || count < 16 {
231        return false;
232    }
233    let align = simd_alignment_bytes(SimdLevel::Avx512);
234    let src_mod = (src as usize) % align;
235    let dst_mod = (dst as usize) % align;
236    let src_al = src_mod == 0;
237    let dst_al = dst_mod == 0;
238    if src_al && dst_al {
239        simd_copy_avx512_aligned(src, dst, count);
240    } else if src_mod == dst_mod {
241        let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
242        let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
243        if elems_to_align > 0 && elems_to_align < count {
244            std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
245            let src2 = src.add(elems_to_align);
246            let dst2 = dst.add(elems_to_align);
247            let rem = count - elems_to_align;
248            simd_copy_avx512_aligned(src2, dst2, rem);
249        } else {
250            simd_copy_avx512_unaligned(src, dst, count);
251        }
252    } else {
253        simd_copy_avx512_unaligned(src, dst, count);
254    }
255    true
256}
257
258#[cfg(target_arch = "x86_64")]
259#[inline]
260#[target_feature(enable = "avx512f")]
261unsafe fn simd_copy_avx512_aligned(src: *const f32, dst: *mut f32, count: usize) {
262    use std::arch::x86_64::*;
263    let mut offset = 0usize;
264    let block = 64usize;
265    let n_blocks = count / block;
266    for _ in 0..n_blocks {
267        let a = _mm512_load_ps(src.add(offset));
268        let b = _mm512_load_ps(src.add(offset + 16));
269        let c = _mm512_load_ps(src.add(offset + 32));
270        let d = _mm512_load_ps(src.add(offset + 48));
271        _mm512_store_ps(dst.add(offset), a);
272        _mm512_store_ps(dst.add(offset + 16), b);
273        _mm512_store_ps(dst.add(offset + 32), c);
274        _mm512_store_ps(dst.add(offset + 48), d);
275        offset += block;
276    }
277    let mut rem = count - offset;
278    while rem >= 16 {
279        let v = _mm512_load_ps(src.add(offset));
280        _mm512_store_ps(dst.add(offset), v);
281        offset += 16;
282        rem -= 16;
283    }
284    if rem > 0 {
285        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
286    }
287}
288
289#[cfg(target_arch = "x86_64")]
290#[inline]
291#[target_feature(enable = "avx512f")]
292unsafe fn simd_copy_avx512_unaligned(src: *const f32, dst: *mut f32, count: usize) {
293    use std::arch::x86_64::*;
294    let mut offset = 0usize;
295    let block = 64usize;
296    let n_blocks = count / block;
297    for _ in 0..n_blocks {
298        let a = _mm512_loadu_ps(src.add(offset));
299        let b = _mm512_loadu_ps(src.add(offset + 16));
300        let c = _mm512_loadu_ps(src.add(offset + 32));
301        let d = _mm512_loadu_ps(src.add(offset + 48));
302        _mm512_storeu_ps(dst.add(offset), a);
303        _mm512_storeu_ps(dst.add(offset + 16), b);
304        _mm512_storeu_ps(dst.add(offset + 32), c);
305        _mm512_storeu_ps(dst.add(offset + 48), d);
306        offset += block;
307    }
308    let mut rem = count - offset;
309    while rem >= 16 {
310        let v = _mm512_loadu_ps(src.add(offset));
311        _mm512_storeu_ps(dst.add(offset), v);
312        offset += 16;
313        rem -= 16;
314    }
315    if rem > 0 {
316        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
317    }
318}
319
320#[cfg(target_arch = "x86_64")]
321#[inline]
322unsafe fn simd_copy_avx2_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
323    if !is_x86_feature_detected!("avx2") || count < 8 {
324        return false;
325    }
326    let align = simd_alignment_bytes(SimdLevel::Avx2);
327    let src_mod = (src as usize) % align;
328    let dst_mod = (dst as usize) % align;
329    let src_al = src_mod == 0;
330    let dst_al = dst_mod == 0;
331    if src_al && dst_al {
332        simd_copy_avx2_aligned(src, dst, count);
333    } else if src_mod == dst_mod {
334        let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
335        let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
336        if elems_to_align > 0 && elems_to_align < count {
337            std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
338            let src2 = src.add(elems_to_align);
339            let dst2 = dst.add(elems_to_align);
340            let rem = count - elems_to_align;
341            simd_copy_avx2_aligned(src2, dst2, rem);
342        } else {
343            simd_copy_avx2_unaligned(src, dst, count);
344        }
345    } else {
346        simd_copy_avx2_unaligned(src, dst, count);
347    }
348    true
349}
350
351#[cfg(target_arch = "x86_64")]
352#[inline]
353#[target_feature(enable = "avx2")]
354unsafe fn simd_copy_avx2_aligned(src: *const f32, dst: *mut f32, count: usize) {
355    use std::arch::x86_64::*;
356    let mut offset = 0usize;
357    let block = 32usize;
358    let n_blocks = count / block;
359    for _ in 0..n_blocks {
360        let v1 = _mm256_load_ps(src.add(offset));
361        let v2 = _mm256_load_ps(src.add(offset + 8));
362        let v3 = _mm256_load_ps(src.add(offset + 16));
363        let v4 = _mm256_load_ps(src.add(offset + 24));
364        _mm256_store_ps(dst.add(offset), v1);
365        _mm256_store_ps(dst.add(offset + 8), v2);
366        _mm256_store_ps(dst.add(offset + 16), v3);
367        _mm256_store_ps(dst.add(offset + 24), v4);
368        offset += block;
369    }
370    let mut rem = count - offset;
371    while rem >= 8 {
372        let v = _mm256_load_ps(src.add(offset));
373        _mm256_store_ps(dst.add(offset), v);
374        offset += 8;
375        rem -= 8;
376    }
377    if rem > 0 {
378        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
379    }
380}
381
382#[cfg(target_arch = "x86_64")]
383#[inline]
384#[target_feature(enable = "avx2")]
385unsafe fn simd_copy_avx2_unaligned(src: *const f32, dst: *mut f32, count: usize) {
386    use std::arch::x86_64::*;
387    let mut offset = 0usize;
388    let block = 32usize;
389    let n_blocks = count / block;
390    for _ in 0..n_blocks {
391        let v1 = _mm256_loadu_ps(src.add(offset));
392        let v2 = _mm256_loadu_ps(src.add(offset + 8));
393        let v3 = _mm256_loadu_ps(src.add(offset + 16));
394        let v4 = _mm256_loadu_ps(src.add(offset + 24));
395        _mm256_storeu_ps(dst.add(offset), v1);
396        _mm256_storeu_ps(dst.add(offset + 8), v2);
397        _mm256_storeu_ps(dst.add(offset + 16), v3);
398        _mm256_storeu_ps(dst.add(offset + 24), v4);
399        offset += block;
400    }
401    let mut rem = count - offset;
402    while rem >= 8 {
403        let v = _mm256_loadu_ps(src.add(offset));
404        _mm256_storeu_ps(dst.add(offset), v);
405        offset += 8;
406        rem -= 8;
407    }
408    if rem > 0 {
409        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
410    }
411}
412
413#[cfg(target_arch = "x86_64")]
414#[inline]
415unsafe fn simd_copy_sse_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
416    if !is_x86_feature_detected!("sse2") || count < 4 {
417        return false;
418    }
419    let align = simd_alignment_bytes(SimdLevel::Sse2);
420    let src_mod = (src as usize) % align;
421    let dst_mod = (dst as usize) % align;
422    let src_al = src_mod == 0;
423    let dst_al = dst_mod == 0;
424    if src_al && dst_al {
425        simd_copy_sse_aligned(src, dst, count);
426    } else if src_mod == dst_mod {
427        let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
428        let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
429        if elems_to_align > 0 && elems_to_align < count {
430            std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
431            let src2 = src.add(elems_to_align);
432            let dst2 = dst.add(elems_to_align);
433            let rem = count - elems_to_align;
434            simd_copy_sse_aligned(src2, dst2, rem);
435        } else {
436            simd_copy_sse_unaligned(src, dst, count);
437        }
438    } else {
439        simd_copy_sse_unaligned(src, dst, count);
440    }
441    true
442}
443
444#[cfg(target_arch = "x86_64")]
445#[inline]
446#[target_feature(enable = "sse2")]
447unsafe fn simd_copy_sse_aligned(src: *const f32, dst: *mut f32, count: usize) {
448    use std::arch::x86_64::*;
449    let mut offset = 0usize;
450    let block = 16usize;
451    let n_blocks = count / block;
452    for _ in 0..n_blocks {
453        let a = _mm_load_ps(src.add(offset));
454        let b = _mm_load_ps(src.add(offset + 4));
455        let c = _mm_load_ps(src.add(offset + 8));
456        let d = _mm_load_ps(src.add(offset + 12));
457        _mm_store_ps(dst.add(offset), a);
458        _mm_store_ps(dst.add(offset + 4), b);
459        _mm_store_ps(dst.add(offset + 8), c);
460        _mm_store_ps(dst.add(offset + 12), d);
461        offset += block;
462    }
463    let mut rem = count - offset;
464    while rem >= 4 {
465        let v = _mm_load_ps(src.add(offset));
466        _mm_store_ps(dst.add(offset), v);
467        offset += 4;
468        rem -= 4;
469    }
470    if rem > 0 {
471        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
472    }
473}
474
475#[cfg(target_arch = "x86_64")]
476#[inline]
477#[target_feature(enable = "sse2")]
478unsafe fn simd_copy_sse_unaligned(src: *const f32, dst: *mut f32, count: usize) {
479    use std::arch::x86_64::*;
480    let mut offset = 0usize;
481    let block = 16usize;
482    let n_blocks = count / block;
483    for _ in 0..n_blocks {
484        let a = _mm_loadu_ps(src.add(offset));
485        let b = _mm_loadu_ps(src.add(offset + 4));
486        let c = _mm_loadu_ps(src.add(offset + 8));
487        let d = _mm_loadu_ps(src.add(offset + 12));
488        _mm_storeu_ps(dst.add(offset), a);
489        _mm_storeu_ps(dst.add(offset + 4), b);
490        _mm_storeu_ps(dst.add(offset + 8), c);
491        _mm_storeu_ps(dst.add(offset + 12), d);
492        offset += block;
493    }
494    let mut rem = count - offset;
495    while rem >= 4 {
496        let v = _mm_loadu_ps(src.add(offset));
497        _mm_storeu_ps(dst.add(offset), v);
498        offset += 4;
499        rem -= 4;
500    }
501    if rem > 0 {
502        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
503    }
504}
505
506#[inline]
507unsafe fn scalar_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
508    let unroll = 8;
509    let blocks = count / unroll;
510    let mut offset = 0usize;
511    for _ in 0..blocks {
512        *dst.add(offset) = *src.add(offset);
513        *dst.add(offset + 1) = *src.add(offset + 1);
514        *dst.add(offset + 2) = *src.add(offset + 2);
515        *dst.add(offset + 3) = *src.add(offset + 3);
516        *dst.add(offset + 4) = *src.add(offset + 4);
517        *dst.add(offset + 5) = *src.add(offset + 5);
518        *dst.add(offset + 6) = *src.add(offset + 6);
519        *dst.add(offset + 7) = *src.add(offset + 7);
520        offset += unroll;
521    }
522    if offset < count {
523        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
524    }
525}
526
527/// Extension trait to collect iterator of tensors into provided shape.
528pub trait TensorCollectExt: Iterator<Item = Tensor> + Sized {
529    fn collect_shape(self, dims: Vec<usize>) -> Tensor;
530}
531
532impl<I> TensorCollectExt for I
533where
534    I: Iterator<Item = Tensor> + Sized,
535{
536    #[inline]
537    fn collect_shape(self, dims: Vec<usize>) -> Tensor {
538        Tensor::collect_into_shape(self, dims)
539    }
540}
541
542/// Extension trait to collect Iterator<Item=f32> directly into a shaped Tensor
543///
544/// This trait is automatically implemented for any `Iterator<Item = f32>`, so you can call
545/// `collect_shape()` directly on iterators yielding f32 values without importing the trait.
546pub trait ValuesCollectExt: Iterator<Item = f32> + Sized {
547    /// Collect f32 values from this iterator into a tensor with the specified shape
548    fn collect_shape(self, dims: Vec<usize>) -> Tensor;
549}
550
551impl<I> ValuesCollectExt for I
552where
553    I: Iterator<Item = f32> + Sized,
554{
555    #[inline]
556    fn collect_shape(self, dims: Vec<usize>) -> Tensor {
557        let total: usize = dims.iter().copied().product();
558        let mut out = Tensor::new_uninitialized(dims);
559        if total == 0 {
560            return out;
561        }
562
563        // For small datasets, use direct iteration to avoid allocation overhead
564        if total <= 64 {
565            unsafe {
566                let dst = out.as_mut_ptr();
567                let mut i = 0usize;
568                for v in self {
569                    if i >= total {
570                        break;
571                    }
572                    *dst.add(i) = v;
573                    i += 1;
574                }
575                assert_eq!(
576                    i, total,
577                    "values collect_shape: provided iterator produced {} values, expected {}",
578                    i, total
579                );
580            }
581            return out;
582        }
583
584        // For larger datasets, collect into temporary buffer then use optimized_copy
585        let temp_data: Vec<f32> = self.collect();
586        assert_eq!(
587            temp_data.len(),
588            total,
589            "values collect_shape: provided iterator produced {} values, expected {}",
590            temp_data.len(),
591            total
592        );
593
594        unsafe {
595            optimized_copy(temp_data.as_ptr(), out.as_mut_ptr(), total);
596        }
597        out
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn test_collect_shape() {
607        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
608        let mat = t.iter_chunks(2).collect_shape(vec![3, 2]);
609        assert_eq!(mat.shape().dims(), &[3, 2]);
610        assert_eq!(mat.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
611    }
612
613    #[test]
614    fn test_collect_shape_with_grad_preserves_backward() {
615        use crate::gradtrack::is_grad_enabled;
616        use crate::tensor::core::Tensor;
617
618        if !is_grad_enabled() {
619            // Ensure grad is enabled in normal test runs; skip if not
620        }
621
622        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4])
623            .unwrap()
624            .with_requires_grad();
625        // Split in chunks of 2, scale each chunk, collect with target shape [2,2]
626        let parts: Vec<Tensor> = t.iter_chunks(2).map(|c| c.mul_scalar(3.0)).collect();
627        let y = parts.into_iter().collect_shape(vec![2, 2]);
628        assert!(y.requires_grad());
629
630        let mut loss = y.sum();
631        loss.backward(None);
632        let g = t.grad_owned().unwrap();
633        // Each input element appears exactly once in the collected tensor and is scaled by 3
634        assert_eq!(g.data(), &[3.0, 3.0, 3.0, 3.0]);
635    }
636
637    #[test]
638    fn test_collect_from_values_into_tensor() {
639        // Collect Iterator<f32> into Tensor via FromIterator<f32>
640        let vals = (0..16).map(|i| i as f32);
641        let t: Tensor = vals.collect();
642        assert_eq!(t.shape().dims(), &[16]);
643        assert_eq!(t.data()[0], 0.0);
644        assert_eq!(t.data()[15], 15.0);
645    }
646
647    #[test]
648    fn test_values_iter_then_collect_shape() {
649        // values() + collect into a shaped tensor using collect_into_shape
650        let base =
651            Tensor::from_slice(&(0..12).map(|i| i as f32).collect::<Vec<_>>(), vec![3, 4]).unwrap();
652        let flat_vals = base.iter_values();
653        let collected: Tensor = flat_vals.collect();
654        assert_eq!(collected.shape().dims(), &[12]);
655        // reshape using view semantics
656        let shaped = collected.view(vec![3, 4]);
657        assert_eq!(shaped.shape().dims(), &[3, 4]);
658        assert_eq!(shaped.get(&[2, 3]), 11.0);
659    }
660
661    #[test]
662    fn test_values_collect_shape_direct() {
663        // No need to import trait - collect_shape is available on Iterator<Item=f32>
664        let shaped: Tensor = (0..12).map(|i| i as f32).collect_shape(vec![3, 4]);
665        assert_eq!(shaped.shape().dims(), &[3, 4]);
666        assert_eq!(shaped.get(&[0, 0]), 0.0);
667        assert_eq!(shaped.get(&[2, 3]), 11.0);
668    }
669}