1use crate::gradtrack::is_grad_enabled;
4#[cfg(target_arch = "x86_64")]
5use crate::tensor::core::memory::{detect_runtime_simd, simd_alignment_bytes, SimdLevel};
6use crate::tensor::core::Tensor;
7use std::iter::FromIterator;
8
9impl Tensor {
10 #[inline]
13 #[track_caller]
14 pub fn collect_into_shape<I: IntoIterator<Item = Tensor>>(iter: I, dims: Vec<usize>) -> Tensor {
15 let total: usize = dims.iter().copied().product();
16
17 if !is_grad_enabled() {
20 let mut result = Tensor::new_uninitialized(dims);
21 let mut offset = 0usize;
22 let mut sum_sizes = 0usize;
23 unsafe {
24 let dst = result.as_mut_ptr();
25 for t in iter.into_iter() {
26 let sz = t.size();
27 if sz == 0 {
28 continue;
29 }
30 optimized_copy(t.as_ptr(), dst.add(offset), sz);
31 offset += sz;
32 sum_sizes += sz;
33 }
34 }
35 assert_eq!(
36 sum_sizes, total,
37 "collect_into_shape: element sizes {} do not match target size {}",
38 sum_sizes, total
39 );
40 return result;
41 }
42
43 let elements: Vec<Tensor> = iter.into_iter().collect();
45 let sum_sizes: usize = elements.iter().map(|t| t.size()).sum();
46 assert_eq!(
47 sum_sizes, total,
48 "collect_into_shape: element sizes {} do not match target size {}",
49 sum_sizes, total
50 );
51 let requires_grad = elements.iter().any(|t| t.requires_grad());
52
53 if requires_grad {
54 let mut flat_parts: Vec<Tensor> = Vec::with_capacity(elements.len());
57 for t in elements.into_iter() {
58 flat_parts.push(t.flatten());
59 }
60 let concatenated = Tensor::cat(&flat_parts, 0); let new_shape: Vec<i32> = dims.iter().map(|&d| d as i32).collect();
62 let out = concatenated.view(new_shape);
63 return out;
64 }
65
66 let mut result = Tensor::new_uninitialized(dims);
68 let mut offset = 0usize;
69 unsafe {
70 let dst = result.as_mut_ptr();
71 for t in &elements {
72 let sz = t.size();
73 if sz == 0 {
74 continue;
75 }
76 optimized_copy(t.as_ptr(), dst.add(offset), sz);
77 offset += sz;
78 }
79 }
80 result
81 }
82}
83
84use crate::tensor::iterator::chunks::{TensorChunksExactIterator, TensorChunksIterator};
91use crate::tensor::iterator::element::TensorElementIterator;
92use crate::tensor::iterator::viewdim::TensorDimIterator;
93use crate::tensor::iterator::windows::TensorWindowsIterator;
94
95impl<'a> TensorChunksIterator<'a> {
96 #[inline]
102 #[track_caller]
103 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
104 Tensor::collect_into_shape(self, dims)
105 }
106}
107
108impl<'a> TensorChunksExactIterator<'a> {
109 #[inline]
112 #[track_caller]
113 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
114 Tensor::collect_into_shape(self, dims)
115 }
116}
117
118impl<'a> TensorWindowsIterator<'a> {
119 #[inline]
122 #[track_caller]
123 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
124 Tensor::collect_into_shape(self, dims)
125 }
126}
127
128impl<'a> TensorDimIterator<'a> {
129 #[inline]
132 #[track_caller]
133 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
134 Tensor::collect_into_shape(self, dims)
135 }
136}
137
138impl<'a> TensorElementIterator<'a> {
139 #[inline]
142 #[track_caller]
143 pub fn collect_shape(self, dims: Vec<usize>) -> Tensor {
144 Tensor::collect_into_shape(self, dims)
145 }
146}
147
148impl Tensor {
149 #[inline]
153 #[track_caller]
154 pub fn collect_values_shape<I: IntoIterator<Item = f32>>(iter: I, dims: Vec<usize>) -> Tensor {
155 let total: usize = dims.iter().copied().product();
156 let mut out = Tensor::new_uninitialized(dims);
157 if total == 0 {
158 return out;
159 }
160
161 if total <= 64 {
163 unsafe {
164 let dst = out.as_mut_ptr();
165 let mut i = 0usize;
166 for v in iter {
167 if i >= total {
168 break;
169 }
170 *dst.add(i) = v;
171 i += 1;
172 }
173 assert_eq!(
174 i, total,
175 "values collect_shape: provided iterator produced {} values, expected {}",
176 i, total
177 );
178 }
179 return out;
180 }
181
182 let mut it = iter.into_iter();
185 let chunk_elems = crate::tensor::core::memory::choose_fast_chunk_size(total);
186 let mut buffer: Vec<f32> = Vec::with_capacity(chunk_elems);
187
188 unsafe {
189 let dst = out.as_mut_ptr();
190 let mut written = 0usize;
191 while written < total {
192 buffer.clear();
193 let to_take = buffer.capacity().min(total - written);
194 for _ in 0..to_take {
195 if let Some(v) = it.next() {
196 buffer.push(v);
197 } else {
198 break;
199 }
200 }
201 let got = buffer.len();
202 if got == 0 {
203 break;
204 }
205 optimized_copy(buffer.as_ptr(), dst.add(written), got);
206 written += got;
207 }
208 assert_eq!(
209 written, total,
210 "values collect_shape: provided iterator produced {} values, expected {}",
211 written, total
212 );
213 }
214 out
215 }
216}
217
218impl FromIterator<f32> for Tensor {
220 #[inline]
226 fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self {
227 let it = iter.into_iter();
228 if let (lower, Some(upper)) = it.size_hint() {
230 if lower == upper {
231 let n = lower;
232 if n == 0 {
233 return Tensor::new(vec![0]);
234 }
235 let mut out = Tensor::new_uninitialized(vec![n]);
236 unsafe {
237 let dst = out.as_mut_ptr();
238 let mut i = 0usize;
239 for v in it {
240 *dst.add(i) = v;
241 i += 1;
242 }
243 debug_assert_eq!(i, n);
244 }
245 return out;
246 }
247 }
248
249 let v: Vec<f32> = it.collect();
251 let n = v.len();
252 if n == 0 {
253 return Tensor::new(vec![0]);
254 }
255 let mut out = Tensor::new_uninitialized(vec![n]);
256 unsafe { optimized_copy(v.as_ptr(), out.as_mut_ptr(), n) };
257 out
258 }
259}
260
261impl From<Vec<f32>> for Tensor {
262 #[inline]
267 #[track_caller]
268 fn from(v: Vec<f32>) -> Self {
269 let n = v.len();
270 if n == 0 {
271 return Tensor::new(vec![0]);
272 }
273 let mut out = Tensor::new_uninitialized(vec![n]);
274 unsafe { optimized_copy(v.as_ptr(), out.as_mut_ptr(), n) };
275 out
276 }
277}
278
279impl From<Tensor> for Vec<f32> {
280 #[inline]
286 #[track_caller]
287 fn from(tensor: Tensor) -> Vec<f32> {
288 let n = tensor.size();
289 if n == 0 {
290 return Vec::new();
291 }
292 if tensor.is_contiguous() {
294 let mut v = vec![0.0f32; n];
295 unsafe {
296 crate::tensor::iterator::collect::optimized_copy(tensor.as_ptr(), v.as_mut_ptr(), n)
297 };
298 v
299 } else {
300 let c = tensor.contiguous();
301 let mut v = vec![0.0f32; n];
302 unsafe {
303 crate::tensor::iterator::collect::optimized_copy(c.as_ptr(), v.as_mut_ptr(), n)
304 };
305 v
306 }
307 }
308}
309
310#[inline]
312pub(crate) unsafe fn optimized_copy(src: *const f32, dst: *mut f32, count: usize) {
313 if count == 0 {
314 return;
315 }
316 if count <= 32 {
317 std::ptr::copy_nonoverlapping(src, dst, count);
318 return;
319 }
320
321 #[cfg(target_arch = "x86_64")]
322 {
323 match detect_runtime_simd() {
324 SimdLevel::Avx512 => {
325 if simd_copy_avx512_best(src, dst, count) {
326 return;
327 }
328 }
329 SimdLevel::Avx2 => {
330 if simd_copy_avx2_best(src, dst, count) {
331 return;
332 }
333 }
334 SimdLevel::Sse2 => {
335 if simd_copy_sse_best(src, dst, count) {
336 return;
337 }
338 }
339 SimdLevel::Scalar => {}
340 }
341 }
342
343 scalar_copy_unrolled(src, dst, count);
344}
345
346#[cfg(target_arch = "x86_64")]
347#[inline]
348unsafe fn simd_copy_avx512_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
349 if !is_x86_feature_detected!("avx512f") || count < 16 {
350 return false;
351 }
352 let align = simd_alignment_bytes(SimdLevel::Avx512);
353 let src_mod = (src as usize) % align;
354 let dst_mod = (dst as usize) % align;
355 let src_al = src_mod == 0;
356 let dst_al = dst_mod == 0;
357 if src_al && dst_al {
358 simd_copy_avx512_aligned(src, dst, count);
359 } else if src_mod == dst_mod {
360 let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
361 let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
362 if elems_to_align > 0 && elems_to_align < count {
363 std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
364 let src2 = src.add(elems_to_align);
365 let dst2 = dst.add(elems_to_align);
366 let rem = count - elems_to_align;
367 simd_copy_avx512_aligned(src2, dst2, rem);
368 } else {
369 simd_copy_avx512_unaligned(src, dst, count);
370 }
371 } else {
372 simd_copy_avx512_unaligned(src, dst, count);
373 }
374 true
375}
376
377#[cfg(target_arch = "x86_64")]
378#[inline]
379#[target_feature(enable = "avx512f")]
380unsafe fn simd_copy_avx512_aligned(src: *const f32, dst: *mut f32, count: usize) {
381 use std::arch::x86_64::*;
382 let stream_threshold = crate::tensor::core::memory::stream_min_elems();
384 let pf_distance = crate::tensor::core::memory::prefetch_distance_elems();
385 let mut offset = 0usize;
386 let block = 64usize;
387 let n_blocks = count / block;
388 for _ in 0..n_blocks {
389 if pf_distance > 0 {
390 _mm_prefetch(src.add(offset + pf_distance) as *const i8, _MM_HINT_T0);
391 }
392 let a = _mm512_load_ps(src.add(offset));
393 let b = _mm512_load_ps(src.add(offset + 16));
394 let c = _mm512_load_ps(src.add(offset + 32));
395 let d = _mm512_load_ps(src.add(offset + 48));
396 if count >= stream_threshold {
397 _mm512_stream_ps(dst.add(offset), a);
398 _mm512_stream_ps(dst.add(offset + 16), b);
399 _mm512_stream_ps(dst.add(offset + 32), c);
400 _mm512_stream_ps(dst.add(offset + 48), d);
401 } else {
402 _mm512_store_ps(dst.add(offset), a);
403 _mm512_store_ps(dst.add(offset + 16), b);
404 _mm512_store_ps(dst.add(offset + 32), c);
405 _mm512_store_ps(dst.add(offset + 48), d);
406 }
407 offset += block;
408 }
409 let mut rem = count - offset;
410 while rem >= 16 {
411 let v = _mm512_load_ps(src.add(offset));
412 if count >= stream_threshold {
413 _mm512_stream_ps(dst.add(offset), v);
414 } else {
415 _mm512_store_ps(dst.add(offset), v);
416 }
417 offset += 16;
418 rem -= 16;
419 }
420 if rem > 0 {
421 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
422 }
423}
424
425#[cfg(target_arch = "x86_64")]
426#[inline]
427#[target_feature(enable = "avx512f")]
428unsafe fn simd_copy_avx512_unaligned(src: *const f32, dst: *mut f32, count: usize) {
429 use std::arch::x86_64::*;
430 let stream_threshold = crate::tensor::core::memory::stream_min_elems();
431 let pf_distance = crate::tensor::core::memory::prefetch_distance_elems();
432 let mut offset = 0usize;
433 let block = 64usize;
434 let n_blocks = count / block;
435 for _ in 0..n_blocks {
436 if pf_distance > 0 {
437 _mm_prefetch(src.add(offset + pf_distance) as *const i8, _MM_HINT_T0);
438 }
439 let a = _mm512_loadu_ps(src.add(offset));
440 let b = _mm512_loadu_ps(src.add(offset + 16));
441 let c = _mm512_loadu_ps(src.add(offset + 32));
442 let d = _mm512_loadu_ps(src.add(offset + 48));
443 if count >= stream_threshold {
444 _mm512_stream_ps(dst.add(offset), a);
445 _mm512_stream_ps(dst.add(offset + 16), b);
446 _mm512_stream_ps(dst.add(offset + 32), c);
447 _mm512_stream_ps(dst.add(offset + 48), d);
448 } else {
449 _mm512_storeu_ps(dst.add(offset), a);
450 _mm512_storeu_ps(dst.add(offset + 16), b);
451 _mm512_storeu_ps(dst.add(offset + 32), c);
452 _mm512_storeu_ps(dst.add(offset + 48), d);
453 }
454 offset += block;
455 }
456 let mut rem = count - offset;
457 while rem >= 16 {
458 let v = _mm512_loadu_ps(src.add(offset));
459 if count >= stream_threshold {
460 _mm512_stream_ps(dst.add(offset), v);
461 } else {
462 _mm512_storeu_ps(dst.add(offset), v);
463 }
464 offset += 16;
465 rem -= 16;
466 }
467 if rem > 0 {
468 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
469 }
470}
471
472#[cfg(target_arch = "x86_64")]
473#[inline]
474unsafe fn simd_copy_avx2_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
475 if !is_x86_feature_detected!("avx2") || count < 8 {
476 return false;
477 }
478 let align = simd_alignment_bytes(SimdLevel::Avx2);
479 let src_mod = (src as usize) % align;
480 let dst_mod = (dst as usize) % align;
481 let src_al = src_mod == 0;
482 let dst_al = dst_mod == 0;
483 if src_al && dst_al {
484 simd_copy_avx2_aligned(src, dst, count);
485 } else if src_mod == dst_mod {
486 let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
487 let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
488 if elems_to_align > 0 && elems_to_align < count {
489 std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
490 let src2 = src.add(elems_to_align);
491 let dst2 = dst.add(elems_to_align);
492 let rem = count - elems_to_align;
493 simd_copy_avx2_aligned(src2, dst2, rem);
494 } else {
495 simd_copy_avx2_unaligned(src, dst, count);
496 }
497 } else {
498 simd_copy_avx2_unaligned(src, dst, count);
499 }
500 true
501}
502
503#[cfg(target_arch = "x86_64")]
504#[inline]
505#[target_feature(enable = "avx2")]
506unsafe fn simd_copy_avx2_aligned(src: *const f32, dst: *mut f32, count: usize) {
507 use std::arch::x86_64::*;
508 let stream_threshold = crate::tensor::core::memory::stream_min_elems();
509 let pf_distance = crate::tensor::core::memory::prefetch_distance_elems();
510 let mut offset = 0usize;
511 let block = 32usize;
512 let n_blocks = count / block;
513 for _ in 0..n_blocks {
514 if pf_distance > 0 {
515 _mm_prefetch(src.add(offset + pf_distance) as *const i8, _MM_HINT_T0);
516 }
517 let v1 = _mm256_load_ps(src.add(offset));
518 let v2 = _mm256_load_ps(src.add(offset + 8));
519 let v3 = _mm256_load_ps(src.add(offset + 16));
520 let v4 = _mm256_load_ps(src.add(offset + 24));
521 if count >= stream_threshold {
522 _mm256_stream_ps(dst.add(offset), v1);
523 _mm256_stream_ps(dst.add(offset + 8), v2);
524 _mm256_stream_ps(dst.add(offset + 16), v3);
525 _mm256_stream_ps(dst.add(offset + 24), v4);
526 } else {
527 _mm256_store_ps(dst.add(offset), v1);
528 _mm256_store_ps(dst.add(offset + 8), v2);
529 _mm256_store_ps(dst.add(offset + 16), v3);
530 _mm256_store_ps(dst.add(offset + 24), v4);
531 }
532 offset += block;
533 }
534 let mut rem = count - offset;
535 while rem >= 8 {
536 let v = _mm256_load_ps(src.add(offset));
537 if count >= stream_threshold {
538 _mm256_stream_ps(dst.add(offset), v);
539 } else {
540 _mm256_store_ps(dst.add(offset), v);
541 }
542 offset += 8;
543 rem -= 8;
544 }
545 if rem > 0 {
546 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
547 }
548}
549
550#[cfg(target_arch = "x86_64")]
551#[inline]
552#[target_feature(enable = "avx2")]
553unsafe fn simd_copy_avx2_unaligned(src: *const f32, dst: *mut f32, count: usize) {
554 use std::arch::x86_64::*;
555 let stream_threshold = crate::tensor::core::memory::stream_min_elems();
556 let pf_distance = crate::tensor::core::memory::prefetch_distance_elems();
557 let mut offset = 0usize;
558 let block = 32usize;
559 let n_blocks = count / block;
560 for _ in 0..n_blocks {
561 if pf_distance > 0 {
562 _mm_prefetch(src.add(offset + pf_distance) as *const i8, _MM_HINT_T0);
563 }
564 let v1 = _mm256_loadu_ps(src.add(offset));
565 let v2 = _mm256_loadu_ps(src.add(offset + 8));
566 let v3 = _mm256_loadu_ps(src.add(offset + 16));
567 let v4 = _mm256_loadu_ps(src.add(offset + 24));
568 if count >= stream_threshold {
569 _mm256_stream_ps(dst.add(offset), v1);
570 _mm256_stream_ps(dst.add(offset + 8), v2);
571 _mm256_stream_ps(dst.add(offset + 16), v3);
572 _mm256_stream_ps(dst.add(offset + 24), v4);
573 } else {
574 _mm256_storeu_ps(dst.add(offset), v1);
575 _mm256_storeu_ps(dst.add(offset + 8), v2);
576 _mm256_storeu_ps(dst.add(offset + 16), v3);
577 _mm256_storeu_ps(dst.add(offset + 24), v4);
578 }
579 offset += block;
580 }
581 let mut rem = count - offset;
582 while rem >= 8 {
583 let v = _mm256_loadu_ps(src.add(offset));
584 if count >= stream_threshold {
585 _mm256_stream_ps(dst.add(offset), v);
586 } else {
587 _mm256_storeu_ps(dst.add(offset), v);
588 }
589 offset += 8;
590 rem -= 8;
591 }
592 if rem > 0 {
593 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
594 }
595}
596
597#[cfg(target_arch = "x86_64")]
598#[inline]
599unsafe fn simd_copy_sse_best(src: *const f32, dst: *mut f32, count: usize) -> bool {
600 if !is_x86_feature_detected!("sse2") || count < 4 {
601 return false;
602 }
603 let align = simd_alignment_bytes(SimdLevel::Sse2);
604 let src_mod = (src as usize) % align;
605 let dst_mod = (dst as usize) % align;
606 let src_al = src_mod == 0;
607 let dst_al = dst_mod == 0;
608 if src_al && dst_al {
609 simd_copy_sse_aligned(src, dst, count);
610 } else if src_mod == dst_mod {
611 let bytes_to_align = if src_mod == 0 { 0 } else { align - src_mod };
612 let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(count);
613 if elems_to_align > 0 && elems_to_align < count {
614 std::ptr::copy_nonoverlapping(src, dst, elems_to_align);
615 let src2 = src.add(elems_to_align);
616 let dst2 = dst.add(elems_to_align);
617 let rem = count - elems_to_align;
618 simd_copy_sse_aligned(src2, dst2, rem);
619 } else {
620 simd_copy_sse_unaligned(src, dst, count);
621 }
622 } else {
623 simd_copy_sse_unaligned(src, dst, count);
624 }
625 true
626}
627
628#[cfg(target_arch = "x86_64")]
629#[inline]
630#[target_feature(enable = "sse2")]
631unsafe fn simd_copy_sse_aligned(src: *const f32, dst: *mut f32, count: usize) {
632 use std::arch::x86_64::*;
633 let mut offset = 0usize;
634 let block = 16usize;
635 let n_blocks = count / block;
636 for _ in 0..n_blocks {
637 let a = _mm_load_ps(src.add(offset));
638 let b = _mm_load_ps(src.add(offset + 4));
639 let c = _mm_load_ps(src.add(offset + 8));
640 let d = _mm_load_ps(src.add(offset + 12));
641 _mm_store_ps(dst.add(offset), a);
642 _mm_store_ps(dst.add(offset + 4), b);
643 _mm_store_ps(dst.add(offset + 8), c);
644 _mm_store_ps(dst.add(offset + 12), d);
645 offset += block;
646 }
647 let mut rem = count - offset;
648 while rem >= 4 {
649 let v = _mm_load_ps(src.add(offset));
650 _mm_store_ps(dst.add(offset), v);
651 offset += 4;
652 rem -= 4;
653 }
654 if rem > 0 {
655 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
656 }
657}
658
659#[cfg(target_arch = "x86_64")]
660#[inline]
661#[target_feature(enable = "sse2")]
662unsafe fn simd_copy_sse_unaligned(src: *const f32, dst: *mut f32, count: usize) {
663 use std::arch::x86_64::*;
664 let mut offset = 0usize;
665 let block = 16usize;
666 let n_blocks = count / block;
667 for _ in 0..n_blocks {
668 let a = _mm_loadu_ps(src.add(offset));
669 let b = _mm_loadu_ps(src.add(offset + 4));
670 let c = _mm_loadu_ps(src.add(offset + 8));
671 let d = _mm_loadu_ps(src.add(offset + 12));
672 _mm_storeu_ps(dst.add(offset), a);
673 _mm_storeu_ps(dst.add(offset + 4), b);
674 _mm_storeu_ps(dst.add(offset + 8), c);
675 _mm_storeu_ps(dst.add(offset + 12), d);
676 offset += block;
677 }
678 let mut rem = count - offset;
679 while rem >= 4 {
680 let v = _mm_loadu_ps(src.add(offset));
681 _mm_storeu_ps(dst.add(offset), v);
682 offset += 4;
683 rem -= 4;
684 }
685 if rem > 0 {
686 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), rem);
687 }
688}
689
690#[inline]
691unsafe fn scalar_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
692 let unroll = 8;
693 let blocks = count / unroll;
694 let mut offset = 0usize;
695 for _ in 0..blocks {
696 *dst.add(offset) = *src.add(offset);
697 *dst.add(offset + 1) = *src.add(offset + 1);
698 *dst.add(offset + 2) = *src.add(offset + 2);
699 *dst.add(offset + 3) = *src.add(offset + 3);
700 *dst.add(offset + 4) = *src.add(offset + 4);
701 *dst.add(offset + 5) = *src.add(offset + 5);
702 *dst.add(offset + 6) = *src.add(offset + 6);
703 *dst.add(offset + 7) = *src.add(offset + 7);
704 offset += unroll;
705 }
706 if offset < count {
707 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
708 }
709}
710
711pub trait TensorCollectExt: Iterator<Item = Tensor> + Sized {
713 fn collect_shape(self, dims: Vec<usize>) -> Tensor;
714}
715
716impl<I> TensorCollectExt for I
717where
718 I: Iterator<Item = Tensor> + Sized,
719{
720 #[inline]
721 #[track_caller]
722 fn collect_shape(self, dims: Vec<usize>) -> Tensor {
723 Tensor::collect_into_shape(self, dims)
724 }
725}
726
727pub trait ValuesCollectExt: Iterator<Item = f32> + Sized {
732 fn collect_shape(self, dims: Vec<usize>) -> Tensor;
734}
735
736impl<I> ValuesCollectExt for I
737where
738 I: Iterator<Item = f32> + Sized,
739{
740 #[inline]
741 #[track_caller]
742 fn collect_shape(self, dims: Vec<usize>) -> Tensor {
743 Tensor::collect_values_shape(self, dims)
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751 use crate::gradtrack::NoGradTrack;
752
753 #[test]
754 fn test_collect_shape() {
755 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
756 let mat = t.chunks(2).collect_shape(vec![3, 2]);
757 assert_eq!(mat.shape().dims(), &[3, 2]);
758 assert_eq!(mat.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
759 }
760
761 #[test]
762 fn test_collect_shape_with_grad_preserves_backward() {
763 use crate::gradtrack::is_grad_enabled;
764 use crate::tensor::core::Tensor;
765
766 if !is_grad_enabled() {
767 }
769
770 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4])
771 .unwrap()
772 .with_requires_grad();
773 let parts: Vec<Tensor> = t.chunks(2).map(|c| c.mul_scalar(3.0)).collect();
775 let y = parts.into_iter().collect_shape(vec![2, 2]);
776 assert!(y.requires_grad());
777
778 let mut loss = y.sum();
779 loss.backward(None);
780 let g = t.grad_owned().unwrap();
781 assert_eq!(g.data(), &[3.0, 3.0, 3.0, 3.0]);
783 }
784
785 #[test]
786 fn test_collect_from_values_into_tensor() {
787 let vals = (0..16).map(|i| i as f32);
789 let t: Tensor = vals.collect();
790 assert_eq!(t.shape().dims(), &[16]);
791 assert_eq!(t.data()[0], 0.0);
792 assert_eq!(t.data()[15], 15.0);
793 }
794
795 #[test]
796 fn test_values_iter_then_collect_shape() {
797 let base =
799 Tensor::from_slice(&(0..12).map(|i| i as f32).collect::<Vec<_>>(), vec![3, 4]).unwrap();
800 let collected: Tensor = base.iter_elements().map(|e| e.value()).collect();
801 assert_eq!(collected.shape().dims(), &[12]);
802 let shaped = collected.view(vec![3, 4]);
804 assert_eq!(shaped.shape().dims(), &[3, 4]);
805 assert_eq!(shaped.get(&[2, 3]), 11.0);
806 }
807
808 #[test]
809 fn test_values_collect_shape_direct() {
810 let shaped: Tensor = (0..12).map(|i| i as f32).collect_shape(vec![3, 4]);
812 assert_eq!(shaped.shape().dims(), &[3, 4]);
813 assert_eq!(shaped.get(&[0, 0]), 0.0);
814 assert_eq!(shaped.get(&[2, 3]), 11.0);
815 }
816
817 #[test]
818 fn test_collect_into_shape_exact_sizes_and_zero() {
819 let empty: Vec<Tensor> = Vec::new();
821 let out = Tensor::collect_into_shape(empty, vec![0]);
822 assert_eq!(out.size(), 0);
823
824 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
826 let y = t.chunks(2).collect_shape(vec![2, 2]);
827 assert_eq!(y.shape().dims(), &[2, 2]);
828 assert_eq!(y.data(), &[1.0, 2.0, 3.0, 4.0]);
829 }
830
831 #[test]
832 fn test_collect_into_shape_no_grad_guard_fast_path() {
833 let t = Tensor::from_slice(&(0..8).map(|i| i as f32).collect::<Vec<_>>(), vec![8])
834 .unwrap()
835 .with_requires_grad();
836 let _guard = NoGradTrack::new();
837 let y = t
838 .iter_elements()
839 .map(|e| e.mul_scalar(2.0))
840 .collect_shape(vec![8]);
841 assert!(!y.requires_grad());
842 assert_eq!(y.size(), 8);
843 }
844}