Skip to main content

tensor_forge/
tensor.rs

1//! Representations for dense, multidimensional arrays stored in contiguous memory.
2use std::fmt;
3
4/// Representation for a multidimensional array of numbers.
5///
6/// Tensors are used as data inputs to ML operations and are
7/// the basic datatype for the `tensor-forge` library.
8///
9/// # Examples
10/// ```
11/// # use tensor_forge::tensor::Tensor;
12/// # // TODO: add examples of later use with Graphs
13/// # // and flesh out documentation
14/// let shape = vec![4, 4];
15/// let data: Vec<f64> = (0..16).map(|x| x as f64).collect();
16/// let tensor = Tensor::from_vec(shape, data);
17/// assert!(tensor.is_ok());
18/// ```
19///
20/// Data will be stored in a contiguous array of IEEE 754 double-precision floating-point.
21#[derive(Debug)]
22pub struct Tensor {
23    shape: Vec<usize>,
24    data: Vec<f64>,
25}
26
27impl Tensor {
28    /// Constructs a zero-filled [`Tensor`] of a given shape.
29    ///
30    /// Use [`Tensor::from_vec`] to construct a tensor with
31    /// data, or fill the tensor after a call to [`Tensor::data_mut`].
32    ///
33    /// # Errors
34    /// - [`TensorError::InvalidShape`] if shape contains a zeroed dimension.
35    /// - [`TensorError::ShapeMismatch`] if shape contains a zeroed dimension.
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// # use tensor_forge::tensor::Tensor;
41    /// let shape = vec![4, 4];
42    /// let tensor = Tensor::zeros(shape);
43    /// assert_eq!(tensor.unwrap().data(), [0_f64; 4 * 4]);
44    /// ```
45    ///
46    /// Data will be stored in a contiguous array of IEEE 754 double-precision floating-point.
47    pub fn zeros(shape: impl Into<Vec<usize>>) -> Result<Tensor, TensorError> {
48        let shape: Vec<usize> = shape.into();
49        let num_elements = shape.iter().product();
50        let mut data = Vec::with_capacity(num_elements);
51        data.resize(num_elements, 0_f64);
52        Tensor::from_vec(shape, data)
53    }
54
55    /// Constructs a [`Tensor`] of the given dimensions in `shape` from the input `data`.
56    ///
57    /// # Errors
58    /// - [`TensorError::InvalidShape`] if shape contains a zeroed dimension.
59    /// - [`TensorError::ShapeMismatch`] if the data cannot fit into the tensor's dimension.
60    ///
61    /// # Examples
62    ///
63    /// ```
64    /// # use tensor_forge::tensor::Tensor;
65    /// // Example (toy) data
66    /// let data: Vec<f64> = (0..16).map(|x| x as f64).collect();
67    /// let shape = vec![4, 4];
68    ///
69    /// let tensor = Tensor::from_vec(shape, data);
70    /// assert!(tensor.is_ok());
71    /// ```
72    ///
73    /// Tensor data can fit into multiple valid shapes. For example, the above data with 16 total elements can fit into
74    /// an 8x2, 2x8, 1x16, or 16x1 tensor.
75    ///
76    /// ```
77    /// # use tensor_forge::tensor::Tensor;
78    /// // Example (toy) data from above
79    /// let data: Vec<f64> = (0..16).map(|x| x as f64).collect();
80    /// let shape = vec![16, 1];
81    ///
82    /// let tensor = Tensor::from_vec(shape, data);
83    /// assert!(tensor.is_ok());
84    /// ```
85    pub fn from_vec(shape: impl Into<Vec<usize>>, data: Vec<f64>) -> Result<Tensor, TensorError> {
86        let shape: Vec<usize> = shape.into();
87        let num_elements: usize = shape.iter().product();
88        if num_elements == 0 {
89            return Err(TensorError::InvalidShape);
90        }
91        if num_elements != data.len() {
92            return Err(TensorError::ShapeMismatch);
93        }
94        Ok(Tensor { shape, data })
95    }
96
97    /// Returns the shape of this tensor.
98    ///
99    /// # Examples
100    /// ```
101    /// # use tensor_forge::tensor::Tensor;
102    /// let shape = vec![4, 4];
103    /// let tensor = Tensor::zeros(shape);
104    /// assert_eq!(tensor.unwrap().shape(), vec![4, 4]);
105    /// ```
106    #[must_use]
107    pub fn shape(&self) -> &[usize] {
108        &self.shape
109    }
110
111    /// Returns the total number of elements in this tensor.
112    ///
113    /// # Examples
114    /// ```
115    /// # use tensor_forge::tensor::Tensor;
116    /// let shape = vec![4, 4];
117    /// let tensor = Tensor::zeros(shape);
118    /// assert_eq!(tensor.unwrap().numel(), 16);  // 4x4 = 16 elements
119    /// ```
120    #[must_use]
121    pub fn numel(&self) -> usize {
122        self.shape.iter().product()
123    }
124
125    /// Returns an immutable reference to the data in this tensor.
126    ///
127    /// # Examples
128    /// ```
129    /// # use tensor_forge::tensor::Tensor;
130    /// let shape = vec![4, 4];
131    /// let tensor = Tensor::zeros(shape);
132    /// assert_eq!(tensor.unwrap().data(), vec![0_f64; 4 * 4]);
133    /// ```
134    #[must_use]
135    pub fn data(&self) -> &[f64] {
136        &self.data
137    }
138
139    /// Returns a mutable reference to the data in this tensor.
140    ///
141    /// # Examples
142    /// ```
143    /// # use tensor_forge::tensor::Tensor;
144    /// let shape = vec![4, 4];
145    /// let mut tensor = Tensor::zeros(shape);
146    /// assert_eq!(tensor.unwrap().data_mut(), vec![0_f64; 4 * 4]);
147    /// ```
148    pub fn data_mut(&mut self) -> &mut [f64] {
149        &mut self.data
150    }
151}
152
153/// Error types for Tensor construction.
154#[derive(Clone, Debug)]
155pub enum TensorError {
156    /// Raised in [`Tensor::from_vec`] or [`Tensor::zeros`] if
157    /// one of a tensor's dimensions is zero.
158    InvalidShape,
159    /// Raised if the tensor shape does not match the
160    /// shape of the data passed in to [`Tensor::from_vec`].
161    ShapeMismatch,
162}
163
164impl fmt::Display for TensorError {
165    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
166        match self {
167            TensorError::InvalidShape => write!(
168                f,
169                "Invalid tensor dimensions. Tensor shape must not contain a zero."
170            ),
171            TensorError::ShapeMismatch => {
172                write!(f, "Tensor shape cannot store the size of the data.")
173            }
174        }
175    }
176}