tenso_rs/core/tensor/
impl_creation_ops.rs

1use num_traits::{Float, NumCast};
2use std::{cell::RefCell, fmt::Display, ops, rc::Rc};
3
4use crate::utils::{errors::Errors, strides::new_strides_from_dim};
5
6use super::{storage::TensorStorage, tensor::Tensor};
7
8impl<T: Copy> Tensor<T> {
9    /// Create a one dimensional tensor from a specified slice
10    ///
11    /// # Arguments
12    /// * arr - The slice containing the data
13    ///
14    /// # Examples
15    /// ```rust
16    /// let t = Tensor::from_slice(&[1, 2, 3, 4, 5, 6]).unwrap();
17    /// // t = [1, 2, 3, 4, 5, 6]
18    /// ```
19    pub fn from_slice(arr: &[T]) -> Result<Tensor<T>, Errors> {
20        if arr.is_empty() {
21            return Err(Errors::EmptyTensor);
22        }
23        let tensor_storage = TensorStorage::<T>::from_slice(&arr);
24        let dims = vec![arr.len()];
25        let strides = vec![1];
26        Ok(Tensor::new_unchecked(
27            Rc::new(RefCell::new(tensor_storage)),
28            0,
29            &dims,
30            &strides,
31        ))
32    }
33
34    /// Create a new tensor with specified data in a slice, with specified dimensions
35    ///
36    /// # Arguments
37    /// * arr - The slice containing the data
38    /// * dims - The required dimensions
39    ///
40    /// # Examples
41    /// ```rust
42    /// let t = Tensor::from_slice_and_dims(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]).unwrap();
43    /// // t = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
44    /// ```
45    pub fn from_slice_and_dims(arr: &[T], dims: &[usize]) -> Result<Tensor<T>, Errors> {
46        if arr.is_empty() {
47            return Err(Errors::EmptyTensor);
48        }
49
50        if arr.len() != dims.iter().product() {
51            return Err(Errors::InputError(
52                "Expected product of dimensions to equal number of elements in storage".to_string(),
53            ));
54        }
55        let tensor_storage = TensorStorage::<T>::from_slice(&arr);
56        let strides = new_strides_from_dim(&dims);
57        Ok(Tensor::new_unchecked(
58            Rc::new(RefCell::new(tensor_storage)),
59            0,
60            &dims,
61            &strides,
62        ))
63    }
64
65    /// Return a new tensor with all elements equal to a specified value, with specified dimensions
66    ///
67    /// # Arguments
68    /// * dims - The required dimensions
69    /// * val - The value
70    ///
71    /// # Examples
72    /// ```rust
73    /// let t = Tensor::from_val(&[2, 2], 1729).unwrap();
74    /// // t = [[1729, 1729], [1729, 1729]]
75    /// ```
76    pub fn from_val(dims: &[usize], val: T) -> Result<Tensor<T>, Errors> {
77        let no_el = dims.iter().fold(1, |res, dim_sz| res * dim_sz);
78        if no_el == 0 {
79            return Err(Errors::EmptyTensor);
80        }
81
82        let tensor_storage = TensorStorage::<T>::from_val(no_el, val);
83        let new_strides = new_strides_from_dim(&dims);
84        Ok(Tensor::new_unchecked(
85            Rc::new(RefCell::new(tensor_storage)),
86            0,
87            &dims,
88            &new_strides,
89        ))
90    }
91}
92
93impl<T: Default + Copy> Tensor<T> {
94    /// Return a tensor with all elements equal to the type's default value, with specified
95    /// dimensions
96    ///
97    /// # Arguments
98    /// * dims - The required dimensions
99    ///
100    /// # Examples
101    /// ```rust
102    /// let t = Tensor::<i32>::from_default(&[2, 3]).unwrap();
103    /// // t = [[0, 0, 0], [0, 0, 0]]
104    /// ```
105    pub fn from_default(dims: &[usize]) -> Result<Tensor<T>, Errors> {
106        Tensor::from_val(dims, T::default())
107    }
108}
109
110impl<T: Copy + NumCast> Tensor<T> {
111    /// Return a tensor with all elements equal to 0, with specified dimensions
112    ///
113    /// # Arguments
114    /// * dims - The required dimensions
115    ///
116    /// # Examples
117    /// ```rust
118    /// let t = Tensor::<i32>::zeros(&[2, 3]).unwrap();
119    /// // t = [[0, 0, 0], [0, 0, 0]]
120    pub fn zeros(dims: &[usize]) -> Result<Tensor<T>, Errors> {
121        Tensor::<T>::from_val(&dims, NumCast::from(0).unwrap())
122    }
123
124    /// Return a tensor with all elements equal to 1, with specified dimensions
125    ///
126    /// # Arguments
127    /// * dims - The required dimensions
128    ///
129    /// # Examples
130    /// ```rust
131    /// let t = Tensor::<i32>::ones(&[2, 3]).unwrap();
132    /// // t = [[1, 1, 1], [1, 1, 1]]
133    pub fn ones(dims: &[usize]) -> Result<Tensor<T>, Errors> {
134        Tensor::<T>::from_val(&dims, NumCast::from(1).unwrap())
135    }
136
137    /// Return a 2D tensor with elements equal to 0 except the main diagonal (when row = column)
138    /// which is equal to 1
139    ///
140    /// # Arguments
141    /// * n - Number of rows
142    /// * m - Number of columns
143    ///
144    /// # Examples
145    /// ```rust
146    /// let t = Tensor::<i32>::eye(2, 3).unwrap();
147    /// // t = [[1, 0, 0], [0, 1, 0]]
148    /// ```
149    pub fn eye(n: usize, m: usize) -> Result<Tensor<T>, Errors> {
150        if n == 0 || m == 0 {
151            return Err(Errors::EmptyTensor);
152        }
153
154        let tensor = Tensor::<T>::zeros(&[n, m])?;
155        (0..(n.min(m))).for_each(|i| {
156            tensor.upd_unchecked(&[i, i], NumCast::from(1).unwrap());
157        });
158        Ok(tensor)
159    }
160}
161
162impl<T: Copy + ops::Add<Output = T> + ops::Mul<Output = T> + NumCast + PartialOrd> Tensor<T> {
163    /// Return a 1D tensor where the i th element is `st` + i * `step` such that each element is in
164    /// the range from `st` to `en` (exclusive)
165    ///
166    /// # Arguments
167    /// * st - The start of the range
168    /// * en - The end of the range
169    /// * step - The step size
170    ///
171    /// # Examples
172    /// ```rust
173    /// let t = Tensor::arange(1, 11, 1).unwrap();
174    /// // t = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
175    /// ```
176    pub fn arange(st: T, en: T, step: T) -> Result<Tensor<T>, Errors> {
177        let res: Vec<T> = (0..)
178            .map(|i| <T as NumCast>::from(i).unwrap() * step + st)
179            .take_while(|&v| v < en)
180            .collect();
181        Tensor::from_slice(&res)
182    }
183}
184
185impl<T: Copy + Float + Display> Tensor<T> {
186    /// Returns a 1D tensor with `cnt` elements equally spaced in the range from `st` to `en`
187    ///
188    /// # Arguments
189    /// * st - Start of the range
190    /// * en - End of the range
191    /// * cnt - The number of elements in the resultant tensor
192    ///
193    /// # Examples
194    /// ```rust
195    /// let t = Tensor::linspace(0.0, 1.0, 11).unwrap();
196    /// // t = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
197    /// ```
198    pub fn linspace(st: T, en: T, cnt: usize) -> Result<Tensor<T>, Errors> {
199        if cnt == 0 {
200            return Err(Errors::EmptyTensor);
201        }
202
203        if en <= st {
204            return Err(Errors::InputError(format!(
205                "linspace expected st < en, found {} >= {}",
206                st, en
207            )));
208        }
209
210        if cnt == 1 {
211            return Tensor::from_slice(&[st]);
212        }
213
214        let step_sz = (en - st) / NumCast::from(cnt - 1).unwrap();
215        let res: Vec<T> = (0..cnt)
216            .map(|x| st + <T as NumCast>::from(x).unwrap() * step_sz)
217            .collect();
218        Tensor::from_slice(&res)
219    }
220
221    /// Return a 1D Tensor of size `cnt` with elements evenly spaced from `base`^`st` to
222    /// `base`^`en` on a logarithmic scale with base `base`
223    ///
224    /// # Arguments
225    /// * base - Base for the logarithmic scale
226    /// * st - Start of the range
227    /// * en - End of the range
228    ///
229    /// # Examples
230    /// ```rust
231    /// let t = Tensor::logspace(10.0, -10.0, 10.0, 5).unwrap();
232    /// // t = [1e-10, 1e-5, 1, 1e5, 1e10]
233    /// ```
234    pub fn logspace(base: T, st: T, en: T, cnt: usize) -> Result<Tensor<T>, Errors> {
235        if cnt == 0 {
236            return Err(Errors::EmptyTensor);
237        }
238
239        if en <= st {
240            return Err(Errors::InputError(format!(
241                "logspace expected st < en, found {} >= {}",
242                st, en
243            )));
244        }
245
246        if cnt == 1 {
247            return Tensor::from_slice(&[base.powf(st)]);
248        }
249
250        let step_sz = (en - st) / NumCast::from(cnt - 1).unwrap();
251        let res: Vec<T> = (0..cnt)
252            .map(|x| st + <T as NumCast>::from(x).unwrap() * step_sz)
253            .map(|x| base.powf(x))
254            .collect();
255        Tensor::from_slice(&res)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use std::f64::consts::E;
262
263    use crate::{core::tensor::tensor::Tensor, utils::errors::Errors};
264
265    #[test]
266    fn arange() {
267        let t = Tensor::<u128>::arange(12, 37, 7).unwrap();
268        let exp = vec![12, 19, 26, 33];
269        assert_eq!(t.into_iter().collect::<Vec<u128>>(), exp);
270
271        let t = Tensor::<f64>::arange(-10.0, -9.0, 0.1).unwrap();
272        let exp = vec![-10.0, -9.9, -9.8, -9.7, -9.6, -9.5, -9.4, -9.3, -9.2, -9.1];
273        assert!(t
274            .into_iter()
275            .zip(exp.iter())
276            .all(|(a, b)| (a - b).abs() < 0.00001));
277
278        assert!(match Tensor::arange(1, -1, 1) {
279            Ok(_) => false,
280            Err(e) => match e {
281                Errors::EmptyTensor => true,
282                _ => false,
283            },
284        });
285    }
286
287    #[test]
288    fn linspace() {
289        let t = Tensor::<f32>::linspace(-10.0, 10.0, 5).unwrap();
290        let exp = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
291        assert!(t
292            .into_iter()
293            .zip(exp.iter())
294            .all(|(a, b)| (a - b).abs() < 0.00001));
295
296        let t = Tensor::<f32>::linspace(1.0, 10.0, 21).unwrap();
297        let exp = vec![
298            1., 1.45, 1.9, 2.35, 2.8, 3.25, 3.7, 4.15, 4.6, 5.05, 5.5, 5.95, 6.4, 6.85, 7.3, 7.75,
299            8.2, 8.65, 9.1, 9.55, 10.,
300        ];
301        assert!(t
302            .into_iter()
303            .zip(exp.iter())
304            .all(|(a, b)| (a - b).abs() < 0.00001));
305
306        assert!(match Tensor::linspace(1.0, 2.0, 0) {
307            Ok(_) => false,
308            Err(e) => match e {
309                Errors::EmptyTensor => true,
310                _ => false,
311            },
312        });
313
314        assert!(match Tensor::linspace(1.0, -2.0, 2) {
315            Ok(_) => false,
316            Err(e) => match e {
317                Errors::InputError(_) => true,
318                _ => false,
319            },
320        })
321    }
322
323    #[test]
324    fn logspace() {
325        let t = Tensor::<f32>::logspace(10.0, -10.0, 10.0, 5).unwrap();
326        let exp = vec![1e-10, 1e-5, 1.0, 1e5, 1e10];
327        assert!(t
328            .into_iter()
329            .zip(exp.iter())
330            .all(|(a, b)| (a - b).abs() < 0.00001));
331
332        let t = Tensor::<f64>::logspace(E, 1.0, 5.0, 5).unwrap();
333        let exp = vec![E, E.powi(2), E.powi(3), E.powi(4), E.powi(5)];
334        assert!(t
335            .into_iter()
336            .zip(exp.iter())
337            .all(|(a, b)| (a - b).abs() < 0.00001));
338
339        assert!(match Tensor::logspace(E, 1.0, 2.0, 0) {
340            Ok(_) => false,
341            Err(e) => match e {
342                Errors::EmptyTensor => true,
343                _ => false,
344            },
345        });
346
347        assert!(match Tensor::logspace(E, 1.0, -2.0, 2) {
348            Ok(_) => false,
349            Err(e) => match e {
350                Errors::InputError(_) => true,
351                _ => false,
352            },
353        })
354    }
355
356    #[test]
357    fn eye() {
358        let t = Tensor::<i128>::eye(2, 5).unwrap();
359        let exp = vec![1, 0, 0, 0, 0, 0, 1, 0, 0, 0];
360        assert_eq!(t.into_iter().collect::<Vec<i128>>(), exp);
361
362        let t = Tensor::<f64>::eye(4, 3).unwrap();
363        let exp = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
364        assert!(t
365            .into_iter()
366            .zip(exp.iter())
367            .all(|(a, b)| (a - b).abs() < 0.00001));
368
369        assert!(match Tensor::<f32>::eye(0, 10) {
370            Ok(_) => false,
371            Err(e) => match e {
372                Errors::EmptyTensor => true,
373                _ => false,
374            },
375        });
376    }
377}