1use num_traits::{Float, NumCast};
2use std::{cell::RefCell, fmt::Display, ops, rc::Rc};
3
4use crate::utils::{errors::Errors, strides::new_strides_from_dim};
5
6use super::{storage::TensorStorage, tensor::Tensor};
7
8impl<T: Copy> Tensor<T> {
9 pub fn from_slice(arr: &[T]) -> Result<Tensor<T>, Errors> {
20 if arr.is_empty() {
21 return Err(Errors::EmptyTensor);
22 }
23 let tensor_storage = TensorStorage::<T>::from_slice(&arr);
24 let dims = vec![arr.len()];
25 let strides = vec![1];
26 Ok(Tensor::new_unchecked(
27 Rc::new(RefCell::new(tensor_storage)),
28 0,
29 &dims,
30 &strides,
31 ))
32 }
33
34 pub fn from_slice_and_dims(arr: &[T], dims: &[usize]) -> Result<Tensor<T>, Errors> {
46 if arr.is_empty() {
47 return Err(Errors::EmptyTensor);
48 }
49
50 if arr.len() != dims.iter().product() {
51 return Err(Errors::InputError(
52 "Expected product of dimensions to equal number of elements in storage".to_string(),
53 ));
54 }
55 let tensor_storage = TensorStorage::<T>::from_slice(&arr);
56 let strides = new_strides_from_dim(&dims);
57 Ok(Tensor::new_unchecked(
58 Rc::new(RefCell::new(tensor_storage)),
59 0,
60 &dims,
61 &strides,
62 ))
63 }
64
65 pub fn from_val(dims: &[usize], val: T) -> Result<Tensor<T>, Errors> {
77 let no_el = dims.iter().fold(1, |res, dim_sz| res * dim_sz);
78 if no_el == 0 {
79 return Err(Errors::EmptyTensor);
80 }
81
82 let tensor_storage = TensorStorage::<T>::from_val(no_el, val);
83 let new_strides = new_strides_from_dim(&dims);
84 Ok(Tensor::new_unchecked(
85 Rc::new(RefCell::new(tensor_storage)),
86 0,
87 &dims,
88 &new_strides,
89 ))
90 }
91}
92
93impl<T: Default + Copy> Tensor<T> {
94 pub fn from_default(dims: &[usize]) -> Result<Tensor<T>, Errors> {
106 Tensor::from_val(dims, T::default())
107 }
108}
109
110impl<T: Copy + NumCast> Tensor<T> {
111 pub fn zeros(dims: &[usize]) -> Result<Tensor<T>, Errors> {
121 Tensor::<T>::from_val(&dims, NumCast::from(0).unwrap())
122 }
123
124 pub fn ones(dims: &[usize]) -> Result<Tensor<T>, Errors> {
134 Tensor::<T>::from_val(&dims, NumCast::from(1).unwrap())
135 }
136
137 pub fn eye(n: usize, m: usize) -> Result<Tensor<T>, Errors> {
150 if n == 0 || m == 0 {
151 return Err(Errors::EmptyTensor);
152 }
153
154 let tensor = Tensor::<T>::zeros(&[n, m])?;
155 (0..(n.min(m))).for_each(|i| {
156 tensor.upd_unchecked(&[i, i], NumCast::from(1).unwrap());
157 });
158 Ok(tensor)
159 }
160}
161
162impl<T: Copy + ops::Add<Output = T> + ops::Mul<Output = T> + NumCast + PartialOrd> Tensor<T> {
163 pub fn arange(st: T, en: T, step: T) -> Result<Tensor<T>, Errors> {
177 let res: Vec<T> = (0..)
178 .map(|i| <T as NumCast>::from(i).unwrap() * step + st)
179 .take_while(|&v| v < en)
180 .collect();
181 Tensor::from_slice(&res)
182 }
183}
184
185impl<T: Copy + Float + Display> Tensor<T> {
186 pub fn linspace(st: T, en: T, cnt: usize) -> Result<Tensor<T>, Errors> {
199 if cnt == 0 {
200 return Err(Errors::EmptyTensor);
201 }
202
203 if en <= st {
204 return Err(Errors::InputError(format!(
205 "linspace expected st < en, found {} >= {}",
206 st, en
207 )));
208 }
209
210 if cnt == 1 {
211 return Tensor::from_slice(&[st]);
212 }
213
214 let step_sz = (en - st) / NumCast::from(cnt - 1).unwrap();
215 let res: Vec<T> = (0..cnt)
216 .map(|x| st + <T as NumCast>::from(x).unwrap() * step_sz)
217 .collect();
218 Tensor::from_slice(&res)
219 }
220
221 pub fn logspace(base: T, st: T, en: T, cnt: usize) -> Result<Tensor<T>, Errors> {
235 if cnt == 0 {
236 return Err(Errors::EmptyTensor);
237 }
238
239 if en <= st {
240 return Err(Errors::InputError(format!(
241 "logspace expected st < en, found {} >= {}",
242 st, en
243 )));
244 }
245
246 if cnt == 1 {
247 return Tensor::from_slice(&[base.powf(st)]);
248 }
249
250 let step_sz = (en - st) / NumCast::from(cnt - 1).unwrap();
251 let res: Vec<T> = (0..cnt)
252 .map(|x| st + <T as NumCast>::from(x).unwrap() * step_sz)
253 .map(|x| base.powf(x))
254 .collect();
255 Tensor::from_slice(&res)
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use std::f64::consts::E;
262
263 use crate::{core::tensor::tensor::Tensor, utils::errors::Errors};
264
265 #[test]
266 fn arange() {
267 let t = Tensor::<u128>::arange(12, 37, 7).unwrap();
268 let exp = vec![12, 19, 26, 33];
269 assert_eq!(t.into_iter().collect::<Vec<u128>>(), exp);
270
271 let t = Tensor::<f64>::arange(-10.0, -9.0, 0.1).unwrap();
272 let exp = vec![-10.0, -9.9, -9.8, -9.7, -9.6, -9.5, -9.4, -9.3, -9.2, -9.1];
273 assert!(t
274 .into_iter()
275 .zip(exp.iter())
276 .all(|(a, b)| (a - b).abs() < 0.00001));
277
278 assert!(match Tensor::arange(1, -1, 1) {
279 Ok(_) => false,
280 Err(e) => match e {
281 Errors::EmptyTensor => true,
282 _ => false,
283 },
284 });
285 }
286
287 #[test]
288 fn linspace() {
289 let t = Tensor::<f32>::linspace(-10.0, 10.0, 5).unwrap();
290 let exp = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
291 assert!(t
292 .into_iter()
293 .zip(exp.iter())
294 .all(|(a, b)| (a - b).abs() < 0.00001));
295
296 let t = Tensor::<f32>::linspace(1.0, 10.0, 21).unwrap();
297 let exp = vec![
298 1., 1.45, 1.9, 2.35, 2.8, 3.25, 3.7, 4.15, 4.6, 5.05, 5.5, 5.95, 6.4, 6.85, 7.3, 7.75,
299 8.2, 8.65, 9.1, 9.55, 10.,
300 ];
301 assert!(t
302 .into_iter()
303 .zip(exp.iter())
304 .all(|(a, b)| (a - b).abs() < 0.00001));
305
306 assert!(match Tensor::linspace(1.0, 2.0, 0) {
307 Ok(_) => false,
308 Err(e) => match e {
309 Errors::EmptyTensor => true,
310 _ => false,
311 },
312 });
313
314 assert!(match Tensor::linspace(1.0, -2.0, 2) {
315 Ok(_) => false,
316 Err(e) => match e {
317 Errors::InputError(_) => true,
318 _ => false,
319 },
320 })
321 }
322
323 #[test]
324 fn logspace() {
325 let t = Tensor::<f32>::logspace(10.0, -10.0, 10.0, 5).unwrap();
326 let exp = vec![1e-10, 1e-5, 1.0, 1e5, 1e10];
327 assert!(t
328 .into_iter()
329 .zip(exp.iter())
330 .all(|(a, b)| (a - b).abs() < 0.00001));
331
332 let t = Tensor::<f64>::logspace(E, 1.0, 5.0, 5).unwrap();
333 let exp = vec![E, E.powi(2), E.powi(3), E.powi(4), E.powi(5)];
334 assert!(t
335 .into_iter()
336 .zip(exp.iter())
337 .all(|(a, b)| (a - b).abs() < 0.00001));
338
339 assert!(match Tensor::logspace(E, 1.0, 2.0, 0) {
340 Ok(_) => false,
341 Err(e) => match e {
342 Errors::EmptyTensor => true,
343 _ => false,
344 },
345 });
346
347 assert!(match Tensor::logspace(E, 1.0, -2.0, 2) {
348 Ok(_) => false,
349 Err(e) => match e {
350 Errors::InputError(_) => true,
351 _ => false,
352 },
353 })
354 }
355
356 #[test]
357 fn eye() {
358 let t = Tensor::<i128>::eye(2, 5).unwrap();
359 let exp = vec![1, 0, 0, 0, 0, 0, 1, 0, 0, 0];
360 assert_eq!(t.into_iter().collect::<Vec<i128>>(), exp);
361
362 let t = Tensor::<f64>::eye(4, 3).unwrap();
363 let exp = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
364 assert!(t
365 .into_iter()
366 .zip(exp.iter())
367 .all(|(a, b)| (a - b).abs() < 0.00001));
368
369 assert!(match Tensor::<f32>::eye(0, 10) {
370 Ok(_) => false,
371 Err(e) => match e {
372 Errors::EmptyTensor => true,
373 _ => false,
374 },
375 });
376 }
377}