smelte_rs/cpu/f32/tensor.rs
1use crate::TensorError;
2use std::borrow::Cow;
3
4/// Tensor, can own, or borrow the underlying tensor
5#[derive(Clone)]
6pub struct Tensor<'data> {
7 pub(super) shape: Vec<usize>,
8 data: Cow<'data, [f32]>,
9}
10
11impl<'data> Tensor<'data> {
12 /// The shape of the tensor
13 /// ```
14 /// use smelte-rs::cpu::f32::Tensor;
15 ///
16 /// let tensor = Tensor::zeros(vec![2, 2]);
17 /// assert_eq!(tensor.shape(), vec![2, 2]);
18 /// ```
19 pub fn shape(&self) -> &[usize] {
20 &self.shape
21 }
22
23 /// A slice to the underlying tensor data
24 /// ```
25 /// use smelte-rs::cpu::f32::Tensor;
26 ///
27 /// let tensor = Tensor::zeros(vec![2, 2]);
28 /// assert_eq!(tensor.data(), vec![0.0; 4]);
29 /// ```
30 pub fn data(&self) -> &[f32] {
31 self.data.as_ref()
32 }
33
34 /// A mutable slice to the underlying tensor data
35 /// ```
36 /// use smelte-rs::cpu::f32::Tensor;
37 ///
38 /// let mut tensor = Tensor::zeros(vec![2, 2]);
39 /// tensor.data_mut().iter_mut().for_each(|v| *v += 1.0);
40 /// assert_eq!(tensor.data(), vec![1.0; 4]);
41 /// ```
42 pub fn data_mut(&mut self) -> &mut [f32] {
43 self.data.to_mut()
44 }
45
46 /// Creates a new nulled tensor with given shape
47 /// ```
48 /// use smelte-rs::cpu::f32::Tensor;
49 ///
50 /// let tensor = Tensor::zeros(vec![2, 2]);
51 /// ```
52 pub fn zeros(shape: Vec<usize>) -> Self {
53 let nelement: usize = shape.iter().product();
54 let data = Cow::Owned(vec![0.0; nelement]);
55 Self { shape, data }
56 }
57
58 /// Creates a new borrowed tensor with given shape. Can fail if data doesn't match the shape
59 /// ```
60 /// use smelte-rs::cpu::f32::Tensor;
61 ///
62 /// let data = [1.0, 2.0, 3.0, 4.0];
63 /// let tensor = Tensor::borrowed(&data, vec![2, 2]).unwrap();
64 /// ```
65 pub fn borrowed(data: &'data [f32], shape: Vec<usize>) -> Result<Self, TensorError> {
66 let cow: Cow<'data, [f32]> = data.into();
67 Self::new(cow, shape)
68 }
69
70 /// Creates a new tensor with given shape. Can fail if data doesn't match the shape
71 /// ```
72 /// use smelte-rs::cpu::f32::Tensor;
73 ///
74 /// let data = vec![1.0, 2.0, 3.0, 4.0];
75 /// let tensor = Tensor::new(data, vec![2, 2]).unwrap();
76 /// ```
77 pub fn new<T>(data: T, shape: Vec<usize>) -> Result<Self, TensorError>
78 where
79 T: Into<Cow<'data, [f32]>>,
80 {
81 let data = data.into();
82 if data.len() != shape.iter().product::<usize>() {
83 return Err(TensorError::InvalidBuffer {
84 buffer_size: data.len(),
85 shape,
86 });
87 }
88 Ok(Self { shape, data })
89 }
90}