1use crate::gradtrack::is_grad_enabled;
4#[cfg(target_arch = "x86_64")]
5use crate::tensor::core::memory::{detect_runtime_simd, simd_alignment_bytes, SimdLevel};
6use crate::tensor::core::Tensor;
7use std::iter::FromIterator;
8
9impl Tensor {
10 #[inline]
13 pub fn collect_into_shape<I: IntoIterator<Item = Tensor>>(iter: I, dims: Vec<usize>) -> Tensor {
14 let total: usize = dims.iter().copied().product();
15 let elements: Vec<Tensor> = iter.into_iter().collect();
16 let sum_sizes: usize = elements.iter().map(|t| t.size()).sum();
17 assert_eq!(
18 sum_sizes, total,
19 "collect_into_shape: element sizes {} do not match target size {}",
20 sum_sizes, total
21 );
22 let requires_grad = elements.iter().any(|t| t.requires_grad()) && is_grad_enabled();
23
24 if requires_grad {
25 let mut flat_parts: Vec<Tensor> = Vec::with_capacity(elements.len());
28 for t in elements.into_iter() {
29 flat_parts.push(t.flatten());
31 }
32 let concatenated = Tensor::cat(&flat_parts, 0); let new_shape: Vec<i32> = dims.iter().map(|&d| d as i32).collect();
35 let out = concatenated.view(new_shape);
36 return out;
37 }
38
39 let mut result = Tensor::new_uninitialized(dims);
41 let mut offset = 0usize;
42 unsafe {
43 let dst = result.as_mut_ptr();
44 for t in &elements {
45 let sz = t.size();
46 if sz == 0 {
47 continue;
48 }
49 optimized_copy(t.as_ptr(), dst.add(offset), sz);
50 offset += sz;
51 }
52 }
53 result
54 }
55}
56
57use crate::tensor::iterator::chunks::{TensorChunksExactIterator, TensorChunksIterator};
64use crate::tensor::iterator::element::TensorElementIterator;
65use crate::tensor::iterator::viewdim::TensorDimIterator;
66use crate::tensor::iterator::windows::TensorWindowsIterator;
67
68impl<'a> TensorChunksIterator<'a> {
69 #[inline]
75 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
76 Tensor::collect_into_shape(self, dims)
77 }
78}
79
80impl<'a> TensorChunksExactIterator<'a> {
81 #[inline]
84 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
85 Tensor::collect_into_shape(self, dims)
86 }
87}
88
89impl<'a> TensorWindowsIterator<'a> {
90 #[inline]
93 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
94 Tensor::collect_into_shape(self, dims)
95 }
96}
97
98impl<'a> TensorDimIterator<'a> {
99 #[inline]
102 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
103 Tensor::collect_into_shape(self, dims)
104 }
105}
106
107impl<'a> TensorElementIterator<'a> {
108 #[inline]
111 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
112 Tensor::collect_into_shape(self, dims)
113 }
114}
115
116impl Tensor {
117 #[inline]
131 pub fn collect_shape_from<I: IntoIterator<Item = Tensor>>(iter: I, dims: Vec<usize>) -> Tensor {
132 Tensor::collect_into_shape(iter, dims)
133 }
134
135 #[inline]
139 pub fn collect_values_shape<I: IntoIterator<Item = f32>>(iter: I, dims: Vec<usize>) -> Tensor {
140 let total: usize = dims.iter().copied().product();
141 let mut out = Tensor::new_uninitialized(dims);
142 if total == 0 {
143 return out;
144 }
145 unsafe {
146 let dst = out.as_mut_ptr();
147 let mut i = 0usize;
148 for v in iter {
149 if i >= total {
150 break;
151 }
152 *dst.add(i) = v;
153 i += 1;
154 }
155 assert_eq!(
156 i, total,
157 "values collect_shape: provided iterator produced {} values, expected {}",
158 i, total
159 );
160 }
161 out
162 }
163}
164
165impl FromIterator<f32> for Tensor {
167 #[inline]
173 fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self {
174 let v: Vec<f32> = iter.into_iter().collect();
178 let n = v.len();
179 if n == 0 {
180 return Tensor::new(vec![0]);
181 }
182
183 let mut out = Tensor::new_uninitialized(vec![n]);
184 unsafe {
185 optimized_copy(v.as_ptr(), out.as_mut_ptr(), n);
186 }
187 out
188 }
189}
190
191#[inline]
193pub(crate) unsafe fn optimized_copy(src: *const f32, dst: *mut f32, count: usize) {
194 if count == 0 {
195 return;
196 }
197 if count <= 32 {
198 std::ptr::copy_nonoverlapping(src, dst, count);
199 return;
200 }
201
202 #[cfg(target_arch = "x86_64")]
203 {
204 match detect_runtime_simd() {
205 SimdLevel::Avx512 => {
206 if simd_copy_avx512_best(src, dst, count) {
207 return;
208 }
209 }
210 SimdLevel::Avx2 => {
211 if simd_copy_avx2_best(src, dst, count) {
212 return;
213 }
214 }
215 SimdLevel::Sse2 => {
216 if simd_copy_sse_best(src, dst, count) {
217 return;
218 }
219 }
220 SimdLevel::Scalar => {}
221 }
222 }
223
224 scalar_copy_unrolled(src, dst, count);
225}
226
227#[cfg(target_arch = "x86_64")]
228#[inline]
229unsafe fn simd_copy_avx512_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
230 if !is_x86_feature_detected!("avx512f") || count < 16 {
231 return false;
232 }
233 let align = simd_alignment_bytes(SimdLevel::Avx512);
234 let src_mod = (src as usize) % align;
235 let dst_mod = (dst as usize) % align;
236 let src_al = src_mod == 0;
237 let dst_al = dst_mod == 0;
238 if src_al && dst_al {
239 simd_copy_avx512_aligned(src, dst, count);
240 } else if src_mod == dst_mod {
241 let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
242 let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
243 if elems_to_align > 0 && elems_to_align < count {
244 std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
245 let src2 = src.add(elems_to_align);
246 let dst2 = dst.add(elems_to_align);
247 let rem = count - elems_to_align;
248 simd_copy_avx512_aligned(src2, dst2, rem);
249 } else {
250 simd_copy_avx512_unaligned(src, dst, count);
251 }
252 } else {
253 simd_copy_avx512_unaligned(src, dst, count);
254 }
255 true
256}
257
258#[cfg(target_arch = "x86_64")]
259#[inline]
260#[target_feature(enable = "avx512f")]
261unsafe fn simd_copy_avx512_aligned(src: *const f32, dst: *mut f32, count: usize) {
262 use std::arch::x86_64::*;
263 let mut offset = 0usize;
264 let block = 64usize;
265 let n_blocks = count / block;
266 for _ in 0..n_blocks {
267 let a = _mm512_load_ps(src.add(offset));
268 let b = _mm512_load_ps(src.add(offset + 16));
269 let c = _mm512_load_ps(src.add(offset + 32));
270 let d = _mm512_load_ps(src.add(offset + 48));
271 _mm512_store_ps(dst.add(offset), a);
272 _mm512_store_ps(dst.add(offset + 16), b);
273 _mm512_store_ps(dst.add(offset + 32), c);
274 _mm512_store_ps(dst.add(offset + 48), d);
275 offset += block;
276 }
277 let mut rem = count - offset;
278 while rem >= 16 {
279 let v = _mm512_load_ps(src.add(offset));
280 _mm512_store_ps(dst.add(offset), v);
281 offset += 16;
282 rem -= 16;
283 }
284 if rem > 0 {
285 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
286 }
287}
288
289#[cfg(target_arch = "x86_64")]
290#[inline]
291#[target_feature(enable = "avx512f")]
292unsafe fn simd_copy_avx512_unaligned(src: *const f32, dst: *mut f32, count: usize) {
293 use std::arch::x86_64::*;
294 let mut offset = 0usize;
295 let block = 64usize;
296 let n_blocks = count / block;
297 for _ in 0..n_blocks {
298 let a = _mm512_loadu_ps(src.add(offset));
299 let b = _mm512_loadu_ps(src.add(offset + 16));
300 let c = _mm512_loadu_ps(src.add(offset + 32));
301 let d = _mm512_loadu_ps(src.add(offset + 48));
302 _mm512_storeu_ps(dst.add(offset), a);
303 _mm512_storeu_ps(dst.add(offset + 16), b);
304 _mm512_storeu_ps(dst.add(offset + 32), c);
305 _mm512_storeu_ps(dst.add(offset + 48), d);
306 offset += block;
307 }
308 let mut rem = count - offset;
309 while rem >= 16 {
310 let v = _mm512_loadu_ps(src.add(offset));
311 _mm512_storeu_ps(dst.add(offset), v);
312 offset += 16;
313 rem -= 16;
314 }
315 if rem > 0 {
316 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
317 }
318}
319
320#[cfg(target_arch = "x86_64")]
321#[inline]
322unsafe fn simd_copy_avx2_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
323 if !is_x86_feature_detected!("avx2") || count < 8 {
324 return false;
325 }
326 let align = simd_alignment_bytes(SimdLevel::Avx2);
327 let src_mod = (src as usize) % align;
328 let dst_mod = (dst as usize) % align;
329 let src_al = src_mod == 0;
330 let dst_al = dst_mod == 0;
331 if src_al && dst_al {
332 simd_copy_avx2_aligned(src, dst, count);
333 } else if src_mod == dst_mod {
334 let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
335 let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
336 if elems_to_align > 0 && elems_to_align < count {
337 std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
338 let src2 = src.add(elems_to_align);
339 let dst2 = dst.add(elems_to_align);
340 let rem = count - elems_to_align;
341 simd_copy_avx2_aligned(src2, dst2, rem);
342 } else {
343 simd_copy_avx2_unaligned(src, dst, count);
344 }
345 } else {
346 simd_copy_avx2_unaligned(src, dst, count);
347 }
348 true
349}
350
351#[cfg(target_arch = "x86_64")]
352#[inline]
353#[target_feature(enable = "avx2")]
354unsafe fn simd_copy_avx2_aligned(src: *const f32, dst: *mut f32, count: usize) {
355 use std::arch::x86_64::*;
356 let mut offset = 0usize;
357 let block = 32usize;
358 let n_blocks = count / block;
359 for _ in 0..n_blocks {
360 let v1 = _mm256_load_ps(src.add(offset));
361 let v2 = _mm256_load_ps(src.add(offset + 8));
362 let v3 = _mm256_load_ps(src.add(offset + 16));
363 let v4 = _mm256_load_ps(src.add(offset + 24));
364 _mm256_store_ps(dst.add(offset), v1);
365 _mm256_store_ps(dst.add(offset + 8), v2);
366 _mm256_store_ps(dst.add(offset + 16), v3);
367 _mm256_store_ps(dst.add(offset + 24), v4);
368 offset += block;
369 }
370 let mut rem = count - offset;
371 while rem >= 8 {
372 let v = _mm256_load_ps(src.add(offset));
373 _mm256_store_ps(dst.add(offset), v);
374 offset += 8;
375 rem -= 8;
376 }
377 if rem > 0 {
378 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
379 }
380}
381
382#[cfg(target_arch = "x86_64")]
383#[inline]
384#[target_feature(enable = "avx2")]
385unsafe fn simd_copy_avx2_unaligned(src: *const f32, dst: *mut f32, count: usize) {
386 use std::arch::x86_64::*;
387 let mut offset = 0usize;
388 let block = 32usize;
389 let n_blocks = count / block;
390 for _ in 0..n_blocks {
391 let v1 = _mm256_loadu_ps(src.add(offset));
392 let v2 = _mm256_loadu_ps(src.add(offset + 8));
393 let v3 = _mm256_loadu_ps(src.add(offset + 16));
394 let v4 = _mm256_loadu_ps(src.add(offset + 24));
395 _mm256_storeu_ps(dst.add(offset), v1);
396 _mm256_storeu_ps(dst.add(offset + 8), v2);
397 _mm256_storeu_ps(dst.add(offset + 16), v3);
398 _mm256_storeu_ps(dst.add(offset + 24), v4);
399 offset += block;
400 }
401 let mut rem = count - offset;
402 while rem >= 8 {
403 let v = _mm256_loadu_ps(src.add(offset));
404 _mm256_storeu_ps(dst.add(offset), v);
405 offset += 8;
406 rem -= 8;
407 }
408 if rem > 0 {
409 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
410 }
411}
412
413#[cfg(target_arch = "x86_64")]
414#[inline]
415unsafe fn simd_copy_sse_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
416 if !is_x86_feature_detected!("sse2") || count < 4 {
417 return false;
418 }
419 let align = simd_alignment_bytes(SimdLevel::Sse2);
420 let src_mod = (src as usize) % align;
421 let dst_mod = (dst as usize) % align;
422 let src_al = src_mod == 0;
423 let dst_al = dst_mod == 0;
424 if src_al && dst_al {
425 simd_copy_sse_aligned(src, dst, count);
426 } else if src_mod == dst_mod {
427 let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
428 let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
429 if elems_to_align > 0 && elems_to_align < count {
430 std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
431 let src2 = src.add(elems_to_align);
432 let dst2 = dst.add(elems_to_align);
433 let rem = count - elems_to_align;
434 simd_copy_sse_aligned(src2, dst2, rem);
435 } else {
436 simd_copy_sse_unaligned(src, dst, count);
437 }
438 } else {
439 simd_copy_sse_unaligned(src, dst, count);
440 }
441 true
442}
443
444#[cfg(target_arch = "x86_64")]
445#[inline]
446#[target_feature(enable = "sse2")]
447unsafe fn simd_copy_sse_aligned(src: *const f32, dst: *mut f32, count: usize) {
448 use std::arch::x86_64::*;
449 let mut offset = 0usize;
450 let block = 16usize;
451 let n_blocks = count / block;
452 for _ in 0..n_blocks {
453 let a = _mm_load_ps(src.add(offset));
454 let b = _mm_load_ps(src.add(offset + 4));
455 let c = _mm_load_ps(src.add(offset + 8));
456 let d = _mm_load_ps(src.add(offset + 12));
457 _mm_store_ps(dst.add(offset), a);
458 _mm_store_ps(dst.add(offset + 4), b);
459 _mm_store_ps(dst.add(offset + 8), c);
460 _mm_store_ps(dst.add(offset + 12), d);
461 offset += block;
462 }
463 let mut rem = count - offset;
464 while rem >= 4 {
465 let v = _mm_load_ps(src.add(offset));
466 _mm_store_ps(dst.add(offset), v);
467 offset += 4;
468 rem -= 4;
469 }
470 if rem > 0 {
471 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
472 }
473}
474
475#[cfg(target_arch = "x86_64")]
476#[inline]
477#[target_feature(enable = "sse2")]
478unsafe fn simd_copy_sse_unaligned(src: *const f32, dst: *mut f32, count: usize) {
479 use std::arch::x86_64::*;
480 let mut offset = 0usize;
481 let block = 16usize;
482 let n_blocks = count / block;
483 for _ in 0..n_blocks {
484 let a = _mm_loadu_ps(src.add(offset));
485 let b = _mm_loadu_ps(src.add(offset + 4));
486 let c = _mm_loadu_ps(src.add(offset + 8));
487 let d = _mm_loadu_ps(src.add(offset + 12));
488 _mm_storeu_ps(dst.add(offset), a);
489 _mm_storeu_ps(dst.add(offset + 4), b);
490 _mm_storeu_ps(dst.add(offset + 8), c);
491 _mm_storeu_ps(dst.add(offset + 12), d);
492 offset += block;
493 }
494 let mut rem = count - offset;
495 while rem >= 4 {
496 let v = _mm_loadu_ps(src.add(offset));
497 _mm_storeu_ps(dst.add(offset), v);
498 offset += 4;
499 rem -= 4;
500 }
501 if rem > 0 {
502 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
503 }
504}
505
506#[inline]
507unsafe fn scalar_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
508 let unroll = 8;
509 let blocks = count / unroll;
510 let mut offset = 0usize;
511 for _ in 0..blocks {
512 *dst.add(offset) = *src.add(offset);
513 *dst.add(offset + 1) = *src.add(offset + 1);
514 *dst.add(offset + 2) = *src.add(offset + 2);
515 *dst.add(offset + 3) = *src.add(offset + 3);
516 *dst.add(offset + 4) = *src.add(offset + 4);
517 *dst.add(offset + 5) = *src.add(offset + 5);
518 *dst.add(offset + 6) = *src.add(offset + 6);
519 *dst.add(offset + 7) = *src.add(offset + 7);
520 offset += unroll;
521 }
522 if offset < count {
523 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
524 }
525}
526
527pub trait TensorCollectExt: Iterator<Item = Tensor> + Sized {
529 fn collect_shape(self, dims: Vec<usize>) -> Tensor;
530}
531
532impl<I> TensorCollectExt for I
533where
534 I: Iterator<Item = Tensor> + Sized,
535{
536 #[inline]
537 fn collect_shape(self, dims: Vec<usize>) -> Tensor {
538 Tensor::collect_into_shape(self, dims)
539 }
540}
541
542pub trait ValuesCollectExt: Iterator<Item = f32> + Sized {
547 fn collect_shape(self, dims: Vec<usize>) -> Tensor;
549}
550
551impl<I> ValuesCollectExt for I
552where
553 I: Iterator<Item = f32> + Sized,
554{
555 #[inline]
556 fn collect_shape(self, dims: Vec<usize>) -> Tensor {
557 let total: usize = dims.iter().copied().product();
558 let mut out = Tensor::new_uninitialized(dims);
559 if total == 0 {
560 return out;
561 }
562
563 if total <= 64 {
565 unsafe {
566 let dst = out.as_mut_ptr();
567 let mut i = 0usize;
568 for v in self {
569 if i >= total {
570 break;
571 }
572 *dst.add(i) = v;
573 i += 1;
574 }
575 assert_eq!(
576 i, total,
577 "values collect_shape: provided iterator produced {} values, expected {}",
578 i, total
579 );
580 }
581 return out;
582 }
583
584 let temp_data: Vec<f32> = self.collect();
586 assert_eq!(
587 temp_data.len(),
588 total,
589 "values collect_shape: provided iterator produced {} values, expected {}",
590 temp_data.len(),
591 total
592 );
593
594 unsafe {
595 optimized_copy(temp_data.as_ptr(), out.as_mut_ptr(), total);
596 }
597 out
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn test_collect_shape() {
607 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
608 let mat = t.iter_chunks(2).collect_shape(vec![3, 2]);
609 assert_eq!(mat.shape().dims(), &[3, 2]);
610 assert_eq!(mat.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
611 }
612
613 #[test]
614 fn test_collect_shape_with_grad_preserves_backward() {
615 use crate::gradtrack::is_grad_enabled;
616 use crate::tensor::core::Tensor;
617
618 if !is_grad_enabled() {
619 }
621
622 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4])
623 .unwrap()
624 .with_requires_grad();
625 let parts: Vec<Tensor> = t.iter_chunks(2).map(|c| c.mul_scalar(3.0)).collect();
627 let y = parts.into_iter().collect_shape(vec![2, 2]);
628 assert!(y.requires_grad());
629
630 let mut loss = y.sum();
631 loss.backward(None);
632 let g = t.grad_owned().unwrap();
633 assert_eq!(g.data(), &[3.0, 3.0, 3.0, 3.0]);
635 }
636
637 #[test]
638 fn test_collect_from_values_into_tensor() {
639 let vals = (0..16).map(|i| i as f32);
641 let t: Tensor = vals.collect();
642 assert_eq!(t.shape().dims(), &[16]);
643 assert_eq!(t.data()[0], 0.0);
644 assert_eq!(t.data()[15], 15.0);
645 }
646
647 #[test]
648 fn test_values_iter_then_collect_shape() {
649 let base =
651 Tensor::from_slice(&(0..12).map(|i| i as f32).collect::<Vec<_>>(), vec![3, 4]).unwrap();
652 let flat_vals = base.iter_values();
653 let collected: Tensor = flat_vals.collect();
654 assert_eq!(collected.shape().dims(), &[12]);
655 let shaped = collected.view(vec![3, 4]);
657 assert_eq!(shaped.shape().dims(), &[3, 4]);
658 assert_eq!(shaped.get(&[2, 3]), 11.0);
659 }
660
661 #[test]
662 fn test_values_collect_shape_direct() {
663 let shaped: Tensor = (0..12).map(|i| i as f32).collect_shape(vec![3, 4]);
665 assert_eq!(shaped.shape().dims(), &[3, 4]);
666 assert_eq!(shaped.get(&[0, 0]), 0.0);
667 assert_eq!(shaped.get(&[2, 3]), 11.0);
668 }
669}