Skip to main content

scivex_core/tensor/
create.rs

1//! Tensor creation functions analogous to `np.zeros`, `np.ones`, etc.
2
3use crate::error::{CoreError, Result};
4use crate::{Float, Scalar};
5
6use super::{Tensor, compute_strides};
7
8impl<T: Scalar> Tensor<T> {
9    /// Create a tensor filled with zeros.
10    ///
11    /// ```
12    /// # use scivex_core::tensor::Tensor;
13    /// let t = Tensor::<f64>::zeros(vec![2, 3]);
14    /// assert_eq!(t.shape(), &[2, 3]);
15    /// assert!(t.iter().all(|&x| x == 0.0));
16    /// ```
17    pub fn zeros(shape: Vec<usize>) -> Self {
18        let numel: usize = shape.iter().product();
19        let strides = compute_strides(&shape);
20        Self {
21            data: vec![T::zero(); numel],
22            shape,
23            strides,
24        }
25    }
26
27    /// Create a tensor filled with ones.
28    ///
29    /// ```
30    /// # use scivex_core::tensor::Tensor;
31    /// let t = Tensor::<f64>::ones(vec![2, 3]);
32    /// assert!(t.iter().all(|&x| x == 1.0));
33    /// ```
34    pub fn ones(shape: Vec<usize>) -> Self {
35        let numel: usize = shape.iter().product();
36        let strides = compute_strides(&shape);
37        Self {
38            data: vec![T::one(); numel],
39            shape,
40            strides,
41        }
42    }
43
44    /// Create a tensor filled with a constant value.
45    ///
46    /// ```
47    /// # use scivex_core::tensor::Tensor;
48    /// let t = Tensor::full(vec![2, 2], 7_i32);
49    /// assert!(t.iter().all(|&x| x == 7));
50    /// ```
51    pub fn full(shape: Vec<usize>, value: T) -> Self {
52        let numel: usize = shape.iter().product();
53        let strides = compute_strides(&shape);
54        Self {
55            data: vec![value; numel],
56            shape,
57            strides,
58        }
59    }
60
61    /// Create a 1-D tensor with values `[0, 1, 2, ..., n-1]`.
62    ///
63    /// ```
64    /// # use scivex_core::tensor::Tensor;
65    /// let t = Tensor::<i32>::arange(5);
66    /// assert_eq!(t.as_slice(), &[0, 1, 2, 3, 4]);
67    /// ```
68    pub fn arange(n: usize) -> Self {
69        let data: Vec<T> = (0..n).map(T::from_usize).collect();
70        let strides = compute_strides(&[n]);
71        Self {
72            data,
73            shape: vec![n],
74            strides,
75        }
76    }
77
78    /// Create an identity matrix of size `n x n`.
79    ///
80    /// ```
81    /// # use scivex_core::tensor::Tensor;
82    /// let eye = Tensor::<f64>::eye(3);
83    /// assert_eq!(eye.shape(), &[3, 3]);
84    /// assert_eq!(*eye.get(&[0, 0]).unwrap(), 1.0);
85    /// assert_eq!(*eye.get(&[0, 1]).unwrap(), 0.0);
86    /// ```
87    pub fn eye(n: usize) -> Self {
88        let mut data = vec![T::zero(); n * n];
89        for i in 0..n {
90            data[i * n + i] = T::one();
91        }
92        let strides = compute_strides(&[n, n]);
93        Self {
94            data,
95            shape: vec![n, n],
96            strides,
97        }
98    }
99}
100
101impl<T: Float> Tensor<T> {
102    /// Create a 1-D tensor with `n` evenly spaced values from `start` to `end`
103    /// (inclusive).
104    ///
105    /// Returns an error if `n < 2`.
106    ///
107    /// ```
108    /// # use scivex_core::tensor::Tensor;
109    /// let t = Tensor::<f64>::linspace(0.0, 1.0, 5).unwrap();
110    /// assert_eq!(t.shape(), &[5]);
111    /// ```
112    pub fn linspace(start: T, end: T, n: usize) -> Result<Self> {
113        if n < 2 {
114            return Err(CoreError::InvalidArgument {
115                reason: "linspace requires n >= 2",
116            });
117        }
118        let step = (end - start) / T::from_usize(n - 1);
119        let data: Vec<T> = (0..n).map(|i| start + step * T::from_usize(i)).collect();
120        let strides = compute_strides(&[n]);
121        Ok(Self {
122            data,
123            shape: vec![n],
124            strides,
125        })
126    }
127}
128
129#[cfg(test)]
130#[allow(clippy::float_cmp)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_zeros() {
136        let t = Tensor::<f64>::zeros(vec![3, 4]);
137        assert_eq!(t.shape(), &[3, 4]);
138        assert_eq!(t.numel(), 12);
139        assert!(t.iter().all(|&x| x == 0.0));
140    }
141
142    #[test]
143    fn test_ones() {
144        let t = Tensor::<f32>::ones(vec![2, 2]);
145        assert!(t.iter().all(|&x| x == 1.0));
146    }
147
148    #[test]
149    fn test_full() {
150        let t = Tensor::full(vec![2, 3], 7_i32);
151        assert!(t.iter().all(|&x| x == 7));
152    }
153
154    #[test]
155    fn test_arange() {
156        let t = Tensor::<i32>::arange(5);
157        assert_eq!(t.as_slice(), &[0, 1, 2, 3, 4]);
158        assert_eq!(t.shape(), &[5]);
159    }
160
161    #[test]
162    fn test_arange_zero() {
163        let t = Tensor::<i32>::arange(0);
164        assert!(t.is_empty());
165        assert_eq!(t.shape(), &[0]);
166    }
167
168    #[test]
169    fn test_eye() {
170        let t = Tensor::<f64>::eye(3);
171        assert_eq!(t.shape(), &[3, 3]);
172        assert_eq!(*t.get(&[0, 0]).unwrap(), 1.0);
173        assert_eq!(*t.get(&[1, 1]).unwrap(), 1.0);
174        assert_eq!(*t.get(&[2, 2]).unwrap(), 1.0);
175        assert_eq!(*t.get(&[0, 1]).unwrap(), 0.0);
176        assert_eq!(*t.get(&[1, 0]).unwrap(), 0.0);
177    }
178
179    #[test]
180    fn test_linspace() {
181        let t = Tensor::<f64>::linspace(0.0, 1.0, 5).unwrap();
182        assert_eq!(t.shape(), &[5]);
183        assert_eq!(*t.get(&[0]).unwrap(), 0.0);
184        assert_eq!(*t.get(&[4]).unwrap(), 1.0);
185        assert!((t.as_slice()[2] - 0.5).abs() < 1e-15);
186    }
187
188    #[test]
189    fn test_linspace_invalid() {
190        assert!(Tensor::<f64>::linspace(0.0, 1.0, 1).is_err());
191    }
192}