slsl/core/
dim_iter.rs

1//! High-performance dimension iterator optimized for tensor operations
2//!
3//! This module provides `DimIter`, a specialized iterator that efficiently iterates
4//! over tensor dimensions while maintaining zero-cost abstractions and optimal performance.
5
6use crate::{DType, Shape, StorageTrait, TensorBase, TensorView};
7use std::marker::PhantomData;
8
9#[cfg(feature = "rayon")]
10use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback};
11#[cfg(feature = "rayon")]
12use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
13
14/// High-performance iterator over tensor dimensions
15///
16/// `DimIter` provides efficient iteration over a specific dimension of a tensor,
17/// yielding `TensorView`s that represent slices along that dimension. This iterator
18/// is heavily optimized for performance with several key features:
19///
20/// ## Performance Optimizations
21///
22/// - **Lazy Initialization**: Expensive computations (pointer arithmetic, shape computation)
23///   are deferred until actually needed
24/// - **Zero-Cost Count**: The `count()` method returns a cached value without iteration
25/// - **Memory Layout Awareness**: Optimized for cache-friendly access patterns
26/// - **Minimal Allocations**: Pre-computed shapes avoid repeated heap allocations
27///
28/// ## Usage
29///
30/// ```rust
31/// use slsl::Tensor;
32///
33/// let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]).unwrap();
34///
35/// // Iterate over first dimension (yields 2 views of shape [3])
36/// for view in tensor.iter_dim(0) {
37///     println!("Slice: {:?}", view.as_slice::<f32>().unwrap());
38/// }
39///
40/// // Ultra-fast count without iteration overhead
41/// let count = tensor.iter_dim(0).count(); // Returns 2 instantly
42/// ```
43///
44/// ## Implementation Details
45///
46/// The iterator uses lazy initialization to achieve optimal performance:
47/// - For operations like `count()`, no actual computation is performed
48/// - Only when iteration begins are expensive operations like pointer arithmetic executed
49/// - Shape computation is cached to avoid repeated allocations during iteration
50#[repr(C)]
51pub struct DimIter<'a, S: StorageTrait> {
52    /// Reference to the original tensor or tensor view
53    tensor: &'a TensorBase<S>,
54    /// Current pointer position (lazy initialized)
55    ptr: *mut u8,
56    /// End pointer boundary (lazy initialized)
57    end_ptr: *mut u8,
58    /// Original starting pointer for offset calculation (lazy initialized)
59    original_ptr: *mut u8,
60    /// Step size in bytes between elements (lazy initialized)
61    stride: isize,
62    /// Pre-computed slice shape to avoid repeated allocation
63    slice_shape: Shape,
64    /// Pre-computed slice strides to avoid repeated allocation
65    slice_strides: Shape,
66    /// Data type of tensor elements
67    dtype: DType,
68    /// Dimension being iterated over
69    dim: usize,
70    /// Cached length for ultra-fast count() operations
71    cached_len: usize,
72    /// Lifetime marker
73    _phantom: PhantomData<&'a ()>,
74}
75
76impl<'a, S: StorageTrait> DimIter<'a, S> {
77    /// Creates a new dimension iterator with lazy initialization for optimal performance
78    ///
79    /// This constructor performs minimal work upfront, deferring expensive operations
80    /// like pointer arithmetic and shape computation until they're actually needed.
81    /// This design enables ultra-fast operations like `count()` that don't require
82    /// full iterator setup.
83    ///
84    /// # Arguments
85    /// * `tensor` - The tensor to iterate over
86    /// * `dim` - The dimension index to iterate along
87    ///
88    /// # Performance Notes
89    /// - Construction time: O(1) - only basic field initialization
90    /// - Memory usage: Minimal - no heap allocations during construction
91    /// - Lazy evaluation: Expensive setup deferred until first iteration
92    #[inline(always)]
93    pub fn from_tensor(tensor: &'a TensorBase<S>, dim: usize) -> Self {
94        debug_assert!(dim < tensor.rank(), "Dim {} >= {}", dim, tensor.rank());
95
96        let axis_len = tensor.shape[dim];
97
98        // Minimal initialization - defer expensive computations for optimal performance
99        Self {
100            tensor,
101            ptr: std::ptr::null_mut(), // Lazy init: set when iteration begins
102            end_ptr: std::ptr::null_mut(), // Lazy init: set when iteration begins
103            original_ptr: std::ptr::null_mut(), // Lazy init: set when iteration begins
104            stride: 0,                 // Lazy init: computed when iteration begins
105            slice_shape: Shape::empty(), // Lazy init: computed only when needed
106            slice_strides: Shape::empty(), // Lazy init: computed only when needed
107            dtype: tensor.dtype,
108            dim,
109            cached_len: axis_len,
110            _phantom: PhantomData,
111        }
112    }
113
114    /// Lazy initialization of iteration state - only called when actually iterating
115    ///
116    /// This method performs the expensive setup work that was deferred during construction.
117    /// It's called automatically by iteration methods but never by `count()` or `len()`,
118    /// which enables those methods to be extremely fast.
119    ///
120    /// # Performance Notes
121    /// - Called at most once per iterator instance
122    /// - Performs pointer arithmetic and shape computation
123    /// - Optimized to minimize cache misses and memory allocations
124    #[inline(always)]
125    fn ensure_iteration_ready(&mut self) {
126        // Initialize pointers and stride if not already done
127        if self.ptr.is_null() {
128            let axis_len = self.cached_len;
129            let axis_stride = (self.tensor.strides[self.dim] * self.dtype.size_in_bytes()) as isize;
130
131            let data_ptr = self.tensor.as_ptr() as *mut u8;
132            let end_ptr = unsafe { data_ptr.offset(axis_len as isize * axis_stride) };
133
134            self.ptr = data_ptr;
135            self.end_ptr = end_ptr;
136            self.original_ptr = data_ptr;
137            self.stride = axis_stride;
138
139            // Initialize slice shapes immediately if needed (avoid repeated checks)
140            if self.tensor.rank() > 1 {
141                // Build slice shape and strides by excluding the iteration dimension
142                for (i, &dim_size) in self.tensor.shape.as_slice().iter().enumerate() {
143                    if i != self.dim {
144                        self.slice_shape.push(dim_size);
145                        self.slice_strides.push(self.tensor.strides[i]);
146                    }
147                }
148            }
149        }
150    }
151
152    /// Get remaining length - ultra-fast using cached value
153    #[inline(always)]
154    pub fn len(&self) -> usize {
155        if self.ptr.is_null() {
156            // Not yet initialized - return full cached length
157            return self.cached_len;
158        }
159
160        if self.stride == 0 {
161            return if self.ptr >= self.end_ptr { 0 } else { 1 };
162        }
163
164        let remaining_bytes = self.end_ptr as isize - self.ptr as isize;
165        if remaining_bytes <= 0 {
166            0
167        } else {
168            (remaining_bytes / self.stride) as usize
169        }
170    }
171
172    /// Check if empty - ultra-fast check using cached length
173    #[inline(always)]
174    pub fn is_empty(&self) -> bool {
175        if self.ptr.is_null() {
176            return self.cached_len == 0;
177        }
178        self.ptr >= self.end_ptr
179    }
180
181    #[cfg(feature = "rayon")]
182    /// Convert to parallel iterator (rayon style)
183    #[inline]
184    pub fn par_iter(self) -> ParDimIter<'a, S>
185    where
186        S: Send + Sync,
187    {
188        ParDimIter::new(self)
189    }
190}
191
192impl<'a, S: StorageTrait> Iterator for DimIter<'a, S> {
193    type Item = TensorView<'a>;
194
195    #[inline(always)]
196    fn next(&mut self) -> Option<Self::Item> {
197        // Ensure iteration state is ready (inlined hot path)
198        if self.ptr.is_null() {
199            self.ensure_iteration_ready();
200        }
201
202        if self.ptr >= self.end_ptr {
203            return None;
204        }
205
206        let current = self.ptr;
207        self.ptr = unsafe { self.ptr.offset(self.stride) };
208
209        // Fast path: calculate offset directly without intermediate variables
210        let offset_bytes = (current as isize - self.original_ptr as isize) as usize;
211
212        Some(unsafe {
213            TensorView::from_raw_parts(
214                self.tensor.storage.as_storage(),
215                self.tensor.storage.ptr(),
216                self.slice_shape,
217                self.slice_strides,
218                self.tensor.offset_bytes + offset_bytes,
219                self.dtype,
220            )
221        })
222    }
223
224    #[inline(always)]
225    fn size_hint(&self) -> (usize, Option<usize>) {
226        let len = self.len();
227        (len, Some(len))
228    }
229
230    /// Ultra-fast count operation - returns result without any computation
231    ///
232    /// This method achieves optimal performance by returning a pre-cached value
233    /// instead of actually iterating through elements. This is possible because
234    /// the iterator length is known at construction time.
235    ///
236    /// # Performance
237    /// - Time complexity: O(1) - constant time regardless of tensor size
238    /// - No memory allocations, pointer arithmetic, or iterator state changes
239    /// - No TensorView constructions or shape computations
240    /// - Benchmark: ~412 ps (equivalent to ndarray performance)
241    ///
242    /// # Example
243    /// ```rust
244    /// use slsl::Tensor;
245    ///
246    /// let tensor = Tensor::from_vec(vec![1.0f32; 1_000_000], [1000, 1000]).unwrap();
247    /// let count = tensor.iter_dim(0).count(); // Instant, regardless of tensor size
248    /// assert_eq!(count, 1000);
249    /// ```
250    #[inline(always)]
251    fn count(self) -> usize {
252        // Ultra-fast path: return pre-cached length without any computation overhead
253        self.cached_len
254    }
255
256    #[inline(always)]
257    fn nth(&mut self, n: usize) -> Option<Self::Item> {
258        if n == 0 {
259            return self.next();
260        }
261
262        // Ensure iteration state is ready (inlined hot path)
263        if self.ptr.is_null() {
264            self.ensure_iteration_ready();
265        }
266
267        let skip_bytes = self.stride * n as isize;
268        let new_ptr = unsafe { self.ptr.offset(skip_bytes) };
269
270        if new_ptr >= self.end_ptr {
271            self.ptr = self.end_ptr;
272            return None;
273        }
274
275        self.ptr = new_ptr;
276        self.next()
277    }
278
279    #[inline(always)]
280    fn last(mut self) -> Option<Self::Item> {
281        if self.cached_len == 0 {
282            return None;
283        }
284
285        // Ensure iteration state is ready (inlined hot path)
286        if self.ptr.is_null() {
287            self.ensure_iteration_ready();
288        }
289
290        // Jump directly to the last element
291        let last_ptr = unsafe { self.end_ptr.offset(-self.stride) };
292        self.ptr = last_ptr;
293
294        let offset_bytes = (last_ptr as isize - self.original_ptr as isize) as usize;
295
296        Some(unsafe {
297            TensorView::from_raw_parts(
298                self.tensor.storage.as_storage(),
299                self.tensor.storage.ptr(),
300                self.slice_shape,
301                self.slice_strides,
302                self.tensor.offset_bytes + offset_bytes,
303                self.dtype,
304            )
305        })
306    }
307}
308
309impl<S: StorageTrait> ExactSizeIterator for DimIter<'_, S> {}
310impl<S: StorageTrait> std::iter::FusedIterator for DimIter<'_, S> {}
311
312impl<S: StorageTrait> DoubleEndedIterator for DimIter<'_, S> {
313    fn next_back(&mut self) -> Option<Self::Item> {
314        // Ensure iteration state is ready (inlined hot path)
315        if self.ptr.is_null() {
316            self.ensure_iteration_ready();
317        }
318
319        if self.ptr >= self.end_ptr {
320            return None;
321        }
322
323        // Move end pointer backwards
324        self.end_ptr = unsafe { self.end_ptr.offset(-self.stride) };
325
326        let current = self.end_ptr;
327        let offset_bytes = (current as isize - self.original_ptr as isize) as usize;
328
329        Some(unsafe {
330            TensorView::from_raw_parts(
331                self.tensor.storage.as_storage(),
332                self.tensor.storage.ptr(),
333                self.slice_shape,
334                self.slice_strides,
335                self.tensor.offset_bytes + offset_bytes,
336                self.dtype,
337            )
338        })
339    }
340}
341
342impl<S: StorageTrait> DimIter<'_, S> {
343    /// Split the iterator at the given index
344    pub fn split_at(mut self, index: usize) -> (Self, Self) {
345        let len = self.cached_len;
346        assert!(index <= len, "Split index {index} exceeds length {len}");
347
348        if index == 0 {
349            let empty = self.empty();
350            return (empty, self);
351        }
352        if index == len {
353            let empty = self.empty();
354            return (self, empty);
355        }
356
357        // Ensure iteration state is ready for splitting
358        self.ensure_iteration_ready();
359
360        // Create a copy for the right side
361        let right = Self {
362            tensor: self.tensor,
363            ptr: unsafe { self.ptr.offset(index as isize * self.stride) },
364            end_ptr: self.end_ptr,
365            original_ptr: self.original_ptr,
366            stride: self.stride,
367            slice_shape: self.slice_shape,
368            slice_strides: self.slice_strides,
369            dtype: self.dtype,
370            dim: self.dim,
371            cached_len: len - index,
372            _phantom: PhantomData,
373        };
374
375        // Adjust the left side
376        let mut left = self;
377        left.end_ptr = unsafe { left.ptr.offset(index as isize * left.stride) };
378        left.cached_len = index;
379
380        (left, right)
381    }
382
383    /// Create an empty iterator
384    fn empty(&self) -> Self {
385        Self {
386            tensor: self.tensor,
387            ptr: std::ptr::null_mut(),
388            end_ptr: std::ptr::null_mut(),
389            original_ptr: std::ptr::null_mut(),
390            stride: 0,
391            slice_shape: Shape::empty(),
392            slice_strides: Shape::empty(),
393            dtype: self.dtype,
394            dim: self.dim,
395            cached_len: 0,
396            _phantom: PhantomData,
397        }
398    }
399}
400
401unsafe impl<S: StorageTrait> Send for DimIter<'_, S> where S: Send {}
402unsafe impl<S: StorageTrait> Sync for DimIter<'_, S> where S: Sync {}
403
404#[cfg(feature = "rayon")]
405pub struct ParDimIter<'a, S: StorageTrait> {
406    inner: DimIter<'a, S>,
407    min_len: usize,
408}
409
410#[cfg(feature = "rayon")]
411impl<'a, S: StorageTrait> ParDimIter<'a, S> {
412    pub fn new(inner: DimIter<'a, S>) -> Self {
413        Self { inner, min_len: 1 }
414    }
415
416    pub fn with_min_len(mut self, min_len: usize) -> Self {
417        assert_ne!(
418            min_len, 0,
419            "Minimum number of elements must be at least one"
420        );
421        self.min_len = min_len;
422        self
423    }
424}
425
426#[cfg(feature = "rayon")]
427impl<'a, S: StorageTrait + Send + Sync> IntoParallelIterator for DimIter<'a, S> {
428    type Item = TensorView<'a>;
429    type Iter = ParDimIter<'a, S>;
430
431    fn into_par_iter(self) -> Self::Iter {
432        ParDimIter::new(self)
433    }
434}
435
436#[cfg(feature = "rayon")]
437impl<'a, S: StorageTrait + Send + Sync> ParallelIterator for ParDimIter<'a, S> {
438    type Item = TensorView<'a>;
439
440    fn drive_unindexed<C>(self, consumer: C) -> C::Result
441    where
442        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
443    {
444        bridge(self, consumer)
445    }
446
447    fn opt_len(&self) -> Option<usize> {
448        Some(self.inner.len())
449    }
450}
451
452#[cfg(feature = "rayon")]
453impl<'a, S: StorageTrait + Send + Sync> IndexedParallelIterator for ParDimIter<'a, S> {
454    fn drive<C>(self, consumer: C) -> C::Result
455    where
456        C: Consumer<Self::Item>,
457    {
458        bridge(self, consumer)
459    }
460
461    fn len(&self) -> usize {
462        self.inner.len()
463    }
464
465    fn with_producer<CB>(self, callback: CB) -> CB::Output
466    where
467        CB: ProducerCallback<Self::Item>,
468    {
469        callback.callback(ParDimProducer {
470            inner: self.inner,
471            min_len: self.min_len,
472        })
473    }
474}
475
476#[cfg(feature = "rayon")]
477struct ParDimProducer<'a, S: StorageTrait> {
478    inner: DimIter<'a, S>,
479    min_len: usize,
480}
481
482#[cfg(feature = "rayon")]
483impl<'a, S: StorageTrait + Send + Sync> Producer for ParDimProducer<'a, S> {
484    type Item = TensorView<'a>;
485    type IntoIter = DimIter<'a, S>;
486
487    fn into_iter(self) -> Self::IntoIter {
488        self.inner
489    }
490
491    fn split_at(self, index: usize) -> (Self, Self) {
492        let (left, right) = self.inner.split_at(index);
493        (
494            ParDimProducer {
495                inner: left,
496                min_len: self.min_len,
497            },
498            ParDimProducer {
499                inner: right,
500                min_len: self.min_len,
501            },
502        )
503    }
504}
505
506#[cfg(feature = "rayon")]
507impl<'a, S: StorageTrait + Send + Sync> IntoIterator for ParDimProducer<'a, S> {
508    type Item = TensorView<'a>;
509    type IntoIter = DimIter<'a, S>;
510
511    fn into_iter(self) -> Self::IntoIter {
512        self.inner
513    }
514}
515
516impl<S: StorageTrait> TensorBase<S> {
517    /// Creates an iterator over the specified dimension
518    ///
519    /// Returns a `DimIter` that yields `TensorView`s representing slices
520    /// along the specified dimension. The iterator is optimized for performance
521    /// with lazy initialization and zero-cost abstractions.
522    ///
523    /// # Arguments
524    /// * `dim` - The dimension index to iterate over (0-based)
525    ///
526    /// # Returns
527    /// A `DimIter` that can be used with standard Rust iterator methods
528    ///
529    /// # Performance
530    /// - Iterator construction: O(1) with lazy initialization
531    /// - `count()` operations: O(1) using cached values
532    /// - Actual iteration: Optimized for cache-friendly memory access
533    ///
534    /// # Example
535    /// ```rust
536    /// use slsl::Tensor;
537    ///
538    /// let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]).unwrap();
539    ///
540    /// // Iterate over rows (dimension 0)
541    /// for (i, row) in tensor.iter_dim(0).enumerate() {
542    ///     println!("Row {}: {:?}", i, row.as_slice::<f32>().unwrap());
543    /// }
544    ///
545    /// // Ultra-fast count
546    /// assert_eq!(tensor.iter_dim(0).count(), 2);
547    /// assert_eq!(tensor.iter_dim(1).count(), 3);
548    /// ```
549    #[inline]
550    pub fn iter_dim(&self, dim: usize) -> DimIter<'_, S> {
551        DimIter::from_tensor(self, dim)
552    }
553
554    /// Get the size of a specific dimension (ultra-fast alternative to iter_dim().count())
555    ///
556    /// This method provides direct access to dimension sizes without any iterator
557    /// construction overhead. While `iter_dim(dim).count()` is also very fast due
558    /// to optimizations, this method is slightly faster for simple size queries.
559    ///
560    /// # Arguments
561    /// * `dim` - The dimension index (0-based)
562    ///
563    /// # Returns
564    /// The size of the specified dimension
565    ///
566    /// # Performance
567    /// Time complexity: O(1) - direct array access
568    #[inline(always)]
569    pub fn dim_len(&self, dim: usize) -> usize {
570        debug_assert!(dim < self.rank(), "Dim {} >= {}", dim, self.rank());
571        self.shape[dim]
572    }
573
574    #[cfg(feature = "rayon")]
575    #[inline]
576    pub fn par_iter_dim(&self, dim: usize) -> ParDimIter<'_, S>
577    where
578        S: Send + Sync,
579    {
580        self.iter_dim(dim).par_iter()
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use crate::Tensor;
587    #[cfg(feature = "rayon")]
588    use rayon::iter::{IndexedParallelIterator, ParallelIterator};
589
590    #[test]
591    fn test_unified_dim_iter_basic() {
592        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
593        let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
594
595        let iter = tensor.iter_dim(0);
596        assert_eq!(iter.len(), 2);
597
598        let ptrs: Vec<_> = iter.collect();
599        assert_eq!(ptrs.len(), 2);
600    }
601
602    #[test]
603    fn test_unified_dim_iter_empty() {
604        let data: Vec<f32> = vec![];
605        let tensor = Tensor::from_vec(data, vec![1, 0]).unwrap();
606
607        let iter = tensor.iter_dim(1);
608        assert_eq!(iter.len(), 0);
609        assert!(iter.is_empty());
610    }
611
612    #[test]
613    fn test_dim_iter_count_optimization() {
614        let data = vec![1.0f32; 1000];
615        let tensor = Tensor::from_vec(data, vec![100, 10]).unwrap();
616
617        // This should be very fast now as it doesn't actually iterate
618        let count = tensor.iter_dim(0).count();
619        assert_eq!(count, 100);
620
621        let count = tensor.iter_dim(1).count();
622        assert_eq!(count, 10);
623
624        // Test the ultra-fast dim_len method
625        assert_eq!(tensor.dim_len(0), 100);
626        assert_eq!(tensor.dim_len(1), 10);
627    }
628
629    #[test]
630    fn test_dim_iter_last_optimization() {
631        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
632        let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
633
634        let last = tensor.iter_dim(0).last();
635        assert!(last.is_some());
636
637        let last_view = last.unwrap();
638        assert_eq!(last_view.at::<f32>([0]), 4.0);
639    }
640
641    #[test]
642    fn test_dim_iter_nth_optimization() {
643        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
644        let tensor = Tensor::from_vec(data, vec![4, 2]).unwrap();
645
646        let mut iter = tensor.iter_dim(0);
647        let third = iter.nth(2);
648        assert!(third.is_some());
649
650        let third_view = third.unwrap();
651        assert_eq!(third_view.at::<f32>([0]), 5.0);
652    }
653
654    #[cfg(feature = "rayon")]
655    #[test]
656    fn test_par_dim_iter_basic() {
657        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
658        let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
659
660        let par_iter = tensor.iter_dim(0).par_iter();
661        assert_eq!(par_iter.len(), 2);
662
663        let count = par_iter.count();
664        assert_eq!(count, 2);
665
666        // Compare with ultra-fast alternative
667        assert_eq!(tensor.dim_len(0), count);
668    }
669
670    #[cfg(feature = "rayon")]
671    #[test]
672    fn test_par_dim_iter_map() {
673        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
674        let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
675
676        let par_iter = tensor.iter_dim(0).par_iter();
677        let results: Vec<f32> = par_iter.map(|view| view.at::<f32>([0])).collect();
678
679        assert_eq!(results.len(), 2);
680        assert_eq!(results[0], 1.0);
681        assert_eq!(results[1], 4.0);
682    }
683
684    #[cfg(feature = "rayon")]
685    #[test]
686    fn test_par_dim_iter_filter() {
687        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
688        let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
689
690        let par_iter = tensor.iter_dim(0).par_iter();
691        let results: Vec<crate::TensorView> =
692            par_iter.filter(|view| view.at::<f32>([0]) > 2.0).collect();
693
694        assert_eq!(results.len(), 1);
695        assert_eq!(results[0].at::<f32>([0]), 4.0);
696    }
697
698    #[cfg(feature = "rayon")]
699    #[test]
700    fn test_par_dim_iter_large() {
701        let size = 1000;
702        let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
703        let tensor = Tensor::from_vec(data, vec![size, 1]).unwrap();
704
705        let par_iter = tensor.iter_dim(0).par_iter();
706        let sum: f32 = par_iter.map(|view| view.at::<f32>([0])).sum();
707
708        let expected_sum: f32 = (0..size).map(|i| i as f32).sum();
709        assert!((sum - expected_sum).abs() < f32::EPSILON);
710    }
711
712    #[cfg(feature = "rayon")]
713    #[test]
714    fn test_par_dim_iter_rayon_style() {
715        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
716        let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
717
718        // Test rayon-style chaining
719        let results: Vec<f32> = tensor
720            .iter_dim(0)
721            .par_iter()
722            .map(|view| view.at::<f32>([0]))
723            .collect();
724
725        assert_eq!(results.len(), 2);
726        assert_eq!(results[0], 1.0);
727        assert_eq!(results[1], 4.0);
728    }
729
730    #[test]
731    fn test_lightweight_count_performance() {
732        // Test for large tensors where construction overhead matters
733        let size = 10000;
734        let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
735        let tensor = Tensor::from_vec(data, vec![size, 1]).unwrap();
736
737        // These should all be extremely fast
738        assert_eq!(tensor.iter_dim(0).count(), size);
739        assert_eq!(tensor.dim_len(0), size);
740        assert_eq!(tensor.iter_dim(1).count(), 1);
741        assert_eq!(tensor.dim_len(1), 1);
742    }
743
744    #[test]
745    fn test_multi_dimensional_count() {
746        let data: Vec<f32> = (0..120).map(|i| i as f32).collect();
747        let tensor = Tensor::from_vec(data, vec![2, 3, 4, 5]).unwrap();
748
749        // Test all dimensions
750        assert_eq!(tensor.iter_dim(0).count(), 2);
751        assert_eq!(tensor.iter_dim(1).count(), 3);
752        assert_eq!(tensor.iter_dim(2).count(), 4);
753        assert_eq!(tensor.iter_dim(3).count(), 5);
754
755        // Verify with direct dimension access
756        assert_eq!(tensor.dim_len(0), 2);
757        assert_eq!(tensor.dim_len(1), 3);
758        assert_eq!(tensor.dim_len(2), 4);
759        assert_eq!(tensor.dim_len(3), 5);
760    }
761
762    #[test]
763    fn test_iter_dim_data_correctness() {
764        // Test basic data correctness for 2D tensor
765        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
766        let tensor = Tensor::from_vec(data, vec![2, 3]).unwrap();
767
768        let rows: Vec<_> = tensor.iter_dim(0).collect();
769        assert_eq!(rows.len(), 2);
770
771        // First row should be [1.0, 2.0, 3.0]
772        let row0_data = rows[0].as_slice::<f32>().unwrap();
773        assert_eq!(row0_data, &[1.0, 2.0, 3.0]);
774
775        // Second row should be [4.0, 5.0, 6.0]
776        let row1_data = rows[1].as_slice::<f32>().unwrap();
777        assert_eq!(row1_data, &[4.0, 5.0, 6.0]);
778
779        // Test iteration over dimension 1 - verify count and structure
780        let dim1_slices: Vec<_> = tensor.iter_dim(1).collect();
781        assert_eq!(dim1_slices.len(), 3);
782
783        // Each slice should have shape [2] (2 rows)
784        for slice in dim1_slices.iter() {
785            assert_eq!(slice.shape().as_slice(), &[2]);
786        }
787    }
788
789    #[test]
790    fn test_iter_dim_3d_tensor() {
791        // Test with 3D tensor [2, 3, 4] = 24 elements
792        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
793        let tensor = Tensor::from_vec(data, vec![2, 3, 4]).unwrap();
794
795        // Iterate over first dimension (2 slices of [3, 4])
796        let slices: Vec<_> = tensor.iter_dim(0).collect();
797        assert_eq!(slices.len(), 2);
798
799        // First slice should contain elements 0-11
800        let slice0 = slices[0].as_slice::<f32>().unwrap();
801        let expected0: Vec<f32> = (0..12).map(|i| i as f32).collect();
802        assert_eq!(slice0, expected0.as_slice());
803
804        // Second slice should contain elements 12-23
805        let slice1 = slices[1].as_slice::<f32>().unwrap();
806        let expected1: Vec<f32> = (12..24).map(|i| i as f32).collect();
807        assert_eq!(slice1, expected1.as_slice());
808    }
809
810    #[test]
811    fn test_iter_dim_edge_cases() {
812        // Test empty dimension
813        let tensor_empty = Tensor::from_vec(Vec::<f32>::new(), vec![0, 5]).unwrap();
814        let empty_iter: Vec<_> = tensor_empty.iter_dim(0).collect();
815        assert_eq!(empty_iter.len(), 0);
816        assert!(tensor_empty.iter_dim(0).is_empty());
817
818        // Test single element dimension
819        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
820        let tensor_single = Tensor::from_vec(data, vec![1, 5]).unwrap();
821        let single_iter: Vec<_> = tensor_single.iter_dim(0).collect();
822        assert_eq!(single_iter.len(), 1);
823
824        let slice_data = single_iter[0].as_slice::<f32>().unwrap();
825        assert_eq!(slice_data, &[1.0, 2.0, 3.0, 4.0, 5.0]);
826    }
827
828    #[test]
829    fn test_iter_dim_iterator_methods() {
830        let data: Vec<f32> = (0..20).map(|i| i as f32).collect();
831        let tensor = Tensor::from_vec(data, vec![4, 5]).unwrap();
832
833        // Test take()
834        let taken: Vec<_> = tensor.iter_dim(0).take(2).collect();
835        assert_eq!(taken.len(), 2);
836
837        // Test skip()
838        let skipped: Vec<_> = tensor.iter_dim(0).skip(1).collect();
839        assert_eq!(skipped.len(), 3);
840
841        // Test enumerate
842        for (i, slice) in tensor.iter_dim(0).enumerate() {
843            let slice_data = slice.as_slice::<f32>().unwrap();
844            let expected_start = i * 5;
845            assert_eq!(slice_data[0], expected_start as f32);
846        }
847    }
848
849    #[test]
850    fn test_iter_dim_split_at() {
851        let data: Vec<f32> = (0..20).map(|i| i as f32).collect();
852        let tensor = Tensor::from_vec(data, vec![4, 5]).unwrap();
853
854        let iter = tensor.iter_dim(0);
855        let (left, right) = iter.split_at(2);
856
857        let left_slices: Vec<_> = left.collect();
858        let right_slices: Vec<_> = right.collect();
859
860        assert_eq!(left_slices.len(), 2);
861        assert_eq!(right_slices.len(), 2);
862
863        // Verify data correctness
864        let left0_data = left_slices[0].as_slice::<f32>().unwrap();
865        assert_eq!(left0_data, &[0.0, 1.0, 2.0, 3.0, 4.0]);
866
867        let right0_data = right_slices[0].as_slice::<f32>().unwrap();
868        assert_eq!(right0_data, &[10.0, 11.0, 12.0, 13.0, 14.0]);
869    }
870
871    #[test]
872    fn test_iter_dim_nested_iteration() {
873        // Test nested iteration for 3D tensor
874        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
875        let tensor = Tensor::from_vec(data, vec![2, 3, 4]).unwrap();
876
877        // Iterate over first dimension, then over nested dimension
878        for (i, outer_slice) in tensor.iter_dim(0).enumerate() {
879            assert_eq!(outer_slice.rank(), 2);
880            assert_eq!(outer_slice.shape().as_slice(), &[3, 4]);
881
882            // Nest iteration over the slice
883            let nested_slices: Vec<_> = outer_slice.iter_dim(0).collect();
884            assert_eq!(nested_slices.len(), 3);
885
886            for (j, inner_slice) in nested_slices.iter().enumerate() {
887                let slice_data = inner_slice.as_slice::<f32>().unwrap();
888                assert_eq!(slice_data.len(), 4);
889
890                // Verify first element matches expected pattern
891                let expected_first = (i * 12 + j * 4) as f32;
892                assert_eq!(slice_data[0], expected_first);
893            }
894        }
895    }
896
897    #[test]
898    fn test_iter_dim_large_tensor_performance() {
899        // Test with moderately large tensor to ensure performance characteristics
900        let size = 1000;
901        let data: Vec<f32> = (0..size * 100).map(|i| i as f32).collect();
902        let tensor = Tensor::from_vec(data, vec![size, 100]).unwrap();
903
904        // Ultra-fast operations should complete instantly
905        assert_eq!(tensor.iter_dim(0).count(), size);
906        assert_eq!(tensor.iter_dim(0).len(), size);
907        assert!(!tensor.iter_dim(0).is_empty());
908
909        // Test first and last elements for correctness
910        let first = tensor.iter_dim(0).next().unwrap();
911        let first_data = first.as_slice::<f32>().unwrap();
912        assert_eq!(first_data[0], 0.0);
913
914        let last = tensor.iter_dim(0).last().unwrap();
915        let last_data = last.as_slice::<f32>().unwrap();
916        assert_eq!(last_data[0], (size - 1) as f32 * 100.0);
917    }
918
919    #[test]
920    fn test_iter_dim_double_ended() {
921        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
922        let tensor = Tensor::from_vec(data, vec![3, 4]).unwrap();
923
924        let mut iter = tensor.iter_dim(0);
925
926        // Get first element
927        let first = iter.next().unwrap();
928        let first_data = first.as_slice::<f32>().unwrap();
929        assert_eq!(first_data, &[0.0, 1.0, 2.0, 3.0]);
930
931        // Get last element
932        let last = iter.next_back().unwrap();
933        let last_data = last.as_slice::<f32>().unwrap();
934        assert_eq!(last_data, &[8.0, 9.0, 10.0, 11.0]);
935
936        // Get middle element
937        let middle = iter.next().unwrap();
938        let middle_data = middle.as_slice::<f32>().unwrap();
939        assert_eq!(middle_data, &[4.0, 5.0, 6.0, 7.0]);
940
941        // Iterator should be exhausted
942        assert!(iter.next().is_none());
943        assert!(iter.next_back().is_none());
944    }
945
946    #[test]
947    fn test_iter_dim_various_shapes() {
948        let test_shapes = vec![
949            vec![2, 3],       // Small 2D
950            vec![3, 4, 5],    // 3D
951            vec![2, 3, 4, 5], // 4D
952            vec![1, 10],      // Single row
953            vec![10, 1],      // Single column
954        ];
955
956        for shape in test_shapes {
957            let total_elements: usize = shape.iter().product();
958            let data: Vec<f32> = (0..total_elements).map(|i| i as f32).collect();
959            let tensor = Tensor::from_vec(data, shape.clone()).unwrap();
960
961            // Test iteration over first dimension
962            let slices: Vec<_> = tensor.iter_dim(0).collect();
963            assert_eq!(slices.len(), shape[0]);
964
965            // Verify each slice has correct size
966            let expected_slice_size = if shape.len() > 1 {
967                shape[1..].iter().product()
968            } else {
969                1
970            };
971
972            for slice in slices {
973                let slice_data = slice.as_slice::<f32>().unwrap();
974                assert_eq!(slice_data.len(), expected_slice_size);
975            }
976        }
977    }
978
979    #[test]
980    fn test_iter_dim_correctness_with_strides() {
981        // Test with non-contiguous tensor (different strides)
982        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
983        let tensor = Tensor::from_vec(data, vec![3, 4]).unwrap();
984
985        // Test basic iteration works with any stride configuration
986        let slices: Vec<_> = tensor.iter_dim(0).collect();
987        assert_eq!(slices.len(), 3);
988
989        // Verify all slices have correct shape
990        for slice in slices.iter() {
991            assert_eq!(slice.shape().as_slice(), &[4]);
992        }
993
994        // Test dimension 1 iteration
995        let dim1_slices: Vec<_> = tensor.iter_dim(1).collect();
996        assert_eq!(dim1_slices.len(), 4);
997
998        for slice in dim1_slices.iter() {
999            assert_eq!(slice.shape().as_slice(), &[3]);
1000        }
1001    }
1002
1003    #[test]
1004    fn test_iter_dim_boundary_conditions() {
1005        // Test with very small tensors
1006        let scalar_like = Tensor::from_vec(vec![42.0f32], vec![1]).unwrap();
1007        let slices: Vec<_> = scalar_like.iter_dim(0).collect();
1008        assert_eq!(slices.len(), 1);
1009        let slice_data = slices[0].as_slice::<f32>().unwrap();
1010        assert_eq!(slice_data, &[42.0]);
1011
1012        // Test split_at edge cases
1013        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1014        let tensor = Tensor::from_vec(data, vec![3, 2]).unwrap();
1015
1016        // Split at beginning
1017        let (left, right) = tensor.iter_dim(0).split_at(0);
1018        assert_eq!(left.len(), 0);
1019        assert_eq!(right.len(), 3);
1020
1021        // Split at end
1022        let (left, right) = tensor.iter_dim(0).split_at(3);
1023        assert_eq!(left.len(), 3);
1024        assert_eq!(right.len(), 0);
1025    }
1026
1027    #[test]
1028    fn test_iter_dim_memory_safety() {
1029        // Test that iterators maintain correct memory references
1030        let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
1031        let tensor = Tensor::from_vec(data, vec![10, 10]).unwrap();
1032
1033        // Create multiple iterators and verify they don't interfere
1034        let mut iter1 = tensor.iter_dim(0);
1035        let mut iter2 = tensor.iter_dim(0);
1036
1037        let slice1 = iter1.nth(5).unwrap();
1038        let slice2 = iter2.nth(5).unwrap();
1039
1040        let data1 = slice1.as_slice::<f32>().unwrap();
1041        let data2 = slice2.as_slice::<f32>().unwrap();
1042
1043        // Both should point to the same data
1044        assert_eq!(data1, data2);
1045        assert_eq!(data1[0], 50.0);
1046    }
1047
1048    #[test]
1049    fn test_iter_dim_consistency_across_dimensions() {
1050        // Test that iteration is consistent across different dimensions
1051        let data: Vec<f32> = (0..60).map(|i| i as f32).collect();
1052        let tensor = Tensor::from_vec(data, vec![3, 4, 5]).unwrap();
1053
1054        // Iterate over each dimension and verify counts
1055        assert_eq!(tensor.iter_dim(0).count(), 3);
1056        assert_eq!(tensor.iter_dim(1).count(), 4);
1057        assert_eq!(tensor.iter_dim(2).count(), 5);
1058
1059        // Verify slice shapes are correct
1060        let dim0_slice = tensor.iter_dim(0).next().unwrap();
1061        assert_eq!(dim0_slice.shape().as_slice(), &[4, 5]);
1062
1063        let dim1_slice = tensor.iter_dim(1).next().unwrap();
1064        assert_eq!(dim1_slice.shape().as_slice(), &[3, 5]);
1065
1066        let dim2_slice = tensor.iter_dim(2).next().unwrap();
1067        assert_eq!(dim2_slice.shape().as_slice(), &[3, 4]);
1068    }
1069
1070    #[test]
1071    fn test_iter_dim_offset_correctness() {
1072        // Create a slice of a tensor and verify iteration works correctly
1073        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
1074        let tensor = Tensor::from_vec(data, vec![4, 6]).unwrap();
1075
1076        // Test with a simple subview to verify offset handling
1077        // Get a view of the tensor starting from row 1
1078        let slices: Vec<_> = tensor.iter_dim(0).collect();
1079        let view = &slices[1]; // This creates an offset view
1080        assert_eq!(view.shape().as_slice(), &[6]);
1081
1082        // The view should contain elements [6, 7, 8, 9, 10, 11]
1083        let view_data = view.as_slice::<f32>().unwrap();
1084        assert_eq!(view_data, &[6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
1085
1086        // Test nested iteration with proper offset handling
1087        let nested_slices: Vec<_> = tensor.iter_dim(0).skip(1).take(2).collect();
1088        assert_eq!(nested_slices.len(), 2);
1089
1090        // First nested slice (row 1)
1091        let slice0_data = nested_slices[0].as_slice::<f32>().unwrap();
1092        assert_eq!(slice0_data, &[6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
1093
1094        // Second nested slice (row 2)
1095        let slice1_data = nested_slices[1].as_slice::<f32>().unwrap();
1096        assert_eq!(slice1_data, &[12.0, 13.0, 14.0, 15.0, 16.0, 17.0]);
1097    }
1098
1099    #[test]
1100    fn test_iter_dim_extreme_shapes() {
1101        // Test with very wide tensor
1102        let data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
1103        let wide_tensor = Tensor::from_vec(data, vec![1, 1000]).unwrap();
1104
1105        let slices: Vec<_> = wide_tensor.iter_dim(0).collect();
1106        assert_eq!(slices.len(), 1);
1107
1108        let slice_data = slices[0].as_slice::<f32>().unwrap();
1109        assert_eq!(slice_data.len(), 1000);
1110        assert_eq!(slice_data[0], 0.0);
1111        assert_eq!(slice_data[999], 999.0);
1112
1113        // Test with very tall tensor
1114        let data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
1115        let tall_tensor = Tensor::from_vec(data, vec![1000, 1]).unwrap();
1116
1117        let slices: Vec<_> = tall_tensor.iter_dim(0).collect();
1118        assert_eq!(slices.len(), 1000);
1119
1120        for (i, slice) in slices.iter().enumerate() {
1121            let slice_data = slice.as_slice::<f32>().unwrap();
1122            assert_eq!(slice_data.len(), 1);
1123            assert_eq!(slice_data[0], i as f32);
1124        }
1125    }
1126
1127    #[test]
1128    fn test_iter_dim_zero_stride_edge_case() {
1129        // Test behavior with dimension size 1 (which could have zero stride optimization)
1130        let data = vec![42.0f32];
1131        let tensor = Tensor::from_vec(data, vec![1, 1, 1, 1]).unwrap();
1132
1133        for dim in 0..4 {
1134            let slices: Vec<_> = tensor.iter_dim(dim).collect();
1135            assert_eq!(slices.len(), 1);
1136
1137            // The remaining tensor should have one less dimension
1138            let remaining_dims: Vec<usize> = (0..4).filter(|&d| d != dim).map(|_| 1).collect();
1139            if remaining_dims.is_empty() {
1140                // If we're left with a scalar, as_slice should return array with one element
1141                let slice_data = slices[0].as_slice::<f32>().unwrap();
1142                assert_eq!(slice_data, &[42.0]);
1143            }
1144        }
1145    }
1146}