tensor_forge/tensor.rs
1//! Representations for dense, multidimensional arrays stored in contiguous memory.
2use std::fmt;
3
4/// Representation for a multidimensional array of numbers.
5///
6/// Tensors are used as data inputs to ML operations and are
7/// the basic datatype for the `tensor-forge` library.
8///
9/// # Examples
10/// ```
11/// # use tensor_forge::tensor::Tensor;
12/// # // TODO: add examples of later use with Graphs
13/// # // and flesh out documentation
14/// let shape = vec![4, 4];
15/// let data: Vec<f64> = (0..16).map(|x| x as f64).collect();
16/// let tensor = Tensor::from_vec(shape, data);
17/// assert!(tensor.is_ok());
18/// ```
19///
20/// Data will be stored in a contiguous array of IEEE 754 double-precision floating-point.
21#[derive(Debug)]
22pub struct Tensor {
23 shape: Vec<usize>,
24 data: Vec<f64>,
25}
26
27impl Tensor {
28 /// Constructs a zero-filled [`Tensor`] of a given shape.
29 ///
30 /// Use [`Tensor::from_vec`] to construct a tensor with
31 /// data, or fill the tensor after a call to [`Tensor::data_mut`].
32 ///
33 /// # Errors
34 /// - [`TensorError::InvalidShape`] if shape contains a zeroed dimension.
35 /// - [`TensorError::ShapeMismatch`] if shape contains a zeroed dimension.
36 ///
37 /// # Examples
38 ///
39 /// ```
40 /// # use tensor_forge::tensor::Tensor;
41 /// let shape = vec![4, 4];
42 /// let tensor = Tensor::zeros(shape);
43 /// assert_eq!(tensor.unwrap().data(), [0_f64; 4 * 4]);
44 /// ```
45 ///
46 /// Data will be stored in a contiguous array of IEEE 754 double-precision floating-point.
47 pub fn zeros(shape: impl Into<Vec<usize>>) -> Result<Tensor, TensorError> {
48 let shape: Vec<usize> = shape.into();
49 let num_elements = shape.iter().product();
50 let mut data = Vec::with_capacity(num_elements);
51 data.resize(num_elements, 0_f64);
52 Tensor::from_vec(shape, data)
53 }
54
55 /// Constructs a [`Tensor`] of the given dimensions in `shape` from the input `data`.
56 ///
57 /// # Errors
58 /// - [`TensorError::InvalidShape`] if shape contains a zeroed dimension.
59 /// - [`TensorError::ShapeMismatch`] if the data cannot fit into the tensor's dimension.
60 ///
61 /// # Examples
62 ///
63 /// ```
64 /// # use tensor_forge::tensor::Tensor;
65 /// // Example (toy) data
66 /// let data: Vec<f64> = (0..16).map(|x| x as f64).collect();
67 /// let shape = vec![4, 4];
68 ///
69 /// let tensor = Tensor::from_vec(shape, data);
70 /// assert!(tensor.is_ok());
71 /// ```
72 ///
73 /// Tensor data can fit into multiple valid shapes. For example, the above data with 16 total elements can fit into
74 /// an 8x2, 2x8, 1x16, or 16x1 tensor.
75 ///
76 /// ```
77 /// # use tensor_forge::tensor::Tensor;
78 /// // Example (toy) data from above
79 /// let data: Vec<f64> = (0..16).map(|x| x as f64).collect();
80 /// let shape = vec![16, 1];
81 ///
82 /// let tensor = Tensor::from_vec(shape, data);
83 /// assert!(tensor.is_ok());
84 /// ```
85 pub fn from_vec(shape: impl Into<Vec<usize>>, data: Vec<f64>) -> Result<Tensor, TensorError> {
86 let shape: Vec<usize> = shape.into();
87 let num_elements: usize = shape.iter().product();
88 if num_elements == 0 {
89 return Err(TensorError::InvalidShape);
90 }
91 if num_elements != data.len() {
92 return Err(TensorError::ShapeMismatch);
93 }
94 Ok(Tensor { shape, data })
95 }
96
97 /// Returns the shape of this tensor.
98 ///
99 /// # Examples
100 /// ```
101 /// # use tensor_forge::tensor::Tensor;
102 /// let shape = vec![4, 4];
103 /// let tensor = Tensor::zeros(shape);
104 /// assert_eq!(tensor.unwrap().shape(), vec![4, 4]);
105 /// ```
106 #[must_use]
107 pub fn shape(&self) -> &[usize] {
108 &self.shape
109 }
110
111 /// Returns the total number of elements in this tensor.
112 ///
113 /// # Examples
114 /// ```
115 /// # use tensor_forge::tensor::Tensor;
116 /// let shape = vec![4, 4];
117 /// let tensor = Tensor::zeros(shape);
118 /// assert_eq!(tensor.unwrap().numel(), 16); // 4x4 = 16 elements
119 /// ```
120 #[must_use]
121 pub fn numel(&self) -> usize {
122 self.shape.iter().product()
123 }
124
125 /// Returns an immutable reference to the data in this tensor.
126 ///
127 /// # Examples
128 /// ```
129 /// # use tensor_forge::tensor::Tensor;
130 /// let shape = vec![4, 4];
131 /// let tensor = Tensor::zeros(shape);
132 /// assert_eq!(tensor.unwrap().data(), vec![0_f64; 4 * 4]);
133 /// ```
134 #[must_use]
135 pub fn data(&self) -> &[f64] {
136 &self.data
137 }
138
139 /// Returns a mutable reference to the data in this tensor.
140 ///
141 /// # Examples
142 /// ```
143 /// # use tensor_forge::tensor::Tensor;
144 /// let shape = vec![4, 4];
145 /// let mut tensor = Tensor::zeros(shape);
146 /// assert_eq!(tensor.unwrap().data_mut(), vec![0_f64; 4 * 4]);
147 /// ```
148 pub fn data_mut(&mut self) -> &mut [f64] {
149 &mut self.data
150 }
151}
152
153/// Error types for Tensor construction.
154#[derive(Clone, Debug)]
155pub enum TensorError {
156 /// Raised in [`Tensor::from_vec`] or [`Tensor::zeros`] if
157 /// one of a tensor's dimensions is zero.
158 InvalidShape,
159 /// Raised if the tensor shape does not match the
160 /// shape of the data passed in to [`Tensor::from_vec`].
161 ShapeMismatch,
162}
163
164impl fmt::Display for TensorError {
165 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
166 match self {
167 TensorError::InvalidShape => write!(
168 f,
169 "Invalid tensor dimensions. Tensor shape must not contain a zero."
170 ),
171 TensorError::ShapeMismatch => {
172 write!(f, "Tensor shape cannot store the size of the data.")
173 }
174 }
175 }
176}