train_station/tensor/iterator/
collect.rs

1//! Collection helpers for iterators yielding tensors
2
3use crate::gradtrack::is_grad_enabled;
4#[cfg(target_arch = "x86_64")]
5use crate::tensor::core::memory::{detect_runtime_simd, simd_alignment_bytes, SimdLevel};
6use crate::tensor::core::Tensor;
7use std::iter::FromIterator;
8
9impl Tensor {
10    /// Collect tensors into a single tensor with target shape, copying data in iterator order.
11    /// Optimizes copy using SIMD when available; asserts total size matches.
12    #[inline]
13    #[track_caller]
14    pub fn collect_into_shape<I: IntoIterator<Item = Tensor>>(iter: I, dims: Vec<usize>) -> Tensor {
15        let total: usize = dims.iter().copied().product();
16
17        // If gradients are disabled, stream directly into the destination tensor without
18        // building an intermediate Vec<Tensor> buffer.
19        if !is_grad_enabled() {
20            let mut result = Tensor::new_uninitialized(dims);
21            let mut offset = 0usize;
22            let mut sum_sizes = 0usize;
23            unsafe {
24                let dst = result.as_mut_ptr();
25                for t in iter.into_iter() {
26                    let sz = t.size();
27                    if sz == 0 {
28                        continue;
29                    }
30                    optimized_copy(t.as_ptr(), dst.add(offset), sz);
31                    offset += sz;
32                    sum_sizes += sz;
33                }
34            }
35            assert_eq!(
36                sum_sizes, total,
37                "collect_into_shape: element sizes {} do not match target size {}",
38                sum_sizes, total
39            );
40            return result;
41        }
42
43        // Grad-enabled path: materialize elements to preserve autograd connections
44        let elements: Vec<Tensor> = iter.into_iter().collect();
45        let sum_sizes: usize = elements.iter().map(|t| t.size()).sum();
46        assert_eq!(
47            sum_sizes, total,
48            "collect_into_shape: element sizes {} do not match target size {}",
49            sum_sizes, total
50        );
51        let requires_grad = elements.iter().any(|t| t.requires_grad());
52
53        if requires_grad {
54            // Autograd-preserving path: flatten each element, concatenate with cat,
55            // then reshape to the requested dims via view/reshape to preserve GradFn wiring.
56            let mut flat_parts: Vec<Tensor> = Vec::with_capacity(elements.len());
57            for t in elements.into_iter() {
58                flat_parts.push(t.flatten());
59            }
60            let concatenated = Tensor::cat(&flat_parts, 0); // [total]
61            let new_shape: Vec<i32> = dims.iter().map(|&d| d as i32).collect();
62            let out = concatenated.view(new_shape);
63            return out;
64        }
65
66        // No element requires gradients (but grad tracking enabled globally). Copy directly.
67        let mut result = Tensor::new_uninitialized(dims);
68        let mut offset = 0usize;
69        unsafe {
70            let dst = result.as_mut_ptr();
71            for t in &elements {
72                let sz = t.size();
73                if sz == 0 {
74                    continue;
75                }
76                optimized_copy(t.as_ptr(), dst.add(offset), sz);
77                offset += sz;
78            }
79        }
80        result
81    }
82}
83
84// ===== Inherent convenience methods on core iterators =====
85// These allow calling `.collect_shape(dims)` directly on the iterator returned by
86// Tensor's iterator constructors without importing any extension traits. For more
87// complex iterator chains that use adapters like `.map(...)`, the extension traits
88// are still available (and re-exported at crate root) to enable method-style usage.
89
90use crate::tensor::iterator::chunks::{TensorChunksExactIterator, TensorChunksIterator};
91use crate::tensor::iterator::element::TensorElementIterator;
92use crate::tensor::iterator::viewdim::TensorDimIterator;
93use crate::tensor::iterator::windows::TensorWindowsIterator;
94
95impl<'a> TensorChunksIterator<'a> {
96    /// Collect this chunks iterator into a tensor with the provided shape.
97    ///
98    /// Gradient tracking: If any produced chunk tensor requires gradients and
99    /// gradients are enabled, the resulting tensor will preserve autograd
100    /// connections back to the original source tensor.
101    #[inline]
102    #[track_caller]
103    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
104        Tensor::collect_into_shape(self, dims)
105    }
106}
107
108impl<'a> TensorChunksExactIterator<'a> {
109    /// Collect this exact-sized chunks iterator into a tensor with the provided shape.
110    /// See [`TensorChunksIterator::collect_shape`] for gradient behavior.
111    #[inline]
112    #[track_caller]
113    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
114        Tensor::collect_into_shape(self, dims)
115    }
116}
117
118impl<'a> TensorWindowsIterator<'a> {
119    /// Collect this windows iterator into a tensor with the provided shape.
120    /// See [`TensorChunksIterator::collect_shape`] for gradient behavior.
121    #[inline]
122    #[track_caller]
123    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
124        Tensor::collect_into_shape(self, dims)
125    }
126}
127
128impl<'a> TensorDimIterator<'a> {
129    /// Collect this dimension iterator into a tensor with the provided shape.
130    /// See [`TensorChunksIterator::collect_shape`] for gradient behavior.
131    #[inline]
132    #[track_caller]
133    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
134        Tensor::collect_into_shape(self, dims)
135    }
136}
137
138impl<'a> TensorElementIterator<'a> {
139    /// Collect this element iterator into a tensor with the provided shape.
140    /// See [`TensorChunksIterator::collect_shape`] for gradient behavior.
141    #[inline]
142    #[track_caller]
143    pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
144        Tensor::collect_into_shape(self, dims)
145    }
146}
147
148impl Tensor {
149    /// Inherent helper to collect any iterator of `f32` into a shaped tensor.
150    /// This mirrors the `ValuesCollectExt::collect_shape` functionality but does
151    /// not require importing the extension trait.
152    #[inline]
153    #[track_caller]
154    pub fn collect_values_shape<I: IntoIterator<Item = f32>>(iter: I, dims: Vec<usize>) -> Tensor {
155        let total: usize = dims.iter().copied().product();
156        let mut out = Tensor::new_uninitialized(dims);
157        if total == 0 {
158            return out;
159        }
160
161        // Small sizes: trivial scalar fill is fastest
162        if total <= 64 {
163            unsafe {
164                let dst = out.as_mut_ptr();
165                let mut i = 0usize;
166                for v in iter {
167                    if i >= total {
168                        break;
169                    }
170                    *dst.add(i) = v;
171                    i += 1;
172                }
173                assert_eq!(
174                    i, total,
175                    "values collect_shape: provided iterator produced {} values, expected {}",
176                    i, total
177                );
178            }
179            return out;
180        }
181
182        // Large sizes: chunked buffer + SIMD-optimized copy into destination.
183        // This avoids allocating a full-size Vec while enabling wide stores/streams.
184        let mut it = iter.into_iter();
185        let chunk_elems = crate::tensor::core::memory::choose_fast_chunk_size(total);
186        let mut buffer: Vec<f32> = Vec::with_capacity(chunk_elems);
187
188        unsafe {
189            let dst = out.as_mut_ptr();
190            let mut written = 0usize;
191            while written < total {
192                buffer.clear();
193                let to_take = buffer.capacity().min(total - written);
194                for _ in 0..to_take {
195                    if let Some(v) = it.next() {
196                        buffer.push(v);
197                    } else {
198                        break;
199                    }
200                }
201                let got = buffer.len();
202                if got == 0 {
203                    break;
204                }
205                optimized_copy(buffer.as_ptr(), dst.add(written), got);
206                written += got;
207            }
208            assert_eq!(
209                written, total,
210                "values collect_shape: provided iterator produced {} values, expected {}",
211                written, total
212            );
213        }
214        out
215    }
216}
217
218// Optimized collection from Iterator<Item=f32>
219impl FromIterator<f32> for Tensor {
220    /// Collect f32 values into a 1D contiguous, SIMD-aligned Tensor
221    ///
222    /// - Streams directly when iterator reports exact size_hint
223    /// - Falls back to temporary Vec and optimized_copy otherwise
224    /// - No gradient tracking is set on the result
225    #[inline]
226    fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self {
227        let it = iter.into_iter();
228        // Try to stream if we know the exact length up-front
229        if let (lower, Some(upper)) = it.size_hint() {
230            if lower == upper {
231                let n = lower;
232                if n == 0 {
233                    return Tensor::new(vec![0]);
234                }
235                let mut out = Tensor::new_uninitialized(vec![n]);
236                unsafe {
237                    let dst = out.as_mut_ptr();
238                    let mut i = 0usize;
239                    for v in it {
240                        *dst.add(i) = v;
241                        i += 1;
242                    }
243                    debug_assert_eq!(i, n);
244                }
245                return out;
246            }
247        }
248
249        // Fallback: collect to Vec then SIMD copy
250        let v: Vec<f32> = it.collect();
251        let n = v.len();
252        if n == 0 {
253            return Tensor::new(vec![0]);
254        }
255        let mut out = Tensor::new_uninitialized(vec![n]);
256        unsafe { optimized_copy(v.as_ptr(), out.as_mut_ptr(), n) };
257        out
258    }
259}
260
261impl From<Vec<f32>> for Tensor {
262    /// Create a Tensor from a `Vec<f32>` by copying into an aligned/padded allocation.
263    ///
264    /// Note: We do not adopt the Vec's allocation to preserve alignment and padding
265    /// guarantees for SIMD operations. Use `Tensor::into_vec()` to extract data back.
266    #[inline]
267    #[track_caller]
268    fn from(v: Vec<f32>) -> Self {
269        let n = v.len();
270        if n == 0 {
271            return Tensor::new(vec![0]);
272        }
273        let mut out = Tensor::new_uninitialized(vec![n]);
274        unsafe { optimized_copy(v.as_ptr(), out.as_mut_ptr(), n) };
275        out
276    }
277}
278
279impl From<Tensor> for Vec<f32> {
280    /// Convert this tensor into a `Vec<f32>` in row-major order.
281    ///
282    /// - Contiguous fast path: single optimized copy
283    /// - Non-contiguous: materialize a contiguous copy first
284    /// - Gradient metadata is ignored; this is a pure data extraction API
285    #[inline]
286    #[track_caller]
287    fn from(tensor: Tensor) -> Vec<f32> {
288        let n = tensor.size();
289        if n == 0 {
290            return Vec::new();
291        }
292        // If already contiguous, copy directly; otherwise make a contiguous copy first
293        if tensor.is_contiguous() {
294            let mut v = vec![0.0f32; n];
295            unsafe {
296                crate::tensor::iterator::collect::optimized_copy(tensor.as_ptr(), v.as_mut_ptr(), n)
297            };
298            v
299        } else {
300            let c = tensor.contiguous();
301            let mut v = vec![0.0f32; n];
302            unsafe {
303                crate::tensor::iterator::collect::optimized_copy(c.as_ptr(), v.as_mut_ptr(), n)
304            };
305            v
306        }
307    }
308}
309
310/// Use SIMD-optimized copy when available; falls back to scalar/unrolled copies.
311#[inline]
312pub(crate) unsafe fn optimized_copy(src: *const f32, dst: *mut f32, count: usize) {
313    if count == 0 {
314        return;
315    }
316    if count <= 32 {
317        std::ptr::copy_nonoverlapping(src, dst, count);
318        return;
319    }
320
321    #[cfg(target_arch = "x86_64")]
322    {
323        match detect_runtime_simd() {
324            SimdLevel::Avx512 => {
325                if simd_copy_avx512_best(src, dst, count) {
326                    return;
327                }
328            }
329            SimdLevel::Avx2 => {
330                if simd_copy_avx2_best(src, dst, count) {
331                    return;
332                }
333            }
334            SimdLevel::Sse2 => {
335                if simd_copy_sse_best(src, dst, count) {
336                    return;
337                }
338            }
339            SimdLevel::Scalar => {}
340        }
341    }
342
343    scalar_copy_unrolled(src, dst, count);
344}
345
346#[cfg(target_arch = "x86_64")]
347#[inline]
348unsafe fn simd_copy_avx512_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
349    if !is_x86_feature_detected!("avx512f") || count < 16 {
350        return false;
351    }
352    let align = simd_alignment_bytes(SimdLevel::Avx512);
353    let src_mod = (src as usize) % align;
354    let dst_mod = (dst as usize) % align;
355    let src_al = src_mod == 0;
356    let dst_al = dst_mod == 0;
357    if src_al && dst_al {
358        simd_copy_avx512_aligned(src, dst, count);
359    } else if src_mod == dst_mod {
360        let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
361        let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
362        if elems_to_align > 0 && elems_to_align < count {
363            std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
364            let src2 = src.add(elems_to_align);
365            let dst2 = dst.add(elems_to_align);
366            let rem = count - elems_to_align;
367            simd_copy_avx512_aligned(src2, dst2, rem);
368        } else {
369            simd_copy_avx512_unaligned(src, dst, count);
370        }
371    } else {
372        simd_copy_avx512_unaligned(src, dst, count);
373    }
374    true
375}
376
377#[cfg(target_arch = "x86_64")]
378#[inline]
379#[target_feature(enable = "avx512f")]
380unsafe fn simd_copy_avx512_aligned(src: *const f32, dst: *mut f32, count: usize) {
381    use std::arch::x86_64::*;
382    // Prefetch and optionally stream for very large copies
383    let stream_threshold = crate::tensor::core::memory::stream_min_elems();
384    let pf_distance = crate::tensor::core::memory::prefetch_distance_elems();
385    let mut offset = 0usize;
386    let block = 64usize;
387    let n_blocks = count / block;
388    for _ in 0..n_blocks {
389        if pf_distance > 0 {
390            _mm_prefetch(src.add(offset + pf_distance) as *const i8, _MM_HINT_T0);
391        }
392        let a = _mm512_load_ps(src.add(offset));
393        let b = _mm512_load_ps(src.add(offset + 16));
394        let c = _mm512_load_ps(src.add(offset + 32));
395        let d = _mm512_load_ps(src.add(offset + 48));
396        if count >= stream_threshold {
397            _mm512_stream_ps(dst.add(offset), a);
398            _mm512_stream_ps(dst.add(offset + 16), b);
399            _mm512_stream_ps(dst.add(offset + 32), c);
400            _mm512_stream_ps(dst.add(offset + 48), d);
401        } else {
402            _mm512_store_ps(dst.add(offset), a);
403            _mm512_store_ps(dst.add(offset + 16), b);
404            _mm512_store_ps(dst.add(offset + 32), c);
405            _mm512_store_ps(dst.add(offset + 48), d);
406        }
407        offset += block;
408    }
409    let mut rem = count - offset;
410    while rem >= 16 {
411        let v = _mm512_load_ps(src.add(offset));
412        if count >= stream_threshold {
413            _mm512_stream_ps(dst.add(offset), v);
414        } else {
415            _mm512_store_ps(dst.add(offset), v);
416        }
417        offset += 16;
418        rem -= 16;
419    }
420    if rem > 0 {
421        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
422    }
423}
424
425#[cfg(target_arch = "x86_64")]
426#[inline]
427#[target_feature(enable = "avx512f")]
428unsafe fn simd_copy_avx512_unaligned(src: *const f32, dst: *mut f32, count: usize) {
429    use std::arch::x86_64::*;
430    let stream_threshold = crate::tensor::core::memory::stream_min_elems();
431    let pf_distance = crate::tensor::core::memory::prefetch_distance_elems();
432    let mut offset = 0usize;
433    let block = 64usize;
434    let n_blocks = count / block;
435    for _ in 0..n_blocks {
436        if pf_distance > 0 {
437            _mm_prefetch(src.add(offset + pf_distance) as *const i8, _MM_HINT_T0);
438        }
439        let a = _mm512_loadu_ps(src.add(offset));
440        let b = _mm512_loadu_ps(src.add(offset + 16));
441        let c = _mm512_loadu_ps(src.add(offset + 32));
442        let d = _mm512_loadu_ps(src.add(offset + 48));
443        if count >= stream_threshold {
444            _mm512_stream_ps(dst.add(offset), a);
445            _mm512_stream_ps(dst.add(offset + 16), b);
446            _mm512_stream_ps(dst.add(offset + 32), c);
447            _mm512_stream_ps(dst.add(offset + 48), d);
448        } else {
449            _mm512_storeu_ps(dst.add(offset), a);
450            _mm512_storeu_ps(dst.add(offset + 16), b);
451            _mm512_storeu_ps(dst.add(offset + 32), c);
452            _mm512_storeu_ps(dst.add(offset + 48), d);
453        }
454        offset += block;
455    }
456    let mut rem = count - offset;
457    while rem >= 16 {
458        let v = _mm512_loadu_ps(src.add(offset));
459        if count >= stream_threshold {
460            _mm512_stream_ps(dst.add(offset), v);
461        } else {
462            _mm512_storeu_ps(dst.add(offset), v);
463        }
464        offset += 16;
465        rem -= 16;
466    }
467    if rem > 0 {
468        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
469    }
470}
471
472#[cfg(target_arch = "x86_64")]
473#[inline]
474unsafe fn simd_copy_avx2_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
475    if !is_x86_feature_detected!("avx2") || count < 8 {
476        return false;
477    }
478    let align = simd_alignment_bytes(SimdLevel::Avx2);
479    let src_mod = (src as usize) % align;
480    let dst_mod = (dst as usize) % align;
481    let src_al = src_mod == 0;
482    let dst_al = dst_mod == 0;
483    if src_al && dst_al {
484        simd_copy_avx2_aligned(src, dst, count);
485    } else if src_mod == dst_mod {
486        let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
487        let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
488        if elems_to_align > 0 && elems_to_align < count {
489            std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
490            let src2 = src.add(elems_to_align);
491            let dst2 = dst.add(elems_to_align);
492            let rem = count - elems_to_align;
493            simd_copy_avx2_aligned(src2, dst2, rem);
494        } else {
495            simd_copy_avx2_unaligned(src, dst, count);
496        }
497    } else {
498        simd_copy_avx2_unaligned(src, dst, count);
499    }
500    true
501}
502
503#[cfg(target_arch = "x86_64")]
504#[inline]
505#[target_feature(enable = "avx2")]
506unsafe fn simd_copy_avx2_aligned(src: *const f32, dst: *mut f32, count: usize) {
507    use std::arch::x86_64::*;
508    let stream_threshold = crate::tensor::core::memory::stream_min_elems();
509    let pf_distance = crate::tensor::core::memory::prefetch_distance_elems();
510    let mut offset = 0usize;
511    let block = 32usize;
512    let n_blocks = count / block;
513    for _ in 0..n_blocks {
514        if pf_distance > 0 {
515            _mm_prefetch(src.add(offset + pf_distance) as *const i8, _MM_HINT_T0);
516        }
517        let v1 = _mm256_load_ps(src.add(offset));
518        let v2 = _mm256_load_ps(src.add(offset + 8));
519        let v3 = _mm256_load_ps(src.add(offset + 16));
520        let v4 = _mm256_load_ps(src.add(offset + 24));
521        if count >= stream_threshold {
522            _mm256_stream_ps(dst.add(offset), v1);
523            _mm256_stream_ps(dst.add(offset + 8), v2);
524            _mm256_stream_ps(dst.add(offset + 16), v3);
525            _mm256_stream_ps(dst.add(offset + 24), v4);
526        } else {
527            _mm256_store_ps(dst.add(offset), v1);
528            _mm256_store_ps(dst.add(offset + 8), v2);
529            _mm256_store_ps(dst.add(offset + 16), v3);
530            _mm256_store_ps(dst.add(offset + 24), v4);
531        }
532        offset += block;
533    }
534    let mut rem = count - offset;
535    while rem >= 8 {
536        let v = _mm256_load_ps(src.add(offset));
537        if count >= stream_threshold {
538            _mm256_stream_ps(dst.add(offset), v);
539        } else {
540            _mm256_store_ps(dst.add(offset), v);
541        }
542        offset += 8;
543        rem -= 8;
544    }
545    if rem > 0 {
546        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
547    }
548}
549
550#[cfg(target_arch = "x86_64")]
551#[inline]
552#[target_feature(enable = "avx2")]
553unsafe fn simd_copy_avx2_unaligned(src: *const f32, dst: *mut f32, count: usize) {
554    use std::arch::x86_64::*;
555    let stream_threshold = crate::tensor::core::memory::stream_min_elems();
556    let pf_distance = crate::tensor::core::memory::prefetch_distance_elems();
557    let mut offset = 0usize;
558    let block = 32usize;
559    let n_blocks = count / block;
560    for _ in 0..n_blocks {
561        if pf_distance > 0 {
562            _mm_prefetch(src.add(offset + pf_distance) as *const i8, _MM_HINT_T0);
563        }
564        let v1 = _mm256_loadu_ps(src.add(offset));
565        let v2 = _mm256_loadu_ps(src.add(offset + 8));
566        let v3 = _mm256_loadu_ps(src.add(offset + 16));
567        let v4 = _mm256_loadu_ps(src.add(offset + 24));
568        if count >= stream_threshold {
569            _mm256_stream_ps(dst.add(offset), v1);
570            _mm256_stream_ps(dst.add(offset + 8), v2);
571            _mm256_stream_ps(dst.add(offset + 16), v3);
572            _mm256_stream_ps(dst.add(offset + 24), v4);
573        } else {
574            _mm256_storeu_ps(dst.add(offset), v1);
575            _mm256_storeu_ps(dst.add(offset + 8), v2);
576            _mm256_storeu_ps(dst.add(offset + 16), v3);
577            _mm256_storeu_ps(dst.add(offset + 24), v4);
578        }
579        offset += block;
580    }
581    let mut rem = count - offset;
582    while rem >= 8 {
583        let v = _mm256_loadu_ps(src.add(offset));
584        if count >= stream_threshold {
585            _mm256_stream_ps(dst.add(offset), v);
586        } else {
587            _mm256_storeu_ps(dst.add(offset), v);
588        }
589        offset += 8;
590        rem -= 8;
591    }
592    if rem > 0 {
593        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
594    }
595}
596
597#[cfg(target_arch = "x86_64")]
598#[inline]
599unsafe fn simd_copy_sse_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
600    if !is_x86_feature_detected!("sse2") || count < 4 {
601        return false;
602    }
603    let align = simd_alignment_bytes(SimdLevel::Sse2);
604    let src_mod = (src as usize) % align;
605    let dst_mod = (dst as usize) % align;
606    let src_al = src_mod == 0;
607    let dst_al = dst_mod == 0;
608    if src_al && dst_al {
609        simd_copy_sse_aligned(src, dst, count);
610    } else if src_mod == dst_mod {
611        let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
612        let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
613        if elems_to_align > 0 && elems_to_align < count {
614            std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
615            let src2 = src.add(elems_to_align);
616            let dst2 = dst.add(elems_to_align);
617            let rem = count - elems_to_align;
618            simd_copy_sse_aligned(src2, dst2, rem);
619        } else {
620            simd_copy_sse_unaligned(src, dst, count);
621        }
622    } else {
623        simd_copy_sse_unaligned(src, dst, count);
624    }
625    true
626}
627
628#[cfg(target_arch = "x86_64")]
629#[inline]
630#[target_feature(enable = "sse2")]
631unsafe fn simd_copy_sse_aligned(src: *const f32, dst: *mut f32, count: usize) {
632    use std::arch::x86_64::*;
633    let mut offset = 0usize;
634    let block = 16usize;
635    let n_blocks = count / block;
636    for _ in 0..n_blocks {
637        let a = _mm_load_ps(src.add(offset));
638        let b = _mm_load_ps(src.add(offset + 4));
639        let c = _mm_load_ps(src.add(offset + 8));
640        let d = _mm_load_ps(src.add(offset + 12));
641        _mm_store_ps(dst.add(offset), a);
642        _mm_store_ps(dst.add(offset + 4), b);
643        _mm_store_ps(dst.add(offset + 8), c);
644        _mm_store_ps(dst.add(offset + 12), d);
645        offset += block;
646    }
647    let mut rem = count - offset;
648    while rem >= 4 {
649        let v = _mm_load_ps(src.add(offset));
650        _mm_store_ps(dst.add(offset), v);
651        offset += 4;
652        rem -= 4;
653    }
654    if rem > 0 {
655        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
656    }
657}
658
659#[cfg(target_arch = "x86_64")]
660#[inline]
661#[target_feature(enable = "sse2")]
662unsafe fn simd_copy_sse_unaligned(src: *const f32, dst: *mut f32, count: usize) {
663    use std::arch::x86_64::*;
664    let mut offset = 0usize;
665    let block = 16usize;
666    let n_blocks = count / block;
667    for _ in 0..n_blocks {
668        let a = _mm_loadu_ps(src.add(offset));
669        let b = _mm_loadu_ps(src.add(offset + 4));
670        let c = _mm_loadu_ps(src.add(offset + 8));
671        let d = _mm_loadu_ps(src.add(offset + 12));
672        _mm_storeu_ps(dst.add(offset), a);
673        _mm_storeu_ps(dst.add(offset + 4), b);
674        _mm_storeu_ps(dst.add(offset + 8), c);
675        _mm_storeu_ps(dst.add(offset + 12), d);
676        offset += block;
677    }
678    let mut rem = count - offset;
679    while rem >= 4 {
680        let v = _mm_loadu_ps(src.add(offset));
681        _mm_storeu_ps(dst.add(offset), v);
682        offset += 4;
683        rem -= 4;
684    }
685    if rem > 0 {
686        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
687    }
688}
689
690#[inline]
691unsafe fn scalar_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
692    let unroll = 8;
693    let blocks = count / unroll;
694    let mut offset = 0usize;
695    for _ in 0..blocks {
696        *dst.add(offset) = *src.add(offset);
697        *dst.add(offset + 1) = *src.add(offset + 1);
698        *dst.add(offset + 2) = *src.add(offset + 2);
699        *dst.add(offset + 3) = *src.add(offset + 3);
700        *dst.add(offset + 4) = *src.add(offset + 4);
701        *dst.add(offset + 5) = *src.add(offset + 5);
702        *dst.add(offset + 6) = *src.add(offset + 6);
703        *dst.add(offset + 7) = *src.add(offset + 7);
704        offset += unroll;
705    }
706    if offset < count {
707        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
708    }
709}
710
711/// Extension trait to collect iterator of tensors into provided shape.
712pub trait TensorCollectExt: Iterator<Item = Tensor> + Sized {
713    fn collect_shape(self, dims: Vec<usize>) -> Tensor;
714}
715
716impl<I> TensorCollectExt for I
717where
718    I: Iterator<Item = Tensor> + Sized,
719{
720    #[inline]
721    #[track_caller]
722    fn collect_shape(self, dims: Vec<usize>) -> Tensor {
723        Tensor::collect_into_shape(self, dims)
724    }
725}
726
727/// Extension trait to collect Iterator<Item=f32> directly into a shaped Tensor
728///
729/// This trait is automatically implemented for any `Iterator<Item = f32>`, so you can call
730/// `collect_shape()` directly on iterators yielding f32 values without importing the trait.
731pub trait ValuesCollectExt: Iterator<Item = f32> + Sized {
732    /// Collect f32 values from this iterator into a tensor with the specified shape
733    fn collect_shape(self, dims: Vec<usize>) -> Tensor;
734}
735
736impl<I> ValuesCollectExt for I
737where
738    I: Iterator<Item = f32> + Sized,
739{
740    #[inline]
741    #[track_caller]
742    fn collect_shape(self, dims: Vec<usize>) -> Tensor {
743        // Stream directly into the destination tensor to avoid large temporary buffers
744        Tensor::collect_values_shape(self, dims)
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751    use crate::gradtrack::NoGradTrack;
752
753    #[test]
754    fn test_collect_shape() {
755        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
756        let mat = t.chunks(2).collect_shape(vec![3, 2]);
757        assert_eq!(mat.shape().dims(), &[3, 2]);
758        assert_eq!(mat.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
759    }
760
761    #[test]
762    fn test_collect_shape_with_grad_preserves_backward() {
763        use crate::gradtrack::is_grad_enabled;
764        use crate::tensor::core::Tensor;
765
766        if !is_grad_enabled() {
767            // Ensure grad is enabled in normal test runs; skip if not
768        }
769
770        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4])
771            .unwrap()
772            .with_requires_grad();
773        // Split in chunks of 2, scale each chunk, collect with target shape [2,2]
774        let parts: Vec<Tensor> = t.chunks(2).map(|c| c.mul_scalar(3.0)).collect();
775        let y = parts.into_iter().collect_shape(vec![2, 2]);
776        assert!(y.requires_grad());
777
778        let mut loss = y.sum();
779        loss.backward(None);
780        let g = t.grad_owned().unwrap();
781        // Each input element appears exactly once in the collected tensor and is scaled by 3
782        assert_eq!(g.data(), &[3.0, 3.0, 3.0, 3.0]);
783    }
784
785    #[test]
786    fn test_collect_from_values_into_tensor() {
787        // Collect Iterator<f32> into Tensor via FromIterator<f32>
788        let vals = (0..16).map(|i| i as f32);
789        let t: Tensor = vals.collect();
790        assert_eq!(t.shape().dims(), &[16]);
791        assert_eq!(t.data()[0], 0.0);
792        assert_eq!(t.data()[15], 15.0);
793    }
794
795    #[test]
796    fn test_values_iter_then_collect_shape() {
797        // values() + collect into a shaped tensor using collect_into_shape
798        let base =
799            Tensor::from_slice(&(0..12).map(|i| i as f32).collect::<Vec<_>>(), vec![3, 4]).unwrap();
800        let collected: Tensor = base.iter_elements().map(|e| e.value()).collect();
801        assert_eq!(collected.shape().dims(), &[12]);
802        // reshape using view semantics
803        let shaped = collected.view(vec![3, 4]);
804        assert_eq!(shaped.shape().dims(), &[3, 4]);
805        assert_eq!(shaped.get(&[2, 3]), 11.0);
806    }
807
808    #[test]
809    fn test_values_collect_shape_direct() {
810        // No need to import trait - collect_shape is available on Iterator<Item=f32>
811        let shaped: Tensor = (0..12).map(|i| i as f32).collect_shape(vec![3, 4]);
812        assert_eq!(shaped.shape().dims(), &[3, 4]);
813        assert_eq!(shaped.get(&[0, 0]), 0.0);
814        assert_eq!(shaped.get(&[2, 3]), 11.0);
815    }
816
817    #[test]
818    fn test_collect_into_shape_exact_sizes_and_zero() {
819        // Zero-sized
820        let empty: Vec<Tensor> = Vec::new();
821        let out = Tensor::collect_into_shape(empty, vec![0]);
822        assert_eq!(out.size(), 0);
823
824        // Exact size match across chunks
825        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
826        let y = t.chunks(2).collect_shape(vec![2, 2]);
827        assert_eq!(y.shape().dims(), &[2, 2]);
828        assert_eq!(y.data(), &[1.0, 2.0, 3.0, 4.0]);
829    }
830
831    #[test]
832    fn test_collect_into_shape_no_grad_guard_fast_path() {
833        let t = Tensor::from_slice(&(0..8).map(|i| i as f32).collect::<Vec<_>>(), vec![8])
834            .unwrap()
835            .with_requires_grad();
836        let _guard = NoGradTrack::new();
837        let y = t
838            .iter_elements()
839            .map(|e| e.mul_scalar(2.0))
840            .collect_shape(vec![8]);
841        assert!(!y.requires_grad());
842        assert_eq!(y.size(), 8);
843    }
844}