tenso_rs/core/tensor/
iterator.rs

1use std::{cell::RefCell, rc::Rc};
2
3use crate::utils::index::increment_dim_index_unchecked;
4
5use super::{storage::TensorStorage, tensor::Tensor};
6
7/// The Iterator struct for Tensors
8#[derive(Debug)]
9pub struct TensorIterator<'a, T> {
10    storage: Rc<RefCell<TensorStorage<T>>>,
11    index: Vec<usize>,
12    storage_index: usize,
13    dims: &'a [usize],
14    strides: &'a [usize],
15    done: bool,
16}
17
18impl<'a, T: Copy> IntoIterator for &'a Tensor<T> {
19    type Item = T;
20
21    type IntoIter = TensorIterator<'a, T>;
22
23    fn into_iter(self) -> Self::IntoIter {
24        TensorIterator {
25            storage: Rc::clone(&self.storage),
26            index: vec![0; self.no_dim],
27            storage_index: self.offset,
28            dims: &self.dims,
29            strides: &self.strides,
30            done: false,
31        }
32    }
33}
34
35impl<'a, T: Copy> Iterator for TensorIterator<'a, T> {
36    type Item = T;
37
38    fn next(&mut self) -> Option<Self::Item> {
39        if self.done {
40            return None;
41        }
42
43        let val = self.storage.borrow().get_unchecked(self.storage_index);
44        (self.storage_index, self.done) = increment_dim_index_unchecked(
45            &mut self.index,
46            self.storage_index,
47            &self.dims,
48            &self.strides,
49        );
50
51        Some(val)
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use crate::core::tensor::{storage::TensorStorage, tensor::Tensor};
58    use std::{
59        cell::RefCell,
60        f64::consts::{E, PI, SQRT_2, TAU},
61        rc::Rc,
62    };
63
64    #[test]
65    fn full_tensor_2d() {
66        let vector = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
67        let storage = TensorStorage::from_slice(&vector);
68        let tensor_view: Tensor<i32> =
69            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 0, &[3, 3], &[3, 1]);
70        let tensor_view_vec: Vec<i32> = tensor_view.into_iter().collect();
71        assert_eq!(tensor_view_vec, vector);
72
73        let vector = vec![PI, E, TAU, SQRT_2];
74        let storage = TensorStorage::from_slice(&vector);
75        let tensor_view: Tensor<f64> =
76            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 0, &[2, 2], &[2, 1]);
77        let tensor_view_vec: Vec<f64> = tensor_view.into_iter().collect();
78        assert_eq!(tensor_view_vec, vector);
79    }
80
81    #[test]
82    fn with_strides_2d() {
83        let vector = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
84        let storage = TensorStorage::from_slice(&vector);
85        let tensor_view: Tensor<i32> =
86            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[2, 2], &[3, 1]);
87        let tensor_view_vec: Vec<i32> = tensor_view.into_iter().collect();
88        assert_eq!(tensor_view_vec, vec![5, 6, 8, 9]);
89
90        let vector = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
91        let storage = TensorStorage::from_slice(&vector);
92        let tensor_view: Tensor<i32> =
93            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 1, &[3], &[3]);
94        let tensor_view_vec: Vec<i32> = tensor_view.into_iter().collect();
95        assert_eq!(tensor_view_vec, vec![2, 5, 8]);
96
97        let vector = vec![PI, E, TAU, SQRT_2];
98        let storage = TensorStorage::from_slice(&vector);
99        let tensor_view: Tensor<f64> =
100            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 1, &[2, 1], &[2, 1]);
101        let tensor_view_vec: Vec<f64> = tensor_view.into_iter().collect();
102        assert_eq!(tensor_view_vec, vec![E, SQRT_2]);
103    }
104
105    #[test]
106    fn with_strides_3d() {
107        let vector: Vec<i32> = (0..27).collect();
108        let storage = TensorStorage::from_slice(&vector);
109        let tensor_view: Tensor<i32> =
110            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[3], &[9]);
111        let tensor_view_vec: Vec<i32> = tensor_view.into_iter().collect();
112        assert_eq!(tensor_view_vec, vec![4, 13, 22]);
113
114        let vector: Vec<i128> = (0..64).collect();
115        let storage = TensorStorage::from_slice(&vector);
116        let tensor_view: Tensor<i128> =
117            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 42, &[2, 2, 2], &[16, 4, 1]);
118        let tensor_view_vec: Vec<i128> = tensor_view.into_iter().collect();
119        assert_eq!(tensor_view_vec, vec![42, 43, 46, 47, 58, 59, 62, 63]);
120    }
121}