1use std::{cell::RefCell, rc::Rc};
2
3use crate::utils::index::increment_dim_index_unchecked;
4
5use super::{storage::TensorStorage, tensor::Tensor};
6
7#[derive(Debug)]
9pub struct TensorIterator<'a, T> {
10 storage: Rc<RefCell<TensorStorage<T>>>,
11 index: Vec<usize>,
12 storage_index: usize,
13 dims: &'a [usize],
14 strides: &'a [usize],
15 done: bool,
16}
17
18impl<'a, T: Copy> IntoIterator for &'a Tensor<T> {
19 type Item = T;
20
21 type IntoIter = TensorIterator<'a, T>;
22
23 fn into_iter(self) -> Self::IntoIter {
24 TensorIterator {
25 storage: Rc::clone(&self.storage),
26 index: vec![0; self.no_dim],
27 storage_index: self.offset,
28 dims: &self.dims,
29 strides: &self.strides,
30 done: false,
31 }
32 }
33}
34
35impl<'a, T: Copy> Iterator for TensorIterator<'a, T> {
36 type Item = T;
37
38 fn next(&mut self) -> Option<Self::Item> {
39 if self.done {
40 return None;
41 }
42
43 let val = self.storage.borrow().get_unchecked(self.storage_index);
44 (self.storage_index, self.done) = increment_dim_index_unchecked(
45 &mut self.index,
46 self.storage_index,
47 &self.dims,
48 &self.strides,
49 );
50
51 Some(val)
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use crate::core::tensor::{storage::TensorStorage, tensor::Tensor};
58 use std::{
59 cell::RefCell,
60 f64::consts::{E, PI, SQRT_2, TAU},
61 rc::Rc,
62 };
63
64 #[test]
65 fn full_tensor_2d() {
66 let vector = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
67 let storage = TensorStorage::from_slice(&vector);
68 let tensor_view: Tensor<i32> =
69 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 0, &[3, 3], &[3, 1]);
70 let tensor_view_vec: Vec<i32> = tensor_view.into_iter().collect();
71 assert_eq!(tensor_view_vec, vector);
72
73 let vector = vec![PI, E, TAU, SQRT_2];
74 let storage = TensorStorage::from_slice(&vector);
75 let tensor_view: Tensor<f64> =
76 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 0, &[2, 2], &[2, 1]);
77 let tensor_view_vec: Vec<f64> = tensor_view.into_iter().collect();
78 assert_eq!(tensor_view_vec, vector);
79 }
80
81 #[test]
82 fn with_strides_2d() {
83 let vector = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
84 let storage = TensorStorage::from_slice(&vector);
85 let tensor_view: Tensor<i32> =
86 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[2, 2], &[3, 1]);
87 let tensor_view_vec: Vec<i32> = tensor_view.into_iter().collect();
88 assert_eq!(tensor_view_vec, vec![5, 6, 8, 9]);
89
90 let vector = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
91 let storage = TensorStorage::from_slice(&vector);
92 let tensor_view: Tensor<i32> =
93 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 1, &[3], &[3]);
94 let tensor_view_vec: Vec<i32> = tensor_view.into_iter().collect();
95 assert_eq!(tensor_view_vec, vec![2, 5, 8]);
96
97 let vector = vec![PI, E, TAU, SQRT_2];
98 let storage = TensorStorage::from_slice(&vector);
99 let tensor_view: Tensor<f64> =
100 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 1, &[2, 1], &[2, 1]);
101 let tensor_view_vec: Vec<f64> = tensor_view.into_iter().collect();
102 assert_eq!(tensor_view_vec, vec![E, SQRT_2]);
103 }
104
105 #[test]
106 fn with_strides_3d() {
107 let vector: Vec<i32> = (0..27).collect();
108 let storage = TensorStorage::from_slice(&vector);
109 let tensor_view: Tensor<i32> =
110 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[3], &[9]);
111 let tensor_view_vec: Vec<i32> = tensor_view.into_iter().collect();
112 assert_eq!(tensor_view_vec, vec![4, 13, 22]);
113
114 let vector: Vec<i128> = (0..64).collect();
115 let storage = TensorStorage::from_slice(&vector);
116 let tensor_view: Tensor<i128> =
117 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 42, &[2, 2, 2], &[16, 4, 1]);
118 let tensor_view_vec: Vec<i128> = tensor_view.into_iter().collect();
119 assert_eq!(tensor_view_vec, vec![42, 43, 46, 47, 58, 59, 62, 63]);
120 }
121}