1use crate::utils::{
2 errors::Errors,
3 index::{dim_index_to_storage_index, dim_index_to_storage_index_unchecked},
4};
5use std::{cell::RefCell, ops::Range, rc::Rc};
6
7use super::storage::TensorStorage;
8
9#[derive(Debug, Clone)]
11pub struct Tensor<T> {
12 pub(crate) storage: Rc<RefCell<TensorStorage<T>>>,
13 pub(crate) no_dim: usize,
14 pub(crate) no_el: usize,
15 pub(crate) offset: usize,
16 pub(crate) dims: Vec<usize>,
17 pub(crate) strides: Vec<usize>,
18}
19
20impl<T: Copy> Tensor<T> {
21 pub fn new(
24 storage: Rc<RefCell<TensorStorage<T>>>,
25 offset: usize,
26 dims: &[usize],
27 strides: &[usize],
28 ) -> Result<Tensor<T>, Errors> {
29 if dims.len() != strides.len() {
31 Err(Errors::DimsNeqStrides {
32 dim_len: dims.len(),
33 strides_len: strides.len(),
34 })
35 } else if dims.is_empty() || dims.iter().any(|&x| x == 0) {
36 Err(Errors::EmptyTensor)
37 } else {
38 let last_idx: Vec<_> = dims.iter().map(|&x| x - 1).collect();
39 let last_storage_idx = dim_index_to_storage_index(&last_idx, offset, &dims, &strides)?;
40 storage.borrow().get(last_storage_idx)?;
41 Ok(Tensor::new_unchecked(storage, offset, dims, strides))
42 }
43 }
44
45 pub fn new_unchecked(
49 storage: Rc<RefCell<TensorStorage<T>>>,
50 offset: usize,
51 dims: &[usize],
52 strides: &[usize],
53 ) -> Tensor<T> {
54 let no_dim = dims.len();
55 let no_el = dims.iter().fold(1, |res, dim_sz| res * dim_sz);
56 Tensor {
57 storage,
58 no_dim,
59 no_el,
60 offset,
61 dims: dims.to_vec(),
62 strides: strides.to_vec(),
63 }
64 }
65
66 pub fn get_storage_ptr(&self) -> Rc<RefCell<TensorStorage<T>>> {
68 Rc::clone(&self.storage)
69 }
70
71 pub fn no_dim(&self) -> usize {
73 self.no_dim
74 }
75
76 pub fn len(&self) -> usize {
78 self.no_el
79 }
80
81 pub fn is_view(&self) -> bool {
83 self.no_el != self.storage.borrow().len()
84 }
85
86 pub fn make_contiguous(&self) -> Tensor<T> {
88 let res: Vec<_> = self.into_iter().collect();
89 Tensor::from_slice_and_dims(&res, &self.dims).unwrap()
90 }
91
92 pub fn slice(&self, rngs: &[Range<usize>]) -> Result<Tensor<T>, Errors> {
94 if rngs.len() != self.no_dim {
95 Err(Errors::InvalidIndexSize {
96 expected: self.no_dim,
97 found: rngs.len(),
98 })
99 } else if let Some(idx) = rngs
100 .iter()
101 .zip(self.dims.iter())
102 .position(|(rng, &dim)| rng.end > dim)
103 {
104 Err(Errors::OutOfBounds {
105 expected: self.dims[idx],
106 found: rngs[idx].end,
107 axis: idx,
108 })
109 } else if rngs.iter().any(|rng| rng.is_empty()) {
110 Err(Errors::EmptyTensor)
111 } else {
112 Ok(self.slice_unchecked(rngs))
113 }
114 }
115
116 pub fn slice_unchecked(&self, rngs: &[Range<usize>]) -> Tensor<T> {
118 let new_offset = self.offset
119 + self
120 .strides
121 .iter()
122 .zip(rngs.iter())
123 .fold(0, |res, (&stride, rng)| res + stride * rng.start);
124 let new_dims: Vec<usize> = rngs
125 .iter()
126 .map(|rng| (rng.end - rng.start).max(1))
127 .collect();
128 Tensor::new_unchecked(
129 Rc::clone(&self.storage),
130 new_offset,
131 &new_dims,
132 &self.strides,
133 )
134 }
135
136 pub fn at(&self, index: &[usize]) -> Result<T, Errors> {
138 let storage_idx =
139 dim_index_to_storage_index(&index, self.offset, &self.dims, &self.strides)?;
140 self.storage.borrow().get(storage_idx)
141 }
142
143 pub fn at_unchecked(&self, index: &[usize]) -> T {
145 let storage_idx = dim_index_to_storage_index_unchecked(&index, self.offset, &self.strides);
146 self.storage.borrow().get_unchecked(storage_idx)
147 }
148
149 pub fn upd(&self, index: &[usize], new_val: T) -> Result<(), Errors> {
151 let storage_idx =
152 dim_index_to_storage_index(&index, self.offset, &self.dims, &self.strides)?;
153 self.storage.borrow_mut().upd(storage_idx, new_val)?;
154 Ok(())
155 }
156
157 pub fn upd_unchecked(&self, index: &[usize], new_val: T) {
159 let storage_idx = dim_index_to_storage_index_unchecked(&index, self.offset, &self.strides);
160 self.storage
161 .borrow_mut()
162 .upd_unchecked(storage_idx, new_val);
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use std::f64::consts::{E, PI, SQRT_2, TAU};
169
170 use super::*;
171
172 #[test]
173 fn splice_1d() {
174 let vector: Vec<i32> = (0..27).collect();
175 let storage = TensorStorage::from_slice(&vector);
176 let tensor_view: Tensor<i32> =
177 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 0, &[vector.len()], &[1]);
178 let sliced = tensor_view.slice(&[4..7]).unwrap();
179 let tensor_view_vec: Vec<i32> = sliced.into_iter().collect();
180 assert_eq!(tensor_view_vec, vec![4, 5, 6]);
181 assert!(match tensor_view.slice(&[4..4]) {
182 Ok(_) => false,
183 Err(e) => match e {
184 Errors::EmptyTensor => true,
185 _ => false,
186 },
187 });
188
189 let vector = vec![PI, E, TAU, SQRT_2];
190 let storage = TensorStorage::from_slice(&vector);
191 let tensor_view: Tensor<f64> =
192 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 0, &[vector.len()], &[1]);
193 let sliced = tensor_view.slice(&[1..2]).unwrap();
194 let tensor_view_vec: Vec<f64> = sliced.into_iter().collect();
195 assert_eq!(tensor_view_vec, vec![E]);
196 assert_eq!(sliced.dims, vec![1]);
197 assert!(match tensor_view.slice(&[4..5]) {
198 Ok(_) => false,
199 Err(e) => match e {
200 Errors::OutOfBounds {
201 expected: _,
202 found: _,
203 axis: _,
204 } => true,
205 _ => false,
206 },
207 });
208 }
209
210 #[test]
211 fn splice_3d() {
212 let vector: Vec<i32> = (0..27).collect();
213 let storage = TensorStorage::from_slice(&vector);
214 let tensor_view: Tensor<i32> =
215 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[3, 2, 2], &[9, 3, 1]);
216 let sliced = tensor_view.slice(&[1..2, 0..2, 1..2]).unwrap();
217 let tensor_view_vec: Vec<i32> = sliced.into_iter().collect();
218 assert_eq!(tensor_view_vec, vec![14, 17]);
219 assert_eq!(sliced.dims, vec![1, 2, 1]);
220 assert!(match tensor_view.slice(&[4..5, 4..5, 4..5]) {
221 Ok(_) => false,
222 Err(e) => match e {
223 Errors::OutOfBounds {
224 expected: _,
225 found: _,
226 axis: _,
227 } => true,
228 _ => false,
229 },
230 });
231
232 let vector: Vec<i128> = (0..64).collect();
233 let storage = TensorStorage::from_slice(&vector);
234 let tensor_view: Tensor<i128> =
235 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 42, &[2, 2, 2], &[16, 4, 1]);
236 let sliced = tensor_view.slice(&[1..2, 1..2, 1..2]).unwrap();
237 let tensor_view_vec: Vec<i128> = sliced.into_iter().collect();
238 assert_eq!(tensor_view_vec, vec![63]);
239 assert_eq!(sliced.dims, vec![1, 1, 1]);
240 assert!(match tensor_view.slice(&[0..0, 0..0, 0..0]) {
241 Ok(_) => false,
242 Err(e) => match e {
243 Errors::EmptyTensor => true,
244 _ => false,
245 },
246 });
247 }
248
249 #[test]
250 fn get() {
251 let vector: Vec<i32> = (0..27).collect();
252 let storage = TensorStorage::from_slice(&vector);
253 let tensor_view: Tensor<i32> =
254 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[3, 2, 2], &[9, 3, 1]);
255 assert_eq!(tensor_view.at(&[2, 1, 1]).unwrap(), 26);
256 assert!(match tensor_view.at(&[2, 1, 2]) {
257 Ok(_) => false,
258 Err(e) => match e {
259 Errors::OutOfBounds {
260 expected: _,
261 found: _,
262 axis: _,
263 } => true,
264 _ => false,
265 },
266 });
267
268 let vector: Vec<i128> = (0..64).collect();
269 let storage = TensorStorage::from_slice(&vector);
270 let tensor_view: Tensor<i128> =
271 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 42, &[2, 2, 2], &[16, 4, 1]);
272 assert_eq!(tensor_view.at(&[1, 1, 0]).unwrap(), 62);
273 assert!(match tensor_view.at(&[2, 1, 0]) {
274 Ok(_) => false,
275 Err(e) => match e {
276 Errors::OutOfBounds {
277 expected: _,
278 found: _,
279 axis: _,
280 } => true,
281 _ => false,
282 },
283 });
284 }
285
286 #[test]
287 fn upd() {
288 let vector: Vec<i32> = (0..27).collect();
289 let storage = TensorStorage::from_slice(&vector);
290 let tensor_view: Tensor<i32> =
291 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 4, &[3, 2, 2], &[9, 3, 1]);
292 tensor_view.upd(&[2, 1, 1], -100).unwrap();
293 assert_eq!(tensor_view.at(&[2, 1, 1]).unwrap(), -100);
294 assert!(match tensor_view.upd(&[2, 1, 2], -100) {
295 Ok(_) => false,
296 Err(e) => match e {
297 Errors::OutOfBounds {
298 expected: _,
299 found: _,
300 axis: _,
301 } => true,
302 _ => false,
303 },
304 });
305
306 let vector: Vec<i128> = (0..64).collect();
307 let storage = TensorStorage::from_slice(&vector);
308 let tensor_view: Tensor<i128> =
309 Tensor::new_unchecked(Rc::new(RefCell::new(storage)), 42, &[2, 2, 2], &[16, 4, 1]);
310 tensor_view.upd(&[1, 1, 0], -100).unwrap();
311 assert_eq!(tensor_view.at(&[1, 1, 0]).unwrap(), -100);
312 assert!(match tensor_view.upd(&[2, 1, 0], -100) {
313 Ok(_) => false,
314 Err(e) => match e {
315 Errors::OutOfBounds {
316 expected: _,
317 found: _,
318 axis: _,
319 } => true,
320 _ => false,
321 },
322 });
323 }
324}