smelte_rs/cpu/f32/
tensor.rs

1use crate::TensorError;
2use std::borrow::Cow;
3
4/// Tensor, can own, or borrow the underlying tensor
5#[derive(Clone)]
6pub struct Tensor<'data> {
7    pub(super) shape: Vec<usize>,
8    data: Cow<'data, [f32]>,
9}
10
11impl<'data> Tensor<'data> {
12    /// The shape of the tensor
13    /// ```
14    /// use smelte-rs::cpu::f32::Tensor;
15    ///
16    /// let tensor = Tensor::zeros(vec![2, 2]);
17    /// assert_eq!(tensor.shape(), vec![2, 2]);
18    /// ```
19    pub fn shape(&self) -> &[usize] {
20        &self.shape
21    }
22
23    /// A slice to the underlying tensor data
24    /// ```
25    /// use smelte-rs::cpu::f32::Tensor;
26    ///
27    /// let tensor = Tensor::zeros(vec![2, 2]);
28    /// assert_eq!(tensor.data(), vec![0.0; 4]);
29    /// ```
30    pub fn data(&self) -> &[f32] {
31        self.data.as_ref()
32    }
33
34    /// A mutable slice to the underlying tensor data
35    /// ```
36    /// use smelte-rs::cpu::f32::Tensor;
37    ///
38    /// let mut tensor = Tensor::zeros(vec![2, 2]);
39    /// tensor.data_mut().iter_mut().for_each(|v| *v += 1.0);
40    /// assert_eq!(tensor.data(), vec![1.0; 4]);
41    /// ```
42    pub fn data_mut(&mut self) -> &mut [f32] {
43        self.data.to_mut()
44    }
45
46    /// Creates a new nulled tensor with given shape
47    /// ```
48    /// use smelte-rs::cpu::f32::Tensor;
49    ///
50    /// let tensor = Tensor::zeros(vec![2, 2]);
51    /// ```
52    pub fn zeros(shape: Vec<usize>) -> Self {
53        let nelement: usize = shape.iter().product();
54        let data = Cow::Owned(vec![0.0; nelement]);
55        Self { shape, data }
56    }
57
58    /// Creates a new borrowed tensor with given shape. Can fail if data doesn't match the shape
59    /// ```
60    /// use smelte-rs::cpu::f32::Tensor;
61    ///
62    /// let data = [1.0, 2.0, 3.0, 4.0];
63    /// let tensor = Tensor::borrowed(&data, vec![2, 2]).unwrap();
64    /// ```
65    pub fn borrowed(data: &'data [f32], shape: Vec<usize>) -> Result<Self, TensorError> {
66        let cow: Cow<'data, [f32]> = data.into();
67        Self::new(cow, shape)
68    }
69
70    /// Creates a new tensor with given shape. Can fail if data doesn't match the shape
71    /// ```
72    /// use smelte-rs::cpu::f32::Tensor;
73    ///
74    /// let data = vec![1.0, 2.0, 3.0, 4.0];
75    /// let tensor = Tensor::new(data, vec![2, 2]).unwrap();
76    /// ```
77    pub fn new<T>(data: T, shape: Vec<usize>) -> Result<Self, TensorError>
78    where
79        T: Into<Cow<'data, [f32]>>,
80    {
81        let data = data.into();
82        if data.len() != shape.iter().product::<usize>() {
83            return Err(TensorError::InvalidBuffer {
84                buffer_size: data.len(),
85                shape,
86            });
87        }
88        Ok(Self { shape, data })
89    }
90}