1use crate::{DType, Shape, StorageTrait, TensorBase, TensorView};
7use std::marker::PhantomData;
8
9#[cfg(feature = "rayon")]
10use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback};
11#[cfg(feature = "rayon")]
12use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
13
14#[repr(C)]
51pub struct DimIter<'a, S: StorageTrait> {
52 tensor: &'a TensorBase<S>,
54 ptr: *mut u8,
56 end_ptr: *mut u8,
58 original_ptr: *mut u8,
60 stride: isize,
62 slice_shape: Shape,
64 slice_strides: Shape,
66 dtype: DType,
68 dim: usize,
70 cached_len: usize,
72 _phantom: PhantomData<&'a ()>,
74}
75
76impl<'a, S: StorageTrait> DimIter<'a, S> {
77 #[inline(always)]
93 pub fn from_tensor(tensor: &'a TensorBase<S>, dim: usize) -> Self {
94 debug_assert!(dim < tensor.rank(), "Dim {} >= {}", dim, tensor.rank());
95
96 let axis_len = tensor.shape[dim];
97
98 Self {
100 tensor,
101 ptr: std::ptr::null_mut(), end_ptr: std::ptr::null_mut(), original_ptr: std::ptr::null_mut(), stride: 0, slice_shape: Shape::empty(), slice_strides: Shape::empty(), dtype: tensor.dtype,
108 dim,
109 cached_len: axis_len,
110 _phantom: PhantomData,
111 }
112 }
113
114 #[inline(always)]
125 fn ensure_iteration_ready(&mut self) {
126 if self.ptr.is_null() {
128 let axis_len = self.cached_len;
129 let axis_stride = (self.tensor.strides[self.dim] * self.dtype.size_in_bytes()) as isize;
130
131 let data_ptr = self.tensor.as_ptr() as *mut u8;
132 let end_ptr = unsafe { data_ptr.offset(axis_len as isize * axis_stride) };
133
134 self.ptr = data_ptr;
135 self.end_ptr = end_ptr;
136 self.original_ptr = data_ptr;
137 self.stride = axis_stride;
138
139 if self.tensor.rank() > 1 {
141 for (i, &dim_size) in self.tensor.shape.as_slice().iter().enumerate() {
143 if i != self.dim {
144 self.slice_shape.push(dim_size);
145 self.slice_strides.push(self.tensor.strides[i]);
146 }
147 }
148 }
149 }
150 }
151
152 #[inline(always)]
154 pub fn len(&self) -> usize {
155 if self.ptr.is_null() {
156 return self.cached_len;
158 }
159
160 if self.stride == 0 {
161 return if self.ptr >= self.end_ptr { 0 } else { 1 };
162 }
163
164 let remaining_bytes = self.end_ptr as isize - self.ptr as isize;
165 if remaining_bytes <= 0 {
166 0
167 } else {
168 (remaining_bytes / self.stride) as usize
169 }
170 }
171
172 #[inline(always)]
174 pub fn is_empty(&self) -> bool {
175 if self.ptr.is_null() {
176 return self.cached_len == 0;
177 }
178 self.ptr >= self.end_ptr
179 }
180
181 #[cfg(feature = "rayon")]
182 #[inline]
184 pub fn par_iter(self) -> ParDimIter<'a, S>
185 where
186 S: Send + Sync,
187 {
188 ParDimIter::new(self)
189 }
190}
191
192impl<'a, S: StorageTrait> Iterator for DimIter<'a, S> {
193 type Item = TensorView<'a>;
194
195 #[inline(always)]
196 fn next(&mut self) -> Option<Self::Item> {
197 if self.ptr.is_null() {
199 self.ensure_iteration_ready();
200 }
201
202 if self.ptr >= self.end_ptr {
203 return None;
204 }
205
206 let current = self.ptr;
207 self.ptr = unsafe { self.ptr.offset(self.stride) };
208
209 let offset_bytes = (current as isize - self.original_ptr as isize) as usize;
211
212 Some(unsafe {
213 TensorView::from_raw_parts(
214 self.tensor.storage.as_storage(),
215 self.tensor.storage.ptr(),
216 self.slice_shape,
217 self.slice_strides,
218 self.tensor.offset_bytes + offset_bytes,
219 self.dtype,
220 )
221 })
222 }
223
224 #[inline(always)]
225 fn size_hint(&self) -> (usize, Option<usize>) {
226 let len = self.len();
227 (len, Some(len))
228 }
229
230 #[inline(always)]
251 fn count(self) -> usize {
252 self.cached_len
254 }
255
256 #[inline(always)]
257 fn nth(&mut self, n: usize) -> Option<Self::Item> {
258 if n == 0 {
259 return self.next();
260 }
261
262 if self.ptr.is_null() {
264 self.ensure_iteration_ready();
265 }
266
267 let skip_bytes = self.stride * n as isize;
268 let new_ptr = unsafe { self.ptr.offset(skip_bytes) };
269
270 if new_ptr >= self.end_ptr {
271 self.ptr = self.end_ptr;
272 return None;
273 }
274
275 self.ptr = new_ptr;
276 self.next()
277 }
278
279 #[inline(always)]
280 fn last(mut self) -> Option<Self::Item> {
281 if self.cached_len == 0 {
282 return None;
283 }
284
285 if self.ptr.is_null() {
287 self.ensure_iteration_ready();
288 }
289
290 let last_ptr = unsafe { self.end_ptr.offset(-self.stride) };
292 self.ptr = last_ptr;
293
294 let offset_bytes = (last_ptr as isize - self.original_ptr as isize) as usize;
295
296 Some(unsafe {
297 TensorView::from_raw_parts(
298 self.tensor.storage.as_storage(),
299 self.tensor.storage.ptr(),
300 self.slice_shape,
301 self.slice_strides,
302 self.tensor.offset_bytes + offset_bytes,
303 self.dtype,
304 )
305 })
306 }
307}
308
309impl<S: StorageTrait> ExactSizeIterator for DimIter<'_, S> {}
310impl<S: StorageTrait> std::iter::FusedIterator for DimIter<'_, S> {}
311
312impl<S: StorageTrait> DoubleEndedIterator for DimIter<'_, S> {
313 fn next_back(&mut self) -> Option<Self::Item> {
314 if self.ptr.is_null() {
316 self.ensure_iteration_ready();
317 }
318
319 if self.ptr >= self.end_ptr {
320 return None;
321 }
322
323 self.end_ptr = unsafe { self.end_ptr.offset(-self.stride) };
325
326 let current = self.end_ptr;
327 let offset_bytes = (current as isize - self.original_ptr as isize) as usize;
328
329 Some(unsafe {
330 TensorView::from_raw_parts(
331 self.tensor.storage.as_storage(),
332 self.tensor.storage.ptr(),
333 self.slice_shape,
334 self.slice_strides,
335 self.tensor.offset_bytes + offset_bytes,
336 self.dtype,
337 )
338 })
339 }
340}
341
342impl<S: StorageTrait> DimIter<'_, S> {
343 pub fn split_at(mut self, index: usize) -> (Self, Self) {
345 let len = self.cached_len;
346 assert!(index <= len, "Split index {index} exceeds length {len}");
347
348 if index == 0 {
349 let empty = self.empty();
350 return (empty, self);
351 }
352 if index == len {
353 let empty = self.empty();
354 return (self, empty);
355 }
356
357 self.ensure_iteration_ready();
359
360 let right = Self {
362 tensor: self.tensor,
363 ptr: unsafe { self.ptr.offset(index as isize * self.stride) },
364 end_ptr: self.end_ptr,
365 original_ptr: self.original_ptr,
366 stride: self.stride,
367 slice_shape: self.slice_shape,
368 slice_strides: self.slice_strides,
369 dtype: self.dtype,
370 dim: self.dim,
371 cached_len: len - index,
372 _phantom: PhantomData,
373 };
374
375 let mut left = self;
377 left.end_ptr = unsafe { left.ptr.offset(index as isize * left.stride) };
378 left.cached_len = index;
379
380 (left, right)
381 }
382
383 fn empty(&self) -> Self {
385 Self {
386 tensor: self.tensor,
387 ptr: std::ptr::null_mut(),
388 end_ptr: std::ptr::null_mut(),
389 original_ptr: std::ptr::null_mut(),
390 stride: 0,
391 slice_shape: Shape::empty(),
392 slice_strides: Shape::empty(),
393 dtype: self.dtype,
394 dim: self.dim,
395 cached_len: 0,
396 _phantom: PhantomData,
397 }
398 }
399}
400
401unsafe impl<S: StorageTrait> Send for DimIter<'_, S> where S: Send {}
402unsafe impl<S: StorageTrait> Sync for DimIter<'_, S> where S: Sync {}
403
404#[cfg(feature = "rayon")]
405pub struct ParDimIter<'a, S: StorageTrait> {
406 inner: DimIter<'a, S>,
407 min_len: usize,
408}
409
410#[cfg(feature = "rayon")]
411impl<'a, S: StorageTrait> ParDimIter<'a, S> {
412 pub fn new(inner: DimIter<'a, S>) -> Self {
413 Self { inner, min_len: 1 }
414 }
415
416 pub fn with_min_len(mut self, min_len: usize) -> Self {
417 assert_ne!(
418 min_len, 0,
419 "Minimum number of elements must be at least one"
420 );
421 self.min_len = min_len;
422 self
423 }
424}
425
426#[cfg(feature = "rayon")]
427impl<'a, S: StorageTrait + Send + Sync> IntoParallelIterator for DimIter<'a, S> {
428 type Item = TensorView<'a>;
429 type Iter = ParDimIter<'a, S>;
430
431 fn into_par_iter(self) -> Self::Iter {
432 ParDimIter::new(self)
433 }
434}
435
436#[cfg(feature = "rayon")]
437impl<'a, S: StorageTrait + Send + Sync> ParallelIterator for ParDimIter<'a, S> {
438 type Item = TensorView<'a>;
439
440 fn drive_unindexed<C>(self, consumer: C) -> C::Result
441 where
442 C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
443 {
444 bridge(self, consumer)
445 }
446
447 fn opt_len(&self) -> Option<usize> {
448 Some(self.inner.len())
449 }
450}
451
452#[cfg(feature = "rayon")]
453impl<'a, S: StorageTrait + Send + Sync> IndexedParallelIterator for ParDimIter<'a, S> {
454 fn drive<C>(self, consumer: C) -> C::Result
455 where
456 C: Consumer<Self::Item>,
457 {
458 bridge(self, consumer)
459 }
460
461 fn len(&self) -> usize {
462 self.inner.len()
463 }
464
465 fn with_producer<CB>(self, callback: CB) -> CB::Output
466 where
467 CB: ProducerCallback<Self::Item>,
468 {
469 callback.callback(ParDimProducer {
470 inner: self.inner,
471 min_len: self.min_len,
472 })
473 }
474}
475
476#[cfg(feature = "rayon")]
477struct ParDimProducer<'a, S: StorageTrait> {
478 inner: DimIter<'a, S>,
479 min_len: usize,
480}
481
482#[cfg(feature = "rayon")]
483impl<'a, S: StorageTrait + Send + Sync> Producer for ParDimProducer<'a, S> {
484 type Item = TensorView<'a>;
485 type IntoIter = DimIter<'a, S>;
486
487 fn into_iter(self) -> Self::IntoIter {
488 self.inner
489 }
490
491 fn split_at(self, index: usize) -> (Self, Self) {
492 let (left, right) = self.inner.split_at(index);
493 (
494 ParDimProducer {
495 inner: left,
496 min_len: self.min_len,
497 },
498 ParDimProducer {
499 inner: right,
500 min_len: self.min_len,
501 },
502 )
503 }
504}
505
506#[cfg(feature = "rayon")]
507impl<'a, S: StorageTrait + Send + Sync> IntoIterator for ParDimProducer<'a, S> {
508 type Item = TensorView<'a>;
509 type IntoIter = DimIter<'a, S>;
510
511 fn into_iter(self) -> Self::IntoIter {
512 self.inner
513 }
514}
515
516impl<S: StorageTrait> TensorBase<S> {
517 #[inline]
550 pub fn iter_dim(&self, dim: usize) -> DimIter<'_, S> {
551 DimIter::from_tensor(self, dim)
552 }
553
554 #[inline(always)]
569 pub fn dim_len(&self, dim: usize) -> usize {
570 debug_assert!(dim < self.rank(), "Dim {} >= {}", dim, self.rank());
571 self.shape[dim]
572 }
573
574 #[cfg(feature = "rayon")]
575 #[inline]
576 pub fn par_iter_dim(&self, dim: usize) -> ParDimIter<'_, S>
577 where
578 S: Send + Sync,
579 {
580 self.iter_dim(dim).par_iter()
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use crate::Tensor;
587 #[cfg(feature = "rayon")]
588 use rayon::iter::{IndexedParallelIterator, ParallelIterator};
589
590 #[test]
591 fn test_unified_dim_iter_basic() {
592 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
593 let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
594
595 let iter = tensor.iter_dim(0);
596 assert_eq!(iter.len(), 2);
597
598 let ptrs: Vec<_> = iter.collect();
599 assert_eq!(ptrs.len(), 2);
600 }
601
602 #[test]
603 fn test_unified_dim_iter_empty() {
604 let data: Vec<f32> = vec![];
605 let tensor = Tensor::from_vec(data, vec![1, 0]).unwrap();
606
607 let iter = tensor.iter_dim(1);
608 assert_eq!(iter.len(), 0);
609 assert!(iter.is_empty());
610 }
611
612 #[test]
613 fn test_dim_iter_count_optimization() {
614 let data = vec![1.0f32; 1000];
615 let tensor = Tensor::from_vec(data, vec![100, 10]).unwrap();
616
617 let count = tensor.iter_dim(0).count();
619 assert_eq!(count, 100);
620
621 let count = tensor.iter_dim(1).count();
622 assert_eq!(count, 10);
623
624 assert_eq!(tensor.dim_len(0), 100);
626 assert_eq!(tensor.dim_len(1), 10);
627 }
628
629 #[test]
630 fn test_dim_iter_last_optimization() {
631 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
632 let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
633
634 let last = tensor.iter_dim(0).last();
635 assert!(last.is_some());
636
637 let last_view = last.unwrap();
638 assert_eq!(last_view.at::<f32>([0]), 4.0);
639 }
640
641 #[test]
642 fn test_dim_iter_nth_optimization() {
643 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
644 let tensor = Tensor::from_vec(data, vec![4, 2]).unwrap();
645
646 let mut iter = tensor.iter_dim(0);
647 let third = iter.nth(2);
648 assert!(third.is_some());
649
650 let third_view = third.unwrap();
651 assert_eq!(third_view.at::<f32>([0]), 5.0);
652 }
653
654 #[cfg(feature = "rayon")]
655 #[test]
656 fn test_par_dim_iter_basic() {
657 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
658 let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
659
660 let par_iter = tensor.iter_dim(0).par_iter();
661 assert_eq!(par_iter.len(), 2);
662
663 let count = par_iter.count();
664 assert_eq!(count, 2);
665
666 assert_eq!(tensor.dim_len(0), count);
668 }
669
670 #[cfg(feature = "rayon")]
671 #[test]
672 fn test_par_dim_iter_map() {
673 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
674 let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
675
676 let par_iter = tensor.iter_dim(0).par_iter();
677 let results: Vec<f32> = par_iter.map(|view| view.at::<f32>([0])).collect();
678
679 assert_eq!(results.len(), 2);
680 assert_eq!(results[0], 1.0);
681 assert_eq!(results[1], 4.0);
682 }
683
684 #[cfg(feature = "rayon")]
685 #[test]
686 fn test_par_dim_iter_filter() {
687 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
688 let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
689
690 let par_iter = tensor.iter_dim(0).par_iter();
691 let results: Vec<crate::TensorView> =
692 par_iter.filter(|view| view.at::<f32>([0]) > 2.0).collect();
693
694 assert_eq!(results.len(), 1);
695 assert_eq!(results[0].at::<f32>([0]), 4.0);
696 }
697
698 #[cfg(feature = "rayon")]
699 #[test]
700 fn test_par_dim_iter_large() {
701 let size = 1000;
702 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
703 let tensor = Tensor::from_vec(data, vec![size, 1]).unwrap();
704
705 let par_iter = tensor.iter_dim(0).par_iter();
706 let sum: f32 = par_iter.map(|view| view.at::<f32>([0])).sum();
707
708 let expected_sum: f32 = (0..size).map(|i| i as f32).sum();
709 assert!((sum - expected_sum).abs() < f32::EPSILON);
710 }
711
712 #[cfg(feature = "rayon")]
713 #[test]
714 fn test_par_dim_iter_rayon_style() {
715 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
716 let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
717
718 let results: Vec<f32> = tensor
720 .iter_dim(0)
721 .par_iter()
722 .map(|view| view.at::<f32>([0]))
723 .collect();
724
725 assert_eq!(results.len(), 2);
726 assert_eq!(results[0], 1.0);
727 assert_eq!(results[1], 4.0);
728 }
729
730 #[test]
731 fn test_lightweight_count_performance() {
732 let size = 10000;
734 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
735 let tensor = Tensor::from_vec(data, vec![size, 1]).unwrap();
736
737 assert_eq!(tensor.iter_dim(0).count(), size);
739 assert_eq!(tensor.dim_len(0), size);
740 assert_eq!(tensor.iter_dim(1).count(), 1);
741 assert_eq!(tensor.dim_len(1), 1);
742 }
743
744 #[test]
745 fn test_multi_dimensional_count() {
746 let data: Vec<f32> = (0..120).map(|i| i as f32).collect();
747 let tensor = Tensor::from_vec(data, vec![2, 3, 4, 5]).unwrap();
748
749 assert_eq!(tensor.iter_dim(0).count(), 2);
751 assert_eq!(tensor.iter_dim(1).count(), 3);
752 assert_eq!(tensor.iter_dim(2).count(), 4);
753 assert_eq!(tensor.iter_dim(3).count(), 5);
754
755 assert_eq!(tensor.dim_len(0), 2);
757 assert_eq!(tensor.dim_len(1), 3);
758 assert_eq!(tensor.dim_len(2), 4);
759 assert_eq!(tensor.dim_len(3), 5);
760 }
761
762 #[test]
763 fn test_iter_dim_data_correctness() {
764 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
766 let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
767
768 let rows: Vec<_> = tensor.iter_dim(0).collect();
769 assert_eq!(rows.len(), 2);
770
771 let row0_data = rows[0].as_slice::<f32>().unwrap();
773 assert_eq!(row0_data, &[1.0, 2.0, 3.0]);
774
775 let row1_data = rows[1].as_slice::<f32>().unwrap();
777 assert_eq!(row1_data, &[4.0, 5.0, 6.0]);
778
779 let dim1_slices: Vec<_> = tensor.iter_dim(1).collect();
781 assert_eq!(dim1_slices.len(), 3);
782
783 for slice in dim1_slices.iter() {
785 assert_eq!(slice.shape().as_slice(), &[2]);
786 }
787 }
788
789 #[test]
790 fn test_iter_dim_3d_tensor() {
791 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
793 let tensor = Tensor::from_vec(data, vec![2, 3, 4]).unwrap();
794
795 let slices: Vec<_> = tensor.iter_dim(0).collect();
797 assert_eq!(slices.len(), 2);
798
799 let slice0 = slices[0].as_slice::<f32>().unwrap();
801 let expected0: Vec<f32> = (0..12).map(|i| i as f32).collect();
802 assert_eq!(slice0, expected0.as_slice());
803
804 let slice1 = slices[1].as_slice::<f32>().unwrap();
806 let expected1: Vec<f32> = (12..24).map(|i| i as f32).collect();
807 assert_eq!(slice1, expected1.as_slice());
808 }
809
810 #[test]
811 fn test_iter_dim_edge_cases() {
812 let tensor_empty = Tensor::from_vec(Vec::<f32>::new(), vec![0, 5]).unwrap();
814 let empty_iter: Vec<_> = tensor_empty.iter_dim(0).collect();
815 assert_eq!(empty_iter.len(), 0);
816 assert!(tensor_empty.iter_dim(0).is_empty());
817
818 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
820 let tensor_single = Tensor::from_vec(data, vec![1, 5]).unwrap();
821 let single_iter: Vec<_> = tensor_single.iter_dim(0).collect();
822 assert_eq!(single_iter.len(), 1);
823
824 let slice_data = single_iter[0].as_slice::<f32>().unwrap();
825 assert_eq!(slice_data, &[1.0, 2.0, 3.0, 4.0, 5.0]);
826 }
827
828 #[test]
829 fn test_iter_dim_iterator_methods() {
830 let data: Vec<f32> = (0..20).map(|i| i as f32).collect();
831 let tensor = Tensor::from_vec(data, vec![4, 5]).unwrap();
832
833 let taken: Vec<_> = tensor.iter_dim(0).take(2).collect();
835 assert_eq!(taken.len(), 2);
836
837 let skipped: Vec<_> = tensor.iter_dim(0).skip(1).collect();
839 assert_eq!(skipped.len(), 3);
840
841 for (i, slice) in tensor.iter_dim(0).enumerate() {
843 let slice_data = slice.as_slice::<f32>().unwrap();
844 let expected_start = i * 5;
845 assert_eq!(slice_data[0], expected_start as f32);
846 }
847 }
848
849 #[test]
850 fn test_iter_dim_split_at() {
851 let data: Vec<f32> = (0..20).map(|i| i as f32).collect();
852 let tensor = Tensor::from_vec(data, vec![4, 5]).unwrap();
853
854 let iter = tensor.iter_dim(0);
855 let (left, right) = iter.split_at(2);
856
857 let left_slices: Vec<_> = left.collect();
858 let right_slices: Vec<_> = right.collect();
859
860 assert_eq!(left_slices.len(), 2);
861 assert_eq!(right_slices.len(), 2);
862
863 let left0_data = left_slices[0].as_slice::<f32>().unwrap();
865 assert_eq!(left0_data, &[0.0, 1.0, 2.0, 3.0, 4.0]);
866
867 let right0_data = right_slices[0].as_slice::<f32>().unwrap();
868 assert_eq!(right0_data, &[10.0, 11.0, 12.0, 13.0, 14.0]);
869 }
870
871 #[test]
872 fn test_iter_dim_nested_iteration() {
873 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
875 let tensor = Tensor::from_vec(data, vec![2, 3, 4]).unwrap();
876
877 for (i, outer_slice) in tensor.iter_dim(0).enumerate() {
879 assert_eq!(outer_slice.rank(), 2);
880 assert_eq!(outer_slice.shape().as_slice(), &[3, 4]);
881
882 let nested_slices: Vec<_> = outer_slice.iter_dim(0).collect();
884 assert_eq!(nested_slices.len(), 3);
885
886 for (j, inner_slice) in nested_slices.iter().enumerate() {
887 let slice_data = inner_slice.as_slice::<f32>().unwrap();
888 assert_eq!(slice_data.len(), 4);
889
890 let expected_first = (i * 12 + j * 4) as f32;
892 assert_eq!(slice_data[0], expected_first);
893 }
894 }
895 }
896
897 #[test]
898 fn test_iter_dim_large_tensor_performance() {
899 let size = 1000;
901 let data: Vec<f32> = (0..size * 100).map(|i| i as f32).collect();
902 let tensor = Tensor::from_vec(data, vec![size, 100]).unwrap();
903
904 assert_eq!(tensor.iter_dim(0).count(), size);
906 assert_eq!(tensor.iter_dim(0).len(), size);
907 assert!(!tensor.iter_dim(0).is_empty());
908
909 let first = tensor.iter_dim(0).next().unwrap();
911 let first_data = first.as_slice::<f32>().unwrap();
912 assert_eq!(first_data[0], 0.0);
913
914 let last = tensor.iter_dim(0).last().unwrap();
915 let last_data = last.as_slice::<f32>().unwrap();
916 assert_eq!(last_data[0], (size - 1) as f32 * 100.0);
917 }
918
919 #[test]
920 fn test_iter_dim_double_ended() {
921 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
922 let tensor = Tensor::from_vec(data, vec![3, 4]).unwrap();
923
924 let mut iter = tensor.iter_dim(0);
925
926 let first = iter.next().unwrap();
928 let first_data = first.as_slice::<f32>().unwrap();
929 assert_eq!(first_data, &[0.0, 1.0, 2.0, 3.0]);
930
931 let last = iter.next_back().unwrap();
933 let last_data = last.as_slice::<f32>().unwrap();
934 assert_eq!(last_data, &[8.0, 9.0, 10.0, 11.0]);
935
936 let middle = iter.next().unwrap();
938 let middle_data = middle.as_slice::<f32>().unwrap();
939 assert_eq!(middle_data, &[4.0, 5.0, 6.0, 7.0]);
940
941 assert!(iter.next().is_none());
943 assert!(iter.next_back().is_none());
944 }
945
946 #[test]
947 fn test_iter_dim_various_shapes() {
948 let test_shapes = vec![
949 vec![2, 3], vec![3, 4, 5], vec![2, 3, 4, 5], vec![1, 10], vec![10, 1], ];
955
956 for shape in test_shapes {
957 let total_elements: usize = shape.iter().product();
958 let data: Vec<f32> = (0..total_elements).map(|i| i as f32).collect();
959 let tensor = Tensor::from_vec(data, shape.clone()).unwrap();
960
961 let slices: Vec<_> = tensor.iter_dim(0).collect();
963 assert_eq!(slices.len(), shape[0]);
964
965 let expected_slice_size = if shape.len() > 1 {
967 shape[1..].iter().product()
968 } else {
969 1
970 };
971
972 for slice in slices {
973 let slice_data = slice.as_slice::<f32>().unwrap();
974 assert_eq!(slice_data.len(), expected_slice_size);
975 }
976 }
977 }
978
979 #[test]
980 fn test_iter_dim_correctness_with_strides() {
981 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
983 let tensor = Tensor::from_vec(data, vec![3, 4]).unwrap();
984
985 let slices: Vec<_> = tensor.iter_dim(0).collect();
987 assert_eq!(slices.len(), 3);
988
989 for slice in slices.iter() {
991 assert_eq!(slice.shape().as_slice(), &[4]);
992 }
993
994 let dim1_slices: Vec<_> = tensor.iter_dim(1).collect();
996 assert_eq!(dim1_slices.len(), 4);
997
998 for slice in dim1_slices.iter() {
999 assert_eq!(slice.shape().as_slice(), &[3]);
1000 }
1001 }
1002
1003 #[test]
1004 fn test_iter_dim_boundary_conditions() {
1005 let scalar_like = Tensor::from_vec(vec![42.0f32], vec![1]).unwrap();
1007 let slices: Vec<_> = scalar_like.iter_dim(0).collect();
1008 assert_eq!(slices.len(), 1);
1009 let slice_data = slices[0].as_slice::<f32>().unwrap();
1010 assert_eq!(slice_data, &[42.0]);
1011
1012 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1014 let tensor = Tensor::from_vec(data, vec![3, 2]).unwrap();
1015
1016 let (left, right) = tensor.iter_dim(0).split_at(0);
1018 assert_eq!(left.len(), 0);
1019 assert_eq!(right.len(), 3);
1020
1021 let (left, right) = tensor.iter_dim(0).split_at(3);
1023 assert_eq!(left.len(), 3);
1024 assert_eq!(right.len(), 0);
1025 }
1026
1027 #[test]
1028 fn test_iter_dim_memory_safety() {
1029 let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
1031 let tensor = Tensor::from_vec(data, vec![10, 10]).unwrap();
1032
1033 let mut iter1 = tensor.iter_dim(0);
1035 let mut iter2 = tensor.iter_dim(0);
1036
1037 let slice1 = iter1.nth(5).unwrap();
1038 let slice2 = iter2.nth(5).unwrap();
1039
1040 let data1 = slice1.as_slice::<f32>().unwrap();
1041 let data2 = slice2.as_slice::<f32>().unwrap();
1042
1043 assert_eq!(data1, data2);
1045 assert_eq!(data1[0], 50.0);
1046 }
1047
1048 #[test]
1049 fn test_iter_dim_consistency_across_dimensions() {
1050 let data: Vec<f32> = (0..60).map(|i| i as f32).collect();
1052 let tensor = Tensor::from_vec(data, vec![3, 4, 5]).unwrap();
1053
1054 assert_eq!(tensor.iter_dim(0).count(), 3);
1056 assert_eq!(tensor.iter_dim(1).count(), 4);
1057 assert_eq!(tensor.iter_dim(2).count(), 5);
1058
1059 let dim0_slice = tensor.iter_dim(0).next().unwrap();
1061 assert_eq!(dim0_slice.shape().as_slice(), &[4, 5]);
1062
1063 let dim1_slice = tensor.iter_dim(1).next().unwrap();
1064 assert_eq!(dim1_slice.shape().as_slice(), &[3, 5]);
1065
1066 let dim2_slice = tensor.iter_dim(2).next().unwrap();
1067 assert_eq!(dim2_slice.shape().as_slice(), &[3, 4]);
1068 }
1069
1070 #[test]
1071 fn test_iter_dim_offset_correctness() {
1072 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
1074 let tensor = Tensor::from_vec(data, vec![4, 6]).unwrap();
1075
1076 let slices: Vec<_> = tensor.iter_dim(0).collect();
1079 let view = &slices[1]; assert_eq!(view.shape().as_slice(), &[6]);
1081
1082 let view_data = view.as_slice::<f32>().unwrap();
1084 assert_eq!(view_data, &[6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
1085
1086 let nested_slices: Vec<_> = tensor.iter_dim(0).skip(1).take(2).collect();
1088 assert_eq!(nested_slices.len(), 2);
1089
1090 let slice0_data = nested_slices[0].as_slice::<f32>().unwrap();
1092 assert_eq!(slice0_data, &[6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
1093
1094 let slice1_data = nested_slices[1].as_slice::<f32>().unwrap();
1096 assert_eq!(slice1_data, &[12.0, 13.0, 14.0, 15.0, 16.0, 17.0]);
1097 }
1098
1099 #[test]
1100 fn test_iter_dim_extreme_shapes() {
1101 let data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
1103 let wide_tensor = Tensor::from_vec(data, vec![1, 1000]).unwrap();
1104
1105 let slices: Vec<_> = wide_tensor.iter_dim(0).collect();
1106 assert_eq!(slices.len(), 1);
1107
1108 let slice_data = slices[0].as_slice::<f32>().unwrap();
1109 assert_eq!(slice_data.len(), 1000);
1110 assert_eq!(slice_data[0], 0.0);
1111 assert_eq!(slice_data[999], 999.0);
1112
1113 let data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
1115 let tall_tensor = Tensor::from_vec(data, vec![1000, 1]).unwrap();
1116
1117 let slices: Vec<_> = tall_tensor.iter_dim(0).collect();
1118 assert_eq!(slices.len(), 1000);
1119
1120 for (i, slice) in slices.iter().enumerate() {
1121 let slice_data = slice.as_slice::<f32>().unwrap();
1122 assert_eq!(slice_data.len(), 1);
1123 assert_eq!(slice_data[0], i as f32);
1124 }
1125 }
1126
1127 #[test]
1128 fn test_iter_dim_zero_stride_edge_case() {
1129 let data = vec![42.0f32];
1131 let tensor = Tensor::from_vec(data, vec![1, 1, 1, 1]).unwrap();
1132
1133 for dim in 0..4 {
1134 let slices: Vec<_> = tensor.iter_dim(dim).collect();
1135 assert_eq!(slices.len(), 1);
1136
1137 let remaining_dims: Vec<usize> = (0..4).filter(|&d| d != dim).map(|_| 1).collect();
1139 if remaining_dims.is_empty() {
1140 let slice_data = slices[0].as_slice::<f32>().unwrap();
1142 assert_eq!(slice_data, &[42.0]);
1143 }
1144 }
1145 }
1146}