1use crate::error::{CnnError, CnnResult};
7
8#[derive(Debug, Clone)]
10pub struct Tensor {
11 data: Vec<f32>,
13 shape: Vec<usize>,
15 strides: Vec<usize>,
17}
18
19impl Tensor {
20 pub fn zeros(shape: &[usize]) -> Self {
22 let numel: usize = shape.iter().product();
23 let data = vec![0.0; numel];
24 let strides = Self::compute_strides(shape);
25
26 Self {
27 data,
28 shape: shape.to_vec(),
29 strides,
30 }
31 }
32
33 pub fn ones(shape: &[usize]) -> Self {
35 let numel: usize = shape.iter().product();
36 let data = vec![1.0; numel];
37 let strides = Self::compute_strides(shape);
38
39 Self {
40 data,
41 shape: shape.to_vec(),
42 strides,
43 }
44 }
45
46 pub fn from_data(data: Vec<f32>, shape: &[usize]) -> CnnResult<Self> {
48 let expected_numel: usize = shape.iter().product();
49 if data.len() != expected_numel {
50 return Err(CnnError::invalid_shape(
51 format!("data length {}", expected_numel),
52 format!("data length {}", data.len()),
53 ));
54 }
55
56 let strides = Self::compute_strides(shape);
57
58 Ok(Self {
59 data,
60 shape: shape.to_vec(),
61 strides,
62 })
63 }
64
65 pub fn full(shape: &[usize], value: f32) -> Self {
67 let numel: usize = shape.iter().product();
68 let data = vec![value; numel];
69 let strides = Self::compute_strides(shape);
70
71 Self {
72 data,
73 shape: shape.to_vec(),
74 strides,
75 }
76 }
77
78 fn compute_strides(shape: &[usize]) -> Vec<usize> {
80 let mut strides = vec![1; shape.len()];
81 for i in (0..shape.len().saturating_sub(1)).rev() {
82 strides[i] = strides[i + 1] * shape[i + 1];
83 }
84 strides
85 }
86
87 #[inline]
89 pub fn shape(&self) -> &[usize] {
90 &self.shape
91 }
92
93 #[inline]
95 pub fn strides(&self) -> &[usize] {
96 &self.strides
97 }
98
99 #[inline]
101 pub fn ndim(&self) -> usize {
102 self.shape.len()
103 }
104
105 #[inline]
107 pub fn numel(&self) -> usize {
108 self.data.len()
109 }
110
111 #[inline]
113 pub fn data(&self) -> &[f32] {
114 &self.data
115 }
116
117 #[inline]
119 pub fn data_mut(&mut self) -> &mut [f32] {
120 &mut self.data
121 }
122
123 #[inline]
125 pub fn get_4d(&self, n: usize, h: usize, w: usize, c: usize) -> f32 {
126 debug_assert!(self.shape.len() == 4);
127 let idx = n * self.strides[0] + h * self.strides[1] + w * self.strides[2] + c;
128 self.data[idx]
129 }
130
131 #[inline]
133 pub fn set_4d(&mut self, n: usize, h: usize, w: usize, c: usize, value: f32) {
134 debug_assert!(self.shape.len() == 4);
135 let idx = n * self.strides[0] + h * self.strides[1] + w * self.strides[2] + c;
136 self.data[idx] = value;
137 }
138
139 #[inline]
141 pub fn batch_size(&self) -> usize {
142 if self.shape.is_empty() {
143 0
144 } else {
145 self.shape[0]
146 }
147 }
148
149 #[inline]
151 pub fn height(&self) -> usize {
152 if self.shape.len() < 2 {
153 1
154 } else {
155 self.shape[1]
156 }
157 }
158
159 #[inline]
161 pub fn width(&self) -> usize {
162 if self.shape.len() < 3 {
163 1
164 } else {
165 self.shape[2]
166 }
167 }
168
169 #[inline]
171 pub fn channels(&self) -> usize {
172 if self.shape.len() < 4 {
173 1
174 } else {
175 self.shape[3]
176 }
177 }
178
179 pub fn reshape(&self, new_shape: &[usize]) -> CnnResult<Self> {
181 let new_numel: usize = new_shape.iter().product();
182 if new_numel != self.numel() {
183 return Err(CnnError::invalid_shape(
184 format!("numel {}", self.numel()),
185 format!("numel {}", new_numel),
186 ));
187 }
188
189 Self::from_data(self.data.clone(), new_shape)
190 }
191
192 pub fn view(&self, new_shape: &[usize]) -> CnnResult<Self> {
194 self.reshape(new_shape)
195 }
196
197 pub fn slice_batch(&self, start: usize, end: usize) -> CnnResult<Self> {
199 if self.shape.is_empty() {
200 return Err(CnnError::invalid_shape("non-empty tensor", "empty tensor"));
201 }
202
203 if start >= end || end > self.shape[0] {
204 return Err(CnnError::IndexOutOfBounds {
205 index: end,
206 size: self.shape[0],
207 });
208 }
209
210 let batch_stride = self.strides[0];
211 let start_idx = start * batch_stride;
212 let end_idx = end * batch_stride;
213
214 let mut new_shape = self.shape.clone();
215 new_shape[0] = end - start;
216
217 Self::from_data(self.data[start_idx..end_idx].to_vec(), &new_shape)
218 }
219
220 pub fn map<F>(&self, f: F) -> Self
222 where
223 F: Fn(f32) -> f32,
224 {
225 let data: Vec<f32> = self.data.iter().map(|&x| f(x)).collect();
226 Self {
227 data,
228 shape: self.shape.clone(),
229 strides: self.strides.clone(),
230 }
231 }
232
233 pub fn map_inplace<F>(&mut self, f: F)
235 where
236 F: Fn(f32) -> f32,
237 {
238 for x in &mut self.data {
239 *x = f(*x);
240 }
241 }
242
243 pub fn add(&self, other: &Self) -> CnnResult<Self> {
245 if self.shape != other.shape {
246 return Err(CnnError::shape_mismatch(format!(
247 "add: {:?} vs {:?}",
248 self.shape, other.shape
249 )));
250 }
251
252 let data: Vec<f32> = self
253 .data
254 .iter()
255 .zip(other.data.iter())
256 .map(|(&a, &b)| a + b)
257 .collect();
258
259 Self::from_data(data, &self.shape)
260 }
261
262 pub fn mul(&self, other: &Self) -> CnnResult<Self> {
264 if self.shape != other.shape {
265 return Err(CnnError::shape_mismatch(format!(
266 "mul: {:?} vs {:?}",
267 self.shape, other.shape
268 )));
269 }
270
271 let data: Vec<f32> = self
272 .data
273 .iter()
274 .zip(other.data.iter())
275 .map(|(&a, &b)| a * b)
276 .collect();
277
278 Self::from_data(data, &self.shape)
279 }
280
281 pub fn scale(&self, scalar: f32) -> Self {
283 self.map(|x| x * scalar)
284 }
285
286 pub fn sum(&self) -> f32 {
288 self.data.iter().sum()
289 }
290
291 pub fn mean(&self) -> f32 {
293 self.sum() / self.numel() as f32
294 }
295
296 pub fn max(&self) -> f32 {
298 self.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
299 }
300
301 pub fn min(&self) -> f32 {
303 self.data.iter().cloned().fold(f32::INFINITY, f32::min)
304 }
305}
306
307impl Default for Tensor {
308 fn default() -> Self {
309 Self::zeros(&[])
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_tensor_zeros() {
319 let t = Tensor::zeros(&[2, 3, 4, 5]);
320 assert_eq!(t.shape(), &[2, 3, 4, 5]);
321 assert_eq!(t.numel(), 2 * 3 * 4 * 5);
322 assert!(t.data().iter().all(|&x| x == 0.0));
323 }
324
325 #[test]
326 fn test_tensor_ones() {
327 let t = Tensor::ones(&[2, 2, 2, 2]);
328 assert!(t.data().iter().all(|&x| x == 1.0));
329 }
330
331 #[test]
332 fn test_tensor_strides() {
333 let t = Tensor::zeros(&[2, 3, 4, 5]);
334 assert_eq!(t.strides(), &[60, 20, 5, 1]); }
336
337 #[test]
338 fn test_tensor_get_set_4d() {
339 let mut t = Tensor::zeros(&[2, 3, 4, 5]);
340 t.set_4d(1, 2, 3, 4, 42.0);
341 assert_eq!(t.get_4d(1, 2, 3, 4), 42.0);
342 }
343
344 #[test]
345 fn test_tensor_reshape() {
346 let t = Tensor::ones(&[2, 3, 4, 5]);
347 let reshaped = t.reshape(&[6, 4, 5]).unwrap();
348 assert_eq!(reshaped.shape(), &[6, 4, 5]);
349 assert_eq!(reshaped.numel(), t.numel());
350 }
351
352 #[test]
353 fn test_tensor_map() {
354 let t = Tensor::full(&[2, 2], 2.0);
355 let squared = t.map(|x| x * x);
356 assert!(squared.data().iter().all(|&x| x == 4.0));
357 }
358
359 #[test]
360 fn test_tensor_add() {
361 let a = Tensor::ones(&[2, 2]);
362 let b = Tensor::ones(&[2, 2]);
363 let c = a.add(&b).unwrap();
364 assert!(c.data().iter().all(|&x| x == 2.0));
365 }
366
367 #[test]
368 fn test_tensor_sum_mean() {
369 let t = Tensor::ones(&[2, 3]);
370 assert_eq!(t.sum(), 6.0);
371 assert_eq!(t.mean(), 1.0);
372 }
373}