scivex_core/tensor/
create.rs1use crate::error::{CoreError, Result};
4use crate::{Float, Scalar};
5
6use super::{Tensor, compute_strides};
7
8impl<T: Scalar> Tensor<T> {
9 pub fn zeros(shape: Vec<usize>) -> Self {
18 let numel: usize = shape.iter().product();
19 let strides = compute_strides(&shape);
20 Self {
21 data: vec![T::zero(); numel],
22 shape,
23 strides,
24 }
25 }
26
27 pub fn ones(shape: Vec<usize>) -> Self {
35 let numel: usize = shape.iter().product();
36 let strides = compute_strides(&shape);
37 Self {
38 data: vec![T::one(); numel],
39 shape,
40 strides,
41 }
42 }
43
44 pub fn full(shape: Vec<usize>, value: T) -> Self {
52 let numel: usize = shape.iter().product();
53 let strides = compute_strides(&shape);
54 Self {
55 data: vec![value; numel],
56 shape,
57 strides,
58 }
59 }
60
61 pub fn arange(n: usize) -> Self {
69 let data: Vec<T> = (0..n).map(T::from_usize).collect();
70 let strides = compute_strides(&[n]);
71 Self {
72 data,
73 shape: vec![n],
74 strides,
75 }
76 }
77
78 pub fn eye(n: usize) -> Self {
88 let mut data = vec![T::zero(); n * n];
89 for i in 0..n {
90 data[i * n + i] = T::one();
91 }
92 let strides = compute_strides(&[n, n]);
93 Self {
94 data,
95 shape: vec![n, n],
96 strides,
97 }
98 }
99}
100
101impl<T: Float> Tensor<T> {
102 pub fn linspace(start: T, end: T, n: usize) -> Result<Self> {
113 if n < 2 {
114 return Err(CoreError::InvalidArgument {
115 reason: "linspace requires n >= 2",
116 });
117 }
118 let step = (end - start) / T::from_usize(n - 1);
119 let data: Vec<T> = (0..n).map(|i| start + step * T::from_usize(i)).collect();
120 let strides = compute_strides(&[n]);
121 Ok(Self {
122 data,
123 shape: vec![n],
124 strides,
125 })
126 }
127}
128
129#[cfg(test)]
130#[allow(clippy::float_cmp)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_zeros() {
136 let t = Tensor::<f64>::zeros(vec![3, 4]);
137 assert_eq!(t.shape(), &[3, 4]);
138 assert_eq!(t.numel(), 12);
139 assert!(t.iter().all(|&x| x == 0.0));
140 }
141
142 #[test]
143 fn test_ones() {
144 let t = Tensor::<f32>::ones(vec![2, 2]);
145 assert!(t.iter().all(|&x| x == 1.0));
146 }
147
148 #[test]
149 fn test_full() {
150 let t = Tensor::full(vec![2, 3], 7_i32);
151 assert!(t.iter().all(|&x| x == 7));
152 }
153
154 #[test]
155 fn test_arange() {
156 let t = Tensor::<i32>::arange(5);
157 assert_eq!(t.as_slice(), &[0, 1, 2, 3, 4]);
158 assert_eq!(t.shape(), &[5]);
159 }
160
161 #[test]
162 fn test_arange_zero() {
163 let t = Tensor::<i32>::arange(0);
164 assert!(t.is_empty());
165 assert_eq!(t.shape(), &[0]);
166 }
167
168 #[test]
169 fn test_eye() {
170 let t = Tensor::<f64>::eye(3);
171 assert_eq!(t.shape(), &[3, 3]);
172 assert_eq!(*t.get(&[0, 0]).unwrap(), 1.0);
173 assert_eq!(*t.get(&[1, 1]).unwrap(), 1.0);
174 assert_eq!(*t.get(&[2, 2]).unwrap(), 1.0);
175 assert_eq!(*t.get(&[0, 1]).unwrap(), 0.0);
176 assert_eq!(*t.get(&[1, 0]).unwrap(), 0.0);
177 }
178
179 #[test]
180 fn test_linspace() {
181 let t = Tensor::<f64>::linspace(0.0, 1.0, 5).unwrap();
182 assert_eq!(t.shape(), &[5]);
183 assert_eq!(*t.get(&[0]).unwrap(), 0.0);
184 assert_eq!(*t.get(&[4]).unwrap(), 1.0);
185 assert!((t.as_slice()[2] - 0.5).abs() < 1e-15);
186 }
187
188 #[test]
189 fn test_linspace_invalid() {
190 assert!(Tensor::<f64>::linspace(0.0, 1.0, 1).is_err());
191 }
192}