Skip to main content

ruvector_cnn/
tensor.rs

1//! Tensor type for CNN operations
2//!
3//! Uses NHWC (batch, height, width, channels) memory layout for optimal
4//! cache utilization with convolutional operations.
5
6use crate::error::{CnnError, CnnResult};
7
8/// A multi-dimensional tensor with NHWC layout
9#[derive(Debug, Clone)]
10pub struct Tensor {
11    /// Raw data storage
12    data: Vec<f32>,
13    /// Shape: [batch, height, width, channels] for 4D tensors
14    shape: Vec<usize>,
15    /// Strides for each dimension
16    strides: Vec<usize>,
17}
18
19impl Tensor {
20    /// Create a new tensor with the given shape, initialized to zeros
21    pub fn zeros(shape: &[usize]) -> Self {
22        let numel: usize = shape.iter().product();
23        let data = vec![0.0; numel];
24        let strides = Self::compute_strides(shape);
25
26        Self {
27            data,
28            shape: shape.to_vec(),
29            strides,
30        }
31    }
32
33    /// Create a new tensor with the given shape, initialized to ones
34    pub fn ones(shape: &[usize]) -> Self {
35        let numel: usize = shape.iter().product();
36        let data = vec![1.0; numel];
37        let strides = Self::compute_strides(shape);
38
39        Self {
40            data,
41            shape: shape.to_vec(),
42            strides,
43        }
44    }
45
46    /// Create a tensor from raw data with the given shape
47    pub fn from_data(data: Vec<f32>, shape: &[usize]) -> CnnResult<Self> {
48        let expected_numel: usize = shape.iter().product();
49        if data.len() != expected_numel {
50            return Err(CnnError::invalid_shape(
51                format!("data length {}", expected_numel),
52                format!("data length {}", data.len()),
53            ));
54        }
55
56        let strides = Self::compute_strides(shape);
57
58        Ok(Self {
59            data,
60            shape: shape.to_vec(),
61            strides,
62        })
63    }
64
65    /// Create a tensor filled with a constant value
66    pub fn full(shape: &[usize], value: f32) -> Self {
67        let numel: usize = shape.iter().product();
68        let data = vec![value; numel];
69        let strides = Self::compute_strides(shape);
70
71        Self {
72            data,
73            shape: shape.to_vec(),
74            strides,
75        }
76    }
77
78    /// Compute strides for row-major (NHWC) layout
79    fn compute_strides(shape: &[usize]) -> Vec<usize> {
80        let mut strides = vec![1; shape.len()];
81        for i in (0..shape.len().saturating_sub(1)).rev() {
82            strides[i] = strides[i + 1] * shape[i + 1];
83        }
84        strides
85    }
86
87    /// Get the shape of the tensor
88    #[inline]
89    pub fn shape(&self) -> &[usize] {
90        &self.shape
91    }
92
93    /// Get the strides of the tensor
94    #[inline]
95    pub fn strides(&self) -> &[usize] {
96        &self.strides
97    }
98
99    /// Get the number of dimensions
100    #[inline]
101    pub fn ndim(&self) -> usize {
102        self.shape.len()
103    }
104
105    /// Get the total number of elements
106    #[inline]
107    pub fn numel(&self) -> usize {
108        self.data.len()
109    }
110
111    /// Get a reference to the raw data
112    #[inline]
113    pub fn data(&self) -> &[f32] {
114        &self.data
115    }
116
117    /// Get a mutable reference to the raw data
118    #[inline]
119    pub fn data_mut(&mut self) -> &mut [f32] {
120        &mut self.data
121    }
122
123    /// Get element at index (for 4D NHWC tensor)
124    #[inline]
125    pub fn get_4d(&self, n: usize, h: usize, w: usize, c: usize) -> f32 {
126        debug_assert!(self.shape.len() == 4);
127        let idx = n * self.strides[0] + h * self.strides[1] + w * self.strides[2] + c;
128        self.data[idx]
129    }
130
131    /// Set element at index (for 4D NHWC tensor)
132    #[inline]
133    pub fn set_4d(&mut self, n: usize, h: usize, w: usize, c: usize, value: f32) {
134        debug_assert!(self.shape.len() == 4);
135        let idx = n * self.strides[0] + h * self.strides[1] + w * self.strides[2] + c;
136        self.data[idx] = value;
137    }
138
139    /// Get batch size (first dimension)
140    #[inline]
141    pub fn batch_size(&self) -> usize {
142        if self.shape.is_empty() {
143            0
144        } else {
145            self.shape[0]
146        }
147    }
148
149    /// Get height (second dimension for NHWC)
150    #[inline]
151    pub fn height(&self) -> usize {
152        if self.shape.len() < 2 {
153            1
154        } else {
155            self.shape[1]
156        }
157    }
158
159    /// Get width (third dimension for NHWC)
160    #[inline]
161    pub fn width(&self) -> usize {
162        if self.shape.len() < 3 {
163            1
164        } else {
165            self.shape[2]
166        }
167    }
168
169    /// Get channels (fourth dimension for NHWC)
170    #[inline]
171    pub fn channels(&self) -> usize {
172        if self.shape.len() < 4 {
173            1
174        } else {
175            self.shape[3]
176        }
177    }
178
179    /// Reshape the tensor to a new shape
180    pub fn reshape(&self, new_shape: &[usize]) -> CnnResult<Self> {
181        let new_numel: usize = new_shape.iter().product();
182        if new_numel != self.numel() {
183            return Err(CnnError::invalid_shape(
184                format!("numel {}", self.numel()),
185                format!("numel {}", new_numel),
186            ));
187        }
188
189        Self::from_data(self.data.clone(), new_shape)
190    }
191
192    /// Clone with a new shape (must have same numel)
193    pub fn view(&self, new_shape: &[usize]) -> CnnResult<Self> {
194        self.reshape(new_shape)
195    }
196
197    /// Get a slice of the tensor along the batch dimension
198    pub fn slice_batch(&self, start: usize, end: usize) -> CnnResult<Self> {
199        if self.shape.is_empty() {
200            return Err(CnnError::invalid_shape("non-empty tensor", "empty tensor"));
201        }
202
203        if start >= end || end > self.shape[0] {
204            return Err(CnnError::IndexOutOfBounds {
205                index: end,
206                size: self.shape[0],
207            });
208        }
209
210        let batch_stride = self.strides[0];
211        let start_idx = start * batch_stride;
212        let end_idx = end * batch_stride;
213
214        let mut new_shape = self.shape.clone();
215        new_shape[0] = end - start;
216
217        Self::from_data(self.data[start_idx..end_idx].to_vec(), &new_shape)
218    }
219
220    /// Apply a function element-wise
221    pub fn map<F>(&self, f: F) -> Self
222    where
223        F: Fn(f32) -> f32,
224    {
225        let data: Vec<f32> = self.data.iter().map(|&x| f(x)).collect();
226        Self {
227            data,
228            shape: self.shape.clone(),
229            strides: self.strides.clone(),
230        }
231    }
232
233    /// Apply a function element-wise in place
234    pub fn map_inplace<F>(&mut self, f: F)
235    where
236        F: Fn(f32) -> f32,
237    {
238        for x in &mut self.data {
239            *x = f(*x);
240        }
241    }
242
243    /// Element-wise addition
244    pub fn add(&self, other: &Self) -> CnnResult<Self> {
245        if self.shape != other.shape {
246            return Err(CnnError::shape_mismatch(format!(
247                "add: {:?} vs {:?}",
248                self.shape, other.shape
249            )));
250        }
251
252        let data: Vec<f32> = self
253            .data
254            .iter()
255            .zip(other.data.iter())
256            .map(|(&a, &b)| a + b)
257            .collect();
258
259        Self::from_data(data, &self.shape)
260    }
261
262    /// Element-wise multiplication
263    pub fn mul(&self, other: &Self) -> CnnResult<Self> {
264        if self.shape != other.shape {
265            return Err(CnnError::shape_mismatch(format!(
266                "mul: {:?} vs {:?}",
267                self.shape, other.shape
268            )));
269        }
270
271        let data: Vec<f32> = self
272            .data
273            .iter()
274            .zip(other.data.iter())
275            .map(|(&a, &b)| a * b)
276            .collect();
277
278        Self::from_data(data, &self.shape)
279    }
280
281    /// Scalar multiplication
282    pub fn scale(&self, scalar: f32) -> Self {
283        self.map(|x| x * scalar)
284    }
285
286    /// Sum all elements
287    pub fn sum(&self) -> f32 {
288        self.data.iter().sum()
289    }
290
291    /// Mean of all elements
292    pub fn mean(&self) -> f32 {
293        self.sum() / self.numel() as f32
294    }
295
296    /// Maximum element
297    pub fn max(&self) -> f32 {
298        self.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
299    }
300
301    /// Minimum element
302    pub fn min(&self) -> f32 {
303        self.data.iter().cloned().fold(f32::INFINITY, f32::min)
304    }
305}
306
307impl Default for Tensor {
308    fn default() -> Self {
309        Self::zeros(&[])
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_tensor_zeros() {
319        let t = Tensor::zeros(&[2, 3, 4, 5]);
320        assert_eq!(t.shape(), &[2, 3, 4, 5]);
321        assert_eq!(t.numel(), 2 * 3 * 4 * 5);
322        assert!(t.data().iter().all(|&x| x == 0.0));
323    }
324
325    #[test]
326    fn test_tensor_ones() {
327        let t = Tensor::ones(&[2, 2, 2, 2]);
328        assert!(t.data().iter().all(|&x| x == 1.0));
329    }
330
331    #[test]
332    fn test_tensor_strides() {
333        let t = Tensor::zeros(&[2, 3, 4, 5]);
334        assert_eq!(t.strides(), &[60, 20, 5, 1]); // NHWC row-major
335    }
336
337    #[test]
338    fn test_tensor_get_set_4d() {
339        let mut t = Tensor::zeros(&[2, 3, 4, 5]);
340        t.set_4d(1, 2, 3, 4, 42.0);
341        assert_eq!(t.get_4d(1, 2, 3, 4), 42.0);
342    }
343
344    #[test]
345    fn test_tensor_reshape() {
346        let t = Tensor::ones(&[2, 3, 4, 5]);
347        let reshaped = t.reshape(&[6, 4, 5]).unwrap();
348        assert_eq!(reshaped.shape(), &[6, 4, 5]);
349        assert_eq!(reshaped.numel(), t.numel());
350    }
351
352    #[test]
353    fn test_tensor_map() {
354        let t = Tensor::full(&[2, 2], 2.0);
355        let squared = t.map(|x| x * x);
356        assert!(squared.data().iter().all(|&x| x == 4.0));
357    }
358
359    #[test]
360    fn test_tensor_add() {
361        let a = Tensor::ones(&[2, 2]);
362        let b = Tensor::ones(&[2, 2]);
363        let c = a.add(&b).unwrap();
364        assert!(c.data().iter().all(|&x| x == 2.0));
365    }
366
367    #[test]
368    fn test_tensor_sum_mean() {
369        let t = Tensor::ones(&[2, 3]);
370        assert_eq!(t.sum(), 6.0);
371        assert_eq!(t.mean(), 1.0);
372    }
373}