tenso_rs/core/tensor/
tensor.rs

1use crate::utils::{
2    errors::Errors,
3    index::{dim_index_to_storage_index, dim_index_to_storage_index_unchecked},
4};
5use std::{cell::RefCell, ops::Range, rc::Rc};
6
7use super::storage::TensorStorage;
8
9/// The N Dimensional Array
10#[derive(Debug, Clone)]
11pub struct Tensor<T> {
12    pub(crate) storage: Rc<RefCell<TensorStorage<T>>>,
13    pub(crate) no_dim: usize,
14    pub(crate) no_el: usize,
15    pub(crate) offset: usize,
16    pub(crate) dims: Vec<usize>,
17    pub(crate) strides: Vec<usize>,
18}
19
20impl<T: Copy> Tensor<T> {
21    /// Create new tensor with custom TensorStorage, offset, dimensions, and strides.
22    /// Not recommended for use
23    pub fn new(
24        storage: Rc<RefCell<TensorStorage<T>>>,
25        offset: usize,
26        dims: &[usize],
27        strides: &[usize],
28    ) -> Result<Tensor<T>, Errors> {
29        // TODO: Check if any element is outside storage limit
30        if dims.len() != strides.len() {
31            Err(Errors::DimsNeqStrides {
32                dim_len: dims.len(),
33                strides_len: strides.len(),
34            })
35        } else if dims.is_empty() || dims.iter().any(|&x| x == 0) {
36            Err(Errors::EmptyTensor)
37        } else {
38            let last_idx: Vec<_> = dims.iter().map(|&x| x - 1).collect();
39            let last_storage_idx = dim_index_to_storage_index(&last_idx, offset, &dims, &strides)?;
40            storage.borrow().get(last_storage_idx)?;
41            Ok(Tensor::new_unchecked(storage, offset, dims, strides))
42        }
43    }
44
45    /// Create new tensor with custom TensorStorage, offset, dimensions, and strides, without any
46    /// checks
47    /// Not recommended for use
48    pub fn new_unchecked(
49        storage: Rc<RefCell<TensorStorage<T>>>,
50        offset: usize,
51        dims: &[usize],
52        strides: &[usize],
53    ) -> Tensor<T> {
54        let no_dim = dims.len();
55        let no_el = dims.iter().fold(1, |res, dim_sz| res * dim_sz);
56        Tensor {
57            storage,
58            no_dim,
59            no_el,
60            offset,
61            dims: dims.to_vec(),
62            strides: strides.to_vec(),
63        }
64    }
65
66    /// Get pointer to `self.storage`
67    pub fn get_storage_ptr(&self) -> Rc<RefCell<TensorStorage<T>>> {
68        Rc::clone(&self.storage)
69    }
70
71    /// Get number of dimensions
72    pub fn no_dim(&self) -> usize {
73        self.no_dim
74    }
75
76    /// Get number of elements
77    pub fn len(&self) -> usize {
78        self.no_el
79    }
80
81    /// Does `self` borrow memory from other tensors
82    pub fn is_view(&self) -> bool {
83        self.no_el != self.storage.borrow().len()
84    }
85
86    /// Make `self` own it's own data
87    pub fn make_contiguous(&self) -> Tensor<T> {
88        let res: Vec<_> = self.into_iter().collect();
89        Tensor::from_slice_and_dims(&res, &self.dims).unwrap()
90    }
91
92    /// Index `self` with `rngs`
93    pub fn slice(&self, rngs: &[Range<usize>]) -> Result<Tensor<T>, Errors> {
94        if rngs.len() != self.no_dim {
95            Err(Errors::InvalidIndexSize {
96                expected: self.no_dim,
97                found: rngs.len(),
98            })
99        } else if let Some(idx) = rngs
100            .iter()
101            .zip(self.dims.iter())
102            .position(|(rng, &dim)| rng.end > dim)
103        {
104            Err(Errors::OutOfBounds {
105                expected: self.dims[idx],
106                found: rngs[idx].end,
107                axis: idx,
108            })
109        } else if rngs.iter().any(|rng| rng.is_empty()) {
110            Err(Errors::EmptyTensor)
111        } else {
112            Ok(self.slice_unchecked(rngs))
113        }
114    }
115
116    /// Index `self` with `rngs` without checks
117    pub fn slice_unchecked(&self, rngs: &[Range<usize>]) -> Tensor<T> {
118        let new_offset = self.offset
119            + self
120                .strides
121                .iter()
122                .zip(rngs.iter())
123                .fold(0, |res, (&stride, rng)| res + stride * rng.start);
124        let new_dims: Vec<usize> = rngs
125            .iter()
126            .map(|rng| (rng.end - rng.start).max(1))
127            .collect();
128        Tensor::new_unchecked(
129            Rc::clone(&self.storage),
130            new_offset,
131            &new_dims,
132            &self.strides,
133        )
134    }
135
136    /// Get value in `self` at index `index`
137    pub fn at(&self, index: &[usize]) -> Result<T, Errors> {
138        let storage_idx =
139            dim_index_to_storage_index(&index, self.offset, &self.dims, &self.strides)?;
140        self.storage.borrow().get(storage_idx)
141    }
142
143    /// Get value in `self` at index `index` without checks
144    pub fn at_unchecked(&self, index: &[usize]) -> T {
145        let storage_idx = dim_index_to_storage_index_unchecked(&index, self.offset, &self.strides);
146        self.storage.borrow().get_unchecked(storage_idx)
147    }
148
149    /// Update value in `self` at index `index` to `new_val`
150    pub fn upd(&self, index: &[usize], new_val: T) -> Result<(), Errors> {
151        let storage_idx =
152            dim_index_to_storage_index(&index, self.offset, &self.dims, &self.strides)?;
153        self.storage.borrow_mut().upd(storage_idx, new_val)?;
154        Ok(())
155    }
156
157    /// Update value in `self` at index `index` to `new_val` without checks
158    pub fn upd_unchecked(&self, index: &[usize], new_val: T) {
159        let storage_idx = dim_index_to_storage_index_unchecked(&index, self.offset, &self.strides);
160        self.storage
161            .borrow_mut()
162            .upd_unchecked(storage_idx, new_val);
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use std::f64::consts::{E, PI, SQRT_2, TAU};
169
170    use super::*;
171
172    #[test]
173    fn splice_1d() {
174        let vector: Vec<i32> = (0..27).collect();
175        let storage = TensorStorage::from_slice(&vector);
176        let tensor_view: Tensor<i32> =
177            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 0, &[vector.len()], &[1]);
178        let sliced = tensor_view.slice(&[4..7]).unwrap();
179        let tensor_view_vec: Vec<i32> = sliced.into_iter().collect();
180        assert_eq!(tensor_view_vec, vec![4, 5, 6]);
181        assert!(match tensor_view.slice(&[4..4]) {
182            Ok(_) => false,
183            Err(e) => match e {
184                Errors::EmptyTensor => true,
185                _ => false,
186            },
187        });
188
189        let vector = vec![PI, E, TAU, SQRT_2];
190        let storage = TensorStorage::from_slice(&vector);
191        let tensor_view: Tensor<f64> =
192            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 0, &[vector.len()], &[1]);
193        let sliced = tensor_view.slice(&[1..2]).unwrap();
194        let tensor_view_vec: Vec<f64> = sliced.into_iter().collect();
195        assert_eq!(tensor_view_vec, vec![E]);
196        assert_eq!(sliced.dims, vec![1]);
197        assert!(match tensor_view.slice(&[4..5]) {
198            Ok(_) => false,
199            Err(e) => match e {
200                Errors::OutOfBounds {
201                    expected: _,
202                    found: _,
203                    axis: _,
204                } => true,
205                _ => false,
206            },
207        });
208    }
209
210    #[test]
211    fn splice_3d() {
212        let vector: Vec<i32> = (0..27).collect();
213        let storage = TensorStorage::from_slice(&vector);
214        let tensor_view: Tensor<i32> =
215            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[3, 2, 2], &[9, 3, 1]);
216        let sliced = tensor_view.slice(&[1..2, 0..2, 1..2]).unwrap();
217        let tensor_view_vec: Vec<i32> = sliced.into_iter().collect();
218        assert_eq!(tensor_view_vec, vec![14, 17]);
219        assert_eq!(sliced.dims, vec![1, 2, 1]);
220        assert!(match tensor_view.slice(&[4..5, 4..5, 4..5]) {
221            Ok(_) => false,
222            Err(e) => match e {
223                Errors::OutOfBounds {
224                    expected: _,
225                    found: _,
226                    axis: _,
227                } => true,
228                _ => false,
229            },
230        });
231
232        let vector: Vec<i128> = (0..64).collect();
233        let storage = TensorStorage::from_slice(&vector);
234        let tensor_view: Tensor<i128> =
235            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 42, &[2, 2, 2], &[16, 4, 1]);
236        let sliced = tensor_view.slice(&[1..2, 1..2, 1..2]).unwrap();
237        let tensor_view_vec: Vec<i128> = sliced.into_iter().collect();
238        assert_eq!(tensor_view_vec, vec![63]);
239        assert_eq!(sliced.dims, vec![1, 1, 1]);
240        assert!(match tensor_view.slice(&[0..0, 0..0, 0..0]) {
241            Ok(_) => false,
242            Err(e) => match e {
243                Errors::EmptyTensor => true,
244                _ => false,
245            },
246        });
247    }
248
249    #[test]
250    fn get() {
251        let vector: Vec<i32> = (0..27).collect();
252        let storage = TensorStorage::from_slice(&vector);
253        let tensor_view: Tensor<i32> =
254            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[3, 2, 2], &[9, 3, 1]);
255        assert_eq!(tensor_view.at(&[2, 1, 1]).unwrap(), 26);
256        assert!(match tensor_view.at(&[2, 1, 2]) {
257            Ok(_) => false,
258            Err(e) => match e {
259                Errors::OutOfBounds {
260                    expected: _,
261                    found: _,
262                    axis: _,
263                } => true,
264                _ => false,
265            },
266        });
267
268        let vector: Vec<i128> = (0..64).collect();
269        let storage = TensorStorage::from_slice(&vector);
270        let tensor_view: Tensor<i128> =
271            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 42, &[2, 2, 2], &[16, 4, 1]);
272        assert_eq!(tensor_view.at(&[1, 1, 0]).unwrap(), 62);
273        assert!(match tensor_view.at(&[2, 1, 0]) {
274            Ok(_) => false,
275            Err(e) => match e {
276                Errors::OutOfBounds {
277                    expected: _,
278                    found: _,
279                    axis: _,
280                } => true,
281                _ => false,
282            },
283        });
284    }
285
286    #[test]
287    fn upd() {
288        let vector: Vec<i32> = (0..27).collect();
289        let storage = TensorStorage::from_slice(&vector);
290        let tensor_view: Tensor<i32> =
291            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[3, 2, 2], &[9, 3, 1]);
292        tensor_view.upd(&[2, 1, 1], -100).unwrap();
293        assert_eq!(tensor_view.at(&[2, 1, 1]).unwrap(), -100);
294        assert!(match tensor_view.upd(&[2, 1, 2], -100) {
295            Ok(_) => false,
296            Err(e) => match e {
297                Errors::OutOfBounds {
298                    expected: _,
299                    found: _,
300                    axis: _,
301                } => true,
302                _ => false,
303            },
304        });
305
306        let vector: Vec<i128> = (0..64).collect();
307        let storage = TensorStorage::from_slice(&vector);
308        let tensor_view: Tensor<i128> =
309            Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 42, &[2, 2, 2], &[16, 4, 1]);
310        tensor_view.upd(&[1, 1, 0], -100).unwrap();
311        assert_eq!(tensor_view.at(&[1, 1, 0]).unwrap(), -100);
312        assert!(match tensor_view.upd(&[2, 1, 0], -100) {
313            Ok(_) => false,
314            Err(e) => match e {
315                Errors::OutOfBounds {
316                    expected: _,
317                    found: _,
318                    axis: _,
319                } => true,
320                _ => false,
321            },
322        });
323    }
324}