1use crate::Scalar;
5use crate::error::{CoreError, Result};
6
7use super::{Tensor, compute_strides};
8
9impl<T: Scalar> Tensor<T> {
10 pub fn reshape(mut self, new_shape: Vec<usize>) -> Result<Self> {
23 let new_numel: usize = new_shape.iter().product();
24 if new_numel != self.numel() {
25 return Err(CoreError::InvalidShape {
26 shape: new_shape,
27 reason: "new shape has different number of elements",
28 });
29 }
30 self.strides = compute_strides(&new_shape);
31 self.shape = new_shape;
32 Ok(self)
33 }
34
35 pub fn reshaped(&self, new_shape: Vec<usize>) -> Result<Self> {
47 self.clone().reshape(new_shape)
48 }
49
50 pub fn flatten(self) -> Self {
61 let n = self.numel();
62 Tensor {
63 data: self.data,
64 shape: vec![n],
65 strides: vec![1],
66 }
67 }
68
69 pub fn flattened(&self) -> Self {
81 let n = self.numel();
82 Tensor {
83 data: self.data.clone(),
84 shape: vec![n],
85 strides: vec![1],
86 }
87 }
88
89 pub fn transpose(&self) -> Result<Self> {
102 if self.ndim() != 2 {
103 return Err(CoreError::InvalidArgument {
104 reason: "transpose() requires a 2-D tensor; use permute() for higher ranks",
105 });
106 }
107 let (rows, cols) = (self.shape[0], self.shape[1]);
108 let mut data = vec![T::zero(); self.numel()];
109
110 for r in 0..rows {
111 for c in 0..cols {
112 data[c * rows + r] = self.data[r * cols + c];
113 }
114 }
115
116 Tensor::from_vec(data, vec![cols, rows])
117 }
118
119 pub fn permute(&self, axes: &[usize]) -> Result<Self> {
132 if axes.len() != self.ndim() {
133 return Err(CoreError::InvalidArgument {
134 reason: "axes length must match tensor rank",
135 });
136 }
137
138 let mut seen = vec![false; self.ndim()];
140 for &a in axes {
141 if a >= self.ndim() {
142 return Err(CoreError::AxisOutOfBounds {
143 axis: a,
144 ndim: self.ndim(),
145 });
146 }
147 if seen[a] {
148 return Err(CoreError::InvalidArgument {
149 reason: "duplicate axis in permutation",
150 });
151 }
152 seen[a] = true;
153 }
154
155 let new_shape: Vec<usize> = axes.iter().map(|&a| self.shape[a]).collect();
156 let new_strides = compute_strides(&new_shape);
157 let new_numel: usize = new_shape.iter().product();
158 let mut data = vec![T::zero(); new_numel];
159
160 let mut out_index = vec![0usize; self.ndim()];
162 for item in &mut data {
163 let mut flat_in = 0;
165 for (out_ax, &in_ax) in axes.iter().enumerate() {
166 flat_in += out_index[out_ax] * self.strides[in_ax];
167 }
168 *item = self.data[flat_in];
169
170 for d in (0..self.ndim()).rev() {
172 out_index[d] += 1;
173 if out_index[d] < new_shape[d] {
174 break;
175 }
176 out_index[d] = 0;
177 }
178 }
179
180 Ok(Tensor {
181 data,
182 shape: new_shape,
183 strides: new_strides,
184 })
185 }
186
187 pub fn unsqueeze(mut self, axis: usize) -> Result<Self> {
198 if axis > self.ndim() {
199 return Err(CoreError::AxisOutOfBounds {
200 axis,
201 ndim: self.ndim(),
202 });
203 }
204 self.shape.insert(axis, 1);
205 self.strides = compute_strides(&self.shape);
206 Ok(self)
207 }
208
209 pub fn squeeze(mut self) -> Self {
220 self.shape.retain(|&d| d != 1);
221 if self.shape.is_empty() && self.numel() == 1 {
222 self.shape = vec![];
223 }
224 self.strides = compute_strides(&self.shape);
225 self
226 }
227
228 pub fn concat(tensors: &[&Tensor<T>], axis: usize) -> Result<Self> {
242 if tensors.is_empty() {
243 return Err(CoreError::InvalidArgument {
244 reason: "cannot concatenate zero tensors",
245 });
246 }
247
248 let ndim = tensors[0].ndim();
249 if axis >= ndim {
250 return Err(CoreError::AxisOutOfBounds { axis, ndim });
251 }
252
253 for t in &tensors[1..] {
255 if t.ndim() != ndim {
256 return Err(CoreError::DimensionMismatch {
257 expected: tensors[0].shape.clone(),
258 got: t.shape.clone(),
259 });
260 }
261 for (d, (&a, &b)) in tensors[0].shape.iter().zip(t.shape.iter()).enumerate() {
262 if d != axis && a != b {
263 return Err(CoreError::DimensionMismatch {
264 expected: tensors[0].shape.clone(),
265 got: t.shape.clone(),
266 });
267 }
268 }
269 }
270
271 let mut new_shape = tensors[0].shape.clone();
272 new_shape[axis] = tensors.iter().map(|t| t.shape[axis]).sum();
273
274 let outer: usize = new_shape[..axis].iter().product();
275 let inner: usize = new_shape[axis + 1..].iter().product();
276 let total: usize = new_shape.iter().product();
277
278 let mut data = Vec::with_capacity(total);
279
280 for o in 0..outer {
281 for t in tensors {
282 let axis_len = t.shape[axis];
283 let src_start = o * axis_len * inner;
284 let src_end = src_start + axis_len * inner;
285 data.extend_from_slice(&t.data[src_start..src_end]);
286 }
287 }
288
289 Tensor::from_vec(data, new_shape)
290 }
291
292 pub fn stack(tensors: &[&Tensor<T>], axis: usize) -> Result<Self> {
306 if tensors.is_empty() {
307 return Err(CoreError::InvalidArgument {
308 reason: "cannot stack zero tensors",
309 });
310 }
311
312 let base_shape = &tensors[0].shape;
313 if axis > base_shape.len() {
314 return Err(CoreError::AxisOutOfBounds {
315 axis,
316 ndim: base_shape.len() + 1,
317 });
318 }
319
320 for t in &tensors[1..] {
321 if t.shape != *base_shape {
322 return Err(CoreError::DimensionMismatch {
323 expected: base_shape.clone(),
324 got: t.shape.clone(),
325 });
326 }
327 }
328
329 let expanded: Vec<Tensor<T>> = tensors
331 .iter()
332 .map(|t| {
334 (*t).clone()
335 .unsqueeze(axis)
336 .expect("axis is valid for all tensors since shapes were validated above")
337 })
338 .collect();
339 let refs: Vec<&Tensor<T>> = expanded.iter().collect();
340 Tensor::concat(&refs, axis)
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_reshape() {
350 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![6]).unwrap();
351 let t = t.reshape(vec![2, 3]).unwrap();
352 assert_eq!(t.shape(), &[2, 3]);
353 assert_eq!(t.strides(), &[3, 1]);
354 assert_eq!(*t.get(&[1, 0]).unwrap(), 4);
355 }
356
357 #[test]
358 fn test_reshape_invalid() {
359 let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
360 assert!(t.reshape(vec![3, 2]).is_err());
361 }
362
363 #[test]
364 fn test_flatten() {
365 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
366 let flat = t.flatten();
367 assert_eq!(flat.shape(), &[6]);
368 assert_eq!(flat.as_slice(), &[1, 2, 3, 4, 5, 6]);
369 }
370
371 #[test]
372 fn test_transpose() {
373 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
376 let tt = t.transpose().unwrap();
377 assert_eq!(tt.shape(), &[3, 2]);
378 assert_eq!(*tt.get(&[0, 0]).unwrap(), 1);
379 assert_eq!(*tt.get(&[0, 1]).unwrap(), 4);
380 assert_eq!(*tt.get(&[2, 0]).unwrap(), 3);
381 assert_eq!(*tt.get(&[2, 1]).unwrap(), 6);
382 }
383
384 #[test]
385 fn test_transpose_not_2d() {
386 let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
387 assert!(t.transpose().is_err());
388 }
389
390 #[test]
391 fn test_permute() {
392 let t = Tensor::<i32>::arange(24).reshape(vec![2, 3, 4]).unwrap();
394 let p = t.permute(&[2, 0, 1]).unwrap();
395 assert_eq!(p.shape(), &[4, 2, 3]);
396 assert_eq!(*p.get(&[0, 0, 0]).unwrap(), 0);
398 assert_eq!(*p.get(&[3, 1, 2]).unwrap(), *t.get(&[1, 2, 3]).unwrap());
400 }
401
402 #[test]
403 fn test_unsqueeze_squeeze() {
404 let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
405 let t = t.unsqueeze(0).unwrap();
406 assert_eq!(t.shape(), &[1, 3]);
407 let t = t.squeeze();
408 assert_eq!(t.shape(), &[3]);
409 }
410
411 #[test]
412 fn test_concat() {
413 let a = Tensor::from_vec(vec![1, 2, 3], vec![1, 3]).unwrap();
414 let b = Tensor::from_vec(vec![4, 5, 6], vec![1, 3]).unwrap();
415 let c = Tensor::concat(&[&a, &b], 0).unwrap();
416 assert_eq!(c.shape(), &[2, 3]);
417 assert_eq!(c.as_slice(), &[1, 2, 3, 4, 5, 6]);
418 }
419
420 #[test]
421 fn test_concat_axis1() {
422 let a = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
423 let b = Tensor::from_vec(vec![5, 6, 7, 8], vec![2, 2]).unwrap();
424 let c = Tensor::concat(&[&a, &b], 1).unwrap();
425 assert_eq!(c.shape(), &[2, 4]);
426 assert_eq!(c.as_slice(), &[1, 2, 5, 6, 3, 4, 7, 8]);
427 }
428
429 #[test]
430 fn test_stack() {
431 let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
432 let b = Tensor::from_vec(vec![4, 5, 6], vec![3]).unwrap();
433 let c = Tensor::stack(&[&a, &b], 0).unwrap();
434 assert_eq!(c.shape(), &[2, 3]);
435 assert_eq!(c.as_slice(), &[1, 2, 3, 4, 5, 6]);
436 }
437
438 #[test]
439 fn test_stack_axis1() {
440 let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
441 let b = Tensor::from_vec(vec![4, 5, 6], vec![3]).unwrap();
442 let c = Tensor::stack(&[&a, &b], 1).unwrap();
443 assert_eq!(c.shape(), &[3, 2]);
444 assert_eq!(c.as_slice(), &[1, 4, 2, 5, 3, 6]);
445 }
446
447 #[test]
448 fn test_concat_shape_mismatch() {
449 let a = Tensor::from_vec(vec![1, 2, 3], vec![1, 3]).unwrap();
450 let b = Tensor::from_vec(vec![4, 5], vec![1, 2]).unwrap();
451 assert!(Tensor::concat(&[&a, &b], 0).is_err());
452 }
453}