1use std::{fmt, marker::PhantomData, ops};
2
3use rand::{Rng, SeedableRng, rngs::StdRng};
4
5use crate::shape::{Dim, Nil, NonScalarShape, TensorShape};
6use crate::{Float, ReshapePreservesElementCount};
7
8struct StorageTensor<Storage, Shape: TensorShape> {
9 storage: Storage,
10 _shape_marker: PhantomData<Shape>,
11}
12
13pub struct Tensor<Shape: TensorShape>(StorageTensor<Box<[Float]>, Shape>);
14pub struct TensorRef<'a, Shape: TensorShape>(StorageTensor<&'a [Float], Shape>);
15pub struct TensorMut<'a, Shape: TensorShape>(StorageTensor<&'a mut [Float], Shape>);
16
17trait StorageRef {
18 fn as_slice(&self) -> &[Float];
19}
20
21trait StorageMut: StorageRef {
22 fn as_mut_slice(&mut self) -> &mut [Float];
23}
24
25impl StorageRef for Box<[Float]> {
26 fn as_slice(&self) -> &[Float] {
27 self
28 }
29}
30
31impl StorageMut for Box<[Float]> {
32 fn as_mut_slice(&mut self) -> &mut [Float] {
33 self
34 }
35}
36
37impl StorageRef for &[Float] {
38 fn as_slice(&self) -> &[Float] {
39 self
40 }
41}
42
43impl StorageRef for &mut [Float] {
44 fn as_slice(&self) -> &[Float] {
45 self
46 }
47}
48
49impl StorageMut for &mut [Float] {
50 fn as_mut_slice(&mut self) -> &mut [Float] {
51 self
52 }
53}
54
55pub trait TensorLiteral {
56 type Shape: TensorShape;
57
58 fn write_flat(self, out: &mut Vec<Float>);
59}
60
61impl<Storage, Shape> StorageTensor<Storage, Shape>
62where
63 Shape: TensorShape,
64{
65 fn from_storage(storage: Storage) -> Self {
66 Self {
67 storage,
68 _shape_marker: PhantomData,
69 }
70 }
71}
72
73impl<Storage, Shape> StorageTensor<Storage, Shape>
74where
75 Storage: StorageRef,
76 Shape: TensorShape,
77{
78 fn as_slice(&self) -> &[Float] {
79 StorageRef::as_slice(&self.storage)
80 }
81
82 fn at(&self, index: [usize; Shape::RANK]) -> &Float {
83 let offset = Shape::offset(&index);
84 &self.as_slice()[offset]
85 }
86
87 fn sum(&self) -> Float {
88 self.as_slice().iter().copied().sum()
89 }
90
91 fn mean(&self) -> Float {
92 self.sum() / Shape::SIZE as Float
93 }
94}
95
96impl<Storage, Shape> StorageTensor<Storage, Shape>
97where
98 Storage: StorageMut,
99 Shape: TensorShape,
100{
101 fn as_mut_slice(&mut self) -> &mut [Float] {
102 StorageMut::as_mut_slice(&mut self.storage)
103 }
104
105 fn set(&mut self, index: [usize; Shape::RANK], value: Float) {
106 let offset = Shape::offset(&index);
107 self.as_mut_slice()[offset] = value;
108 }
109
110 fn fill(&mut self, value: Float) {
111 self.as_mut_slice().fill(value);
112 }
113}
114
115impl<Storage, Shape> Clone for StorageTensor<Storage, Shape>
116where
117 Storage: Clone,
118 Shape: TensorShape,
119{
120 fn clone(&self) -> Self {
121 Self {
122 storage: self.storage.clone(),
123 _shape_marker: PhantomData,
124 }
125 }
126}
127
128impl<Storage, Shape> Copy for StorageTensor<Storage, Shape>
129where
130 Storage: Copy,
131 Shape: TensorShape,
132{
133}
134
135impl<Shape> Clone for Tensor<Shape>
136where
137 Shape: TensorShape,
138{
139 fn clone(&self) -> Self {
140 Self(self.0.clone())
141 }
142}
143
144impl<'a, Shape> Clone for TensorRef<'a, Shape>
145where
146 Shape: TensorShape,
147{
148 fn clone(&self) -> Self {
149 *self
150 }
151}
152
153impl<'a, Shape> Copy for TensorRef<'a, Shape> where Shape: TensorShape {}
154
155impl<Shape> fmt::Debug for Tensor<Shape>
156where
157 Shape: TensorShape,
158{
159 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160 f.debug_struct("Tensor")
161 .field("rank", &Shape::RANK)
162 .field("elements", &self.as_slice())
163 .finish()
164 }
165}
166
167impl<'a, Shape> fmt::Debug for TensorRef<'a, Shape>
168where
169 Shape: TensorShape,
170{
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 f.debug_struct("TensorRef")
173 .field("rank", &Shape::RANK)
174 .field("elements", &self.as_slice())
175 .finish()
176 }
177}
178
179impl<'a, Shape> fmt::Debug for TensorMut<'a, Shape>
180where
181 Shape: TensorShape,
182{
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 f.debug_struct("TensorMut")
185 .field("rank", &Shape::RANK)
186 .field("elements", &self.as_slice())
187 .finish()
188 }
189}
190
191impl<Shape> Default for Tensor<Shape>
192where
193 Shape: TensorShape,
194{
195 fn default() -> Self {
196 Self::zeros()
197 }
198}
199
200impl<Shape> Tensor<Shape>
201where
202 Shape: TensorShape,
203{
204 pub(crate) fn from_boxed(storage: Box<[Float]>) -> Self {
205 assert_eq!(storage.len(), Shape::SIZE, "tensor storage size mismatch");
206 Self(StorageTensor::from_storage(storage))
207 }
208
209 pub fn from_flat(data: [Float; Shape::SIZE]) -> Self {
210 Self::from_boxed(Vec::from(data).into_boxed_slice())
211 }
212
213 pub fn from_elem(value: Float) -> Self {
214 Self::from_boxed(vec![value; Shape::SIZE].into_boxed_slice())
215 }
216
217 pub(crate) fn raw_slice(&self) -> &[Float] {
218 self.as_slice()
219 }
220
221 pub(crate) fn raw_mut_slice(&mut self) -> &mut [Float] {
222 self.as_mut_slice()
223 }
224
225 pub fn len(&self) -> usize {
226 Shape::SIZE
227 }
228
229 pub fn is_empty(&self) -> bool {
230 self.len() == 0
231 }
232
233 pub fn rank(&self) -> usize {
234 Shape::RANK
235 }
236
237 pub fn as_slice(&self) -> &[Float] {
238 self.0.as_slice()
239 }
240
241 pub fn as_mut_slice(&mut self) -> &mut [Float] {
242 self.0.as_mut_slice()
243 }
244
245 pub fn at(&self, index: [usize; Shape::RANK]) -> &Float {
246 self.0.at(index)
247 }
248
249 pub fn set(&mut self, index: [usize; Shape::RANK], value: Float) {
250 self.0.set(index, value);
251 }
252
253 pub fn fill(&mut self, value: Float) {
254 self.0.fill(value);
255 }
256
257 pub fn zeros() -> Self {
258 Self::from_elem(0.0)
259 }
260
261 pub fn random() -> Self {
262 let mut rng = rand::rng();
263 Self::random_with(&mut rng)
264 }
265
266 pub fn random_with_seed(seed: u64) -> Self {
267 let mut rng = StdRng::seed_from_u64(seed);
268 Self::random_with(&mut rng)
269 }
270
271 pub fn random_with<R>(rng: &mut R) -> Self
272 where
273 R: Rng + ?Sized,
274 {
275 let mut out = Self::zeros();
276 for value in out.as_mut_slice() {
277 *value = rng.random::<Float>();
278 }
279 out
280 }
281
282 pub fn reshape<NewShape>(self) -> Tensor<NewShape>
283 where
284 NewShape: TensorShape,
285 (): ReshapePreservesElementCount<{ Shape::SIZE }, { NewShape::SIZE }>,
286 {
287 Tensor::<NewShape>::from_boxed(self.0.storage)
288 }
289
290 pub fn as_ref(&self) -> TensorRef<'_, Shape> {
291 TensorRef(StorageTensor::from_storage(self.as_slice()))
292 }
293
294 pub fn as_mut(&mut self) -> TensorMut<'_, Shape> {
295 TensorMut(StorageTensor::from_storage(self.as_mut_slice()))
296 }
297
298 pub fn map_inplace<F>(&mut self, mut f: F)
299 where
300 F: FnMut(Float) -> Float,
301 {
302 for value in self.as_mut_slice() {
303 *value = f(*value);
304 }
305 }
306
307 pub fn map<F>(&self, f: F) -> Self
308 where
309 F: FnMut(Float) -> Float,
310 {
311 let mut out = self.clone();
312 out.map_inplace(f);
313 out
314 }
315
316 pub fn zip_map<F>(&self, rhs: &Self, mut f: F) -> Self
317 where
318 F: FnMut(Float, Float) -> Float,
319 {
320 let mut out = Self::zeros();
321 for ((dst, lhs), rhs) in out
322 .as_mut_slice()
323 .iter_mut()
324 .zip(self.as_slice().iter().copied())
325 .zip(rhs.as_slice().iter().copied())
326 {
327 *dst = f(lhs, rhs);
328 }
329 out
330 }
331
332 pub fn sum(&self) -> Float {
333 self.0.sum()
334 }
335
336 pub fn mean(&self) -> Float {
337 self.0.mean()
338 }
339
340 #[deprecated(note = "Tensor::slice is not implemented yet")]
341 pub fn slice<T: Iterator>(_range: T) {}
342}
343
344impl<'a, Shape> TensorRef<'a, Shape>
345where
346 Shape: TensorShape,
347{
348 pub fn len(&self) -> usize {
349 Shape::SIZE
350 }
351
352 pub fn is_empty(&self) -> bool {
353 self.len() == 0
354 }
355
356 pub fn rank(&self) -> usize {
357 Shape::RANK
358 }
359
360 pub fn as_slice(&self) -> &[Float] {
361 self.0.as_slice()
362 }
363
364 pub fn at(&self, index: [usize; Shape::RANK]) -> &Float {
365 self.0.at(index)
366 }
367
368 pub fn sum(&self) -> Float {
369 self.0.sum()
370 }
371
372 pub fn mean(&self) -> Float {
373 self.0.mean()
374 }
375
376 pub fn reshape<NewShape>(self) -> TensorRef<'a, NewShape>
377 where
378 NewShape: TensorShape,
379 (): ReshapePreservesElementCount<{ Shape::SIZE }, { NewShape::SIZE }>,
380 {
381 TensorRef(StorageTensor::from_storage(self.0.storage))
382 }
383}
384
385impl<'a, Shape> TensorMut<'a, Shape>
386where
387 Shape: TensorShape,
388{
389 pub fn len(&self) -> usize {
390 Shape::SIZE
391 }
392
393 pub fn is_empty(&self) -> bool {
394 self.len() == 0
395 }
396
397 pub fn rank(&self) -> usize {
398 Shape::RANK
399 }
400
401 pub fn as_slice(&self) -> &[Float] {
402 self.0.as_slice()
403 }
404
405 pub fn as_mut_slice(&mut self) -> &mut [Float] {
406 self.0.as_mut_slice()
407 }
408
409 pub fn at(&self, index: [usize; Shape::RANK]) -> &Float {
410 self.0.at(index)
411 }
412
413 pub fn set(&mut self, index: [usize; Shape::RANK], value: Float) {
414 self.0.set(index, value);
415 }
416
417 pub fn fill(&mut self, value: Float) {
418 self.0.fill(value);
419 }
420
421 pub fn sum(&self) -> Float {
422 self.0.sum()
423 }
424
425 pub fn mean(&self) -> Float {
426 self.0.mean()
427 }
428
429 pub fn reshape<NewShape>(self) -> TensorMut<'a, NewShape>
430 where
431 NewShape: TensorShape,
432 (): ReshapePreservesElementCount<{ Shape::SIZE }, { NewShape::SIZE }>,
433 {
434 TensorMut(StorageTensor::from_storage(self.0.storage))
435 }
436}
437
438impl<Shape> Tensor<Shape>
439where
440 Shape: NonScalarShape,
441{
442 pub fn get_ref(&self, index: usize) -> TensorRef<'_, Shape::Subshape> {
443 assert!(index < Shape::AXIS_LEN, "index out of bounds");
444 let stride = <Shape::Subshape as TensorShape>::SIZE;
445 let start = index * stride;
446 let end = start + stride;
447 TensorRef(StorageTensor::from_storage(&self.as_slice()[start..end]))
448 }
449
450 pub fn get_mut(&mut self, index: usize) -> TensorMut<'_, Shape::Subshape> {
451 assert!(index < Shape::AXIS_LEN, "index out of bounds");
452 let stride = <Shape::Subshape as TensorShape>::SIZE;
453 let start = index * stride;
454 let end = start + stride;
455 TensorMut(StorageTensor::from_storage(
456 &mut self.as_mut_slice()[start..end],
457 ))
458 }
459
460 pub fn get(&self, index: usize) -> Tensor<Shape::Subshape> {
461 let row = self.get_ref(index);
462 Tensor::<Shape::Subshape>::from_boxed(row.as_slice().to_vec().into_boxed_slice())
463 }
464}
465
466impl<'a, Shape> TensorRef<'a, Shape>
467where
468 Shape: NonScalarShape,
469{
470 pub fn get_ref(&self, index: usize) -> TensorRef<'_, Shape::Subshape> {
471 assert!(index < Shape::AXIS_LEN, "index out of bounds");
472 let stride = <Shape::Subshape as TensorShape>::SIZE;
473 let start = index * stride;
474 let end = start + stride;
475 TensorRef(StorageTensor::from_storage(&self.as_slice()[start..end]))
476 }
477}
478
479impl<'a, Shape> TensorMut<'a, Shape>
480where
481 Shape: NonScalarShape,
482{
483 pub fn get_ref(&self, index: usize) -> TensorRef<'_, Shape::Subshape> {
484 assert!(index < Shape::AXIS_LEN, "index out of bounds");
485 let stride = <Shape::Subshape as TensorShape>::SIZE;
486 let start = index * stride;
487 let end = start + stride;
488 TensorRef(StorageTensor::from_storage(&self.as_slice()[start..end]))
489 }
490
491 pub fn get_mut(&mut self, index: usize) -> TensorMut<'_, Shape::Subshape> {
492 assert!(index < Shape::AXIS_LEN, "index out of bounds");
493 let stride = <Shape::Subshape as TensorShape>::SIZE;
494 let start = index * stride;
495 let end = start + stride;
496 TensorMut(StorageTensor::from_storage(
497 &mut self.as_mut_slice()[start..end],
498 ))
499 }
500}
501
502impl<const N: usize> Tensor<Dim<N, Nil>> {
503 pub fn dot(&self, rhs: &Self) -> Float {
504 self.as_slice()
505 .iter()
506 .zip(rhs.as_slice())
507 .map(|(lhs, rhs)| lhs * rhs)
508 .sum()
509 }
510}
511
512impl<const ROWS: usize, const COLS: usize> Tensor<Dim<ROWS, Dim<COLS, Nil>>> {
513 pub fn transpose(&self) -> Tensor<Dim<COLS, Dim<ROWS, Nil>>> {
514 let mut out = Tensor::<Dim<COLS, Dim<ROWS, Nil>>>::zeros();
515 let input = self.as_slice();
516 let output = out.as_mut_slice();
517 for row in 0..ROWS {
518 for col in 0..COLS {
519 output[col * ROWS + row] = input[row * COLS + col];
520 }
521 }
522 out
523 }
524
525 pub fn matvec(&self, rhs: &Tensor<Dim<COLS, Nil>>) -> Tensor<Dim<ROWS, Nil>> {
526 let mut out = Tensor::<Dim<ROWS, Nil>>::zeros();
527 let lhs = self.as_slice();
528 let rhs = rhs.as_slice();
529 for row in 0..ROWS {
530 let mut acc = 0.0;
531 for col in 0..COLS {
532 acc += lhs[row * COLS + col] * rhs[col];
533 }
534 out.as_mut_slice()[row] = acc;
535 }
536 out
537 }
538
539 pub fn matmul<const OUT_COLS: usize>(
540 &self,
541 rhs: &Tensor<Dim<COLS, Dim<OUT_COLS, Nil>>>,
542 ) -> Tensor<Dim<ROWS, Dim<OUT_COLS, Nil>>> {
543 let mut out = Tensor::<Dim<ROWS, Dim<OUT_COLS, Nil>>>::zeros();
544 let lhs = self.as_slice();
545 let rhs = rhs.as_slice();
546 let output = out.as_mut_slice();
547 for row in 0..ROWS {
548 for out_col in 0..OUT_COLS {
549 let mut acc = 0.0;
550 for inner in 0..COLS {
551 acc += lhs[row * COLS + inner] * rhs[inner * OUT_COLS + out_col];
552 }
553 output[row * OUT_COLS + out_col] = acc;
554 }
555 }
556 out
557 }
558}
559
560macro_rules! impl_tensor_binop {
561 ($trait:ident, $method:ident, $assign_trait:ident, $assign_method:ident, $op:tt) => {
562 impl<Shape> ops::$trait<&Tensor<Shape>> for Tensor<Shape>
563 where
564 Shape: TensorShape,
565 {
566 type Output = Tensor<Shape>;
567
568 fn $method(mut self, rhs: &Tensor<Shape>) -> Self::Output {
569 ops::$assign_trait::$assign_method(&mut self, rhs);
570 self
571 }
572 }
573
574 impl<Shape> ops::$trait<&Tensor<Shape>> for &Tensor<Shape>
575 where
576 Shape: TensorShape,
577 {
578 type Output = Tensor<Shape>;
579
580 fn $method(self, rhs: &Tensor<Shape>) -> Self::Output {
581 self.clone().$method(rhs)
582 }
583 }
584
585 impl<Shape> ops::$assign_trait<&Tensor<Shape>> for Tensor<Shape>
586 where
587 Shape: TensorShape,
588 {
589 fn $assign_method(&mut self, rhs: &Tensor<Shape>) {
590 for (lhs, rhs) in self.as_mut_slice().iter_mut().zip(rhs.as_slice().iter().copied()) {
591 *lhs = *lhs $op rhs;
592 }
593 }
594 }
595 };
596}
597
598macro_rules! impl_tensor_scalar_binop {
599 ($trait:ident, $method:ident, $assign_trait:ident, $assign_method:ident, $op:tt) => {
600 impl<Shape> ops::$trait<Float> for Tensor<Shape>
601 where
602 Shape: TensorShape,
603 {
604 type Output = Tensor<Shape>;
605
606 fn $method(mut self, rhs: Float) -> Self::Output {
607 ops::$assign_trait::$assign_method(&mut self, rhs);
608 self
609 }
610 }
611
612 impl<Shape> ops::$assign_trait<Float> for Tensor<Shape>
613 where
614 Shape: TensorShape,
615 {
616 fn $assign_method(&mut self, rhs: Float) {
617 for value in self.as_mut_slice() {
618 *value = *value $op rhs;
619 }
620 }
621 }
622 };
623}
624
625impl_tensor_binop!(Add, add, AddAssign, add_assign, +);
626impl_tensor_binop!(Sub, sub, SubAssign, sub_assign, -);
627impl_tensor_binop!(Mul, mul, MulAssign, mul_assign, *);
628
629impl_tensor_scalar_binop!(Add, add, AddAssign, add_assign, +);
630impl_tensor_scalar_binop!(Sub, sub, SubAssign, sub_assign, -);
631impl_tensor_scalar_binop!(Mul, mul, MulAssign, mul_assign, *);
632impl_tensor_scalar_binop!(Div, div, DivAssign, div_assign, /);
633
634impl<const N: usize> From<[Float; N]> for Tensor<Dim<N, Nil>> {
635 fn from(value: [Float; N]) -> Self {
636 Self::from_boxed(Vec::from(value).into_boxed_slice())
637 }
638}
639
640impl TensorLiteral for Float {
641 type Shape = Nil;
642
643 fn write_flat(self, out: &mut Vec<Float>) {
644 out.push(self);
645 }
646}
647
648impl<T, const N: usize> TensorLiteral for [T; N]
649where
650 T: TensorLiteral,
651{
652 type Shape = Dim<N, T::Shape>;
653
654 fn write_flat(self, out: &mut Vec<Float>) {
655 for item in self {
656 item.write_flat(out);
657 }
658 }
659}
660
661#[doc(hidden)]
662pub fn __tensor_from_literal<T>(value: T) -> Tensor<T::Shape>
663where
664 T: TensorLiteral,
665{
666 let mut flat = Vec::with_capacity(<T::Shape as TensorShape>::SIZE);
667 value.write_flat(&mut flat);
668 Tensor::<T::Shape>::from_boxed(flat.into_boxed_slice())
669}
670
671#[macro_export]
672macro_rules! tensor {
673 [$($items:tt)*] => {
674 $crate::__tensor_from_literal([$($items)*])
675 };
676}
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681
682 type T3 = crate::shape!(2, 3, 4);
683
684 #[test]
685 fn indexing_borrows_and_owned_get_match_layout() {
686 let mut t = Tensor::<T3>::zeros();
687 let mut value = 0.0;
688 for i in 0..2 {
689 for j in 0..3 {
690 for k in 0..4 {
691 t.set([i, j, k], value);
692 value += 1.0;
693 }
694 }
695 }
696
697 assert_eq!(*t.at([1, 2, 3]), 23.0);
698
699 let row = t.get_ref(1);
700 assert_eq!(*row.at([2, 3]), 23.0);
701
702 let owned = t.get(1);
703 assert_eq!(*owned.at([2, 3]), 23.0);
704
705 let mut tmut = t.as_mut();
706 let mut row_mut = tmut.get_mut(0);
707 row_mut.set([0, 0], 99.0);
708 assert_eq!(*t.at([0, 0, 0]), 99.0);
709 }
710
711 #[test]
712 #[should_panic(expected = "index out of bounds")]
713 fn get_ref_panics_on_oob_index() {
714 let t = Tensor::<T3>::zeros();
715 let _ = t.get_ref(2);
716 }
717
718 #[test]
719 fn reshape_changes_shape_type_without_reordering_data() {
720 let flat = Tensor::<crate::shape!(6)>::from_flat([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
721 let reshaped = flat.reshape::<crate::shape!(2, 3)>();
722 assert_eq!(reshaped.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
723 assert_eq!(*reshaped.at([1, 2]), 6.0);
724 }
725
726 #[test]
727 fn tensor_literal_infers_shape_and_layout() {
728 let t = crate::tensor![[1.0, 2.0], [3.0, 4.0]];
729 assert_eq!(t.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
730 assert_eq!(*t.at([1, 0]), 3.0);
731 }
732
733 #[test]
734 fn tensor_debug_uses_public_type_names() {
735 let tensor = crate::tensor![[1.0, 2.0], [3.0, 4.0]];
736 assert!(format!("{tensor:?}").starts_with("Tensor {"));
737 let row = tensor.get_ref(1);
738 assert!(format!("{row:?}").starts_with("TensorRef {"));
739 }
740
741 #[test]
742 fn elementwise_ops_and_reductions_work() {
743 let lhs = crate::tensor![1.0, 2.0, 3.0];
744 let rhs = crate::tensor![4.0, 5.0, 6.0];
745
746 assert_eq!((&lhs + &rhs).as_slice(), &[5.0, 7.0, 9.0]);
747 assert_eq!((&rhs - &lhs).as_slice(), &[3.0, 3.0, 3.0]);
748 assert_eq!((&lhs * &rhs).as_slice(), &[4.0, 10.0, 18.0]);
749 assert_eq!((lhs.clone() + 1.0).as_slice(), &[2.0, 3.0, 4.0]);
750 assert_eq!(lhs.sum(), 6.0);
751 assert_eq!(lhs.mean(), 2.0);
752 }
753
754 #[test]
755 fn dot_transpose_and_matmul_work() {
756 let vec_a = crate::tensor![1.0, 2.0, 3.0];
757 let vec_b = crate::tensor![4.0, 5.0, 6.0];
758 assert_eq!(vec_a.dot(&vec_b), 32.0);
759
760 let lhs = crate::tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
761 let rhs = crate::tensor![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
762 let product = lhs.matmul(&rhs);
763 assert_eq!(product.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
764
765 let transposed = lhs.transpose();
766 assert_eq!(transposed.as_slice(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
767 }
768
769 #[test]
770 fn seeded_random_is_reproducible() {
771 let a = Tensor::<crate::shape!(2, 3)>::random_with_seed(7);
772 let b = Tensor::<crate::shape!(2, 3)>::random_with_seed(7);
773 let c = Tensor::<crate::shape!(2, 3)>::random_with_seed(9);
774
775 assert_eq!(a.as_slice(), b.as_slice());
776 assert_ne!(a.as_slice(), c.as_slice());
777 }
778
779 #[test]
780 fn randomized_shape_stress_preserves_row_major_layout() {
781 let mut tensor = Tensor::<crate::shape!(2, 3, 4)>::zeros();
782 let mut rng = StdRng::seed_from_u64(42);
783
784 for index in 0..tensor.len() {
785 tensor.as_mut_slice()[index] = rng.random::<Float>();
786 }
787
788 for i in 0..2 {
789 for j in 0..3 {
790 for k in 0..4 {
791 let flat_index = i * 12 + j * 4 + k;
792 assert_eq!(*tensor.at([i, j, k]), tensor.as_slice()[flat_index]);
793 }
794 }
795 }
796
797 let reshaped = tensor.clone().reshape::<crate::shape!(4, 3, 2)>();
798 assert_eq!(tensor.as_slice(), reshaped.as_slice());
799 }
800
801 #[test]
802 fn borrowed_reshape_preserves_view_semantics() {
803 let mut tensor = crate::tensor![[1.0, 2.0], [3.0, 4.0]];
804
805 let flat_ref = tensor.as_ref().reshape::<crate::shape!(4)>();
806 assert_eq!(flat_ref.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
807 assert_eq!(*flat_ref.at([2]), 3.0);
808
809 {
810 let mut flat_mut = tensor.as_mut().reshape::<crate::shape!(4)>();
811 flat_mut.set([3], 9.0);
812 }
813
814 assert_eq!(tensor.as_slice(), &[1.0, 2.0, 3.0, 9.0]);
815 }
816}