Skip to main content

tml_utils/
tensor.rs

1use std::{fmt, marker::PhantomData, ops};
2
3use rand::{Rng, SeedableRng, rngs::StdRng};
4
5use crate::shape::{Dim, Nil, NonScalarShape, TensorShape};
6use crate::{Float, ReshapePreservesElementCount};
7
8struct StorageTensor<Storage, Shape: TensorShape> {
9    storage: Storage,
10    _shape_marker: PhantomData<Shape>,
11}
12
13pub struct Tensor<Shape: TensorShape>(StorageTensor<Box<[Float]>, Shape>);
14pub struct TensorRef<'a, Shape: TensorShape>(StorageTensor<&'a [Float], Shape>);
15pub struct TensorMut<'a, Shape: TensorShape>(StorageTensor<&'a mut [Float], Shape>);
16
17trait StorageRef {
18    fn as_slice(&self) -> &[Float];
19}
20
21trait StorageMut: StorageRef {
22    fn as_mut_slice(&mut self) -> &mut [Float];
23}
24
25impl StorageRef for Box<[Float]> {
26    fn as_slice(&self) -> &[Float] {
27        self
28    }
29}
30
31impl StorageMut for Box<[Float]> {
32    fn as_mut_slice(&mut self) -> &mut [Float] {
33        self
34    }
35}
36
37impl StorageRef for &[Float] {
38    fn as_slice(&self) -> &[Float] {
39        self
40    }
41}
42
43impl StorageRef for &mut [Float] {
44    fn as_slice(&self) -> &[Float] {
45        self
46    }
47}
48
49impl StorageMut for &mut [Float] {
50    fn as_mut_slice(&mut self) -> &mut [Float] {
51        self
52    }
53}
54
55pub trait TensorLiteral {
56    type Shape: TensorShape;
57
58    fn write_flat(self, out: &mut Vec<Float>);
59}
60
61impl<Storage, Shape> StorageTensor<Storage, Shape>
62where
63    Shape: TensorShape,
64{
65    fn from_storage(storage: Storage) -> Self {
66        Self {
67            storage,
68            _shape_marker: PhantomData,
69        }
70    }
71}
72
73impl<Storage, Shape> StorageTensor<Storage, Shape>
74where
75    Storage: StorageRef,
76    Shape: TensorShape,
77{
78    fn as_slice(&self) -> &[Float] {
79        StorageRef::as_slice(&self.storage)
80    }
81
82    fn at(&self, index: [usize; Shape::RANK]) -> &Float {
83        let offset = Shape::offset(&index);
84        &self.as_slice()[offset]
85    }
86
87    fn sum(&self) -> Float {
88        self.as_slice().iter().copied().sum()
89    }
90
91    fn mean(&self) -> Float {
92        self.sum() / Shape::SIZE as Float
93    }
94}
95
96impl<Storage, Shape> StorageTensor<Storage, Shape>
97where
98    Storage: StorageMut,
99    Shape: TensorShape,
100{
101    fn as_mut_slice(&mut self) -> &mut [Float] {
102        StorageMut::as_mut_slice(&mut self.storage)
103    }
104
105    fn set(&mut self, index: [usize; Shape::RANK], value: Float) {
106        let offset = Shape::offset(&index);
107        self.as_mut_slice()[offset] = value;
108    }
109
110    fn fill(&mut self, value: Float) {
111        self.as_mut_slice().fill(value);
112    }
113}
114
115impl<Storage, Shape> Clone for StorageTensor<Storage, Shape>
116where
117    Storage: Clone,
118    Shape: TensorShape,
119{
120    fn clone(&self) -> Self {
121        Self {
122            storage: self.storage.clone(),
123            _shape_marker: PhantomData,
124        }
125    }
126}
127
128impl<Storage, Shape> Copy for StorageTensor<Storage, Shape>
129where
130    Storage: Copy,
131    Shape: TensorShape,
132{
133}
134
135impl<Shape> Clone for Tensor<Shape>
136where
137    Shape: TensorShape,
138{
139    fn clone(&self) -> Self {
140        Self(self.0.clone())
141    }
142}
143
144impl<'a, Shape> Clone for TensorRef<'a, Shape>
145where
146    Shape: TensorShape,
147{
148    fn clone(&self) -> Self {
149        *self
150    }
151}
152
153impl<'a, Shape> Copy for TensorRef<'a, Shape> where Shape: TensorShape {}
154
155impl<Shape> fmt::Debug for Tensor<Shape>
156where
157    Shape: TensorShape,
158{
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        f.debug_struct("Tensor")
161            .field("rank", &Shape::RANK)
162            .field("elements", &self.as_slice())
163            .finish()
164    }
165}
166
167impl<'a, Shape> fmt::Debug for TensorRef<'a, Shape>
168where
169    Shape: TensorShape,
170{
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        f.debug_struct("TensorRef")
173            .field("rank", &Shape::RANK)
174            .field("elements", &self.as_slice())
175            .finish()
176    }
177}
178
179impl<'a, Shape> fmt::Debug for TensorMut<'a, Shape>
180where
181    Shape: TensorShape,
182{
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        f.debug_struct("TensorMut")
185            .field("rank", &Shape::RANK)
186            .field("elements", &self.as_slice())
187            .finish()
188    }
189}
190
191impl<Shape> Default for Tensor<Shape>
192where
193    Shape: TensorShape,
194{
195    fn default() -> Self {
196        Self::zeros()
197    }
198}
199
200impl<Shape> Tensor<Shape>
201where
202    Shape: TensorShape,
203{
204    pub(crate) fn from_boxed(storage: Box<[Float]>) -> Self {
205        assert_eq!(storage.len(), Shape::SIZE, "tensor storage size mismatch");
206        Self(StorageTensor::from_storage(storage))
207    }
208
209    pub fn from_flat(data: [Float; Shape::SIZE]) -> Self {
210        Self::from_boxed(Vec::from(data).into_boxed_slice())
211    }
212
213    pub fn from_elem(value: Float) -> Self {
214        Self::from_boxed(vec![value; Shape::SIZE].into_boxed_slice())
215    }
216
217    pub(crate) fn raw_slice(&self) -> &[Float] {
218        self.as_slice()
219    }
220
221    pub(crate) fn raw_mut_slice(&mut self) -> &mut [Float] {
222        self.as_mut_slice()
223    }
224
225    pub fn len(&self) -> usize {
226        Shape::SIZE
227    }
228
229    pub fn is_empty(&self) -> bool {
230        self.len() == 0
231    }
232
233    pub fn rank(&self) -> usize {
234        Shape::RANK
235    }
236
237    pub fn as_slice(&self) -> &[Float] {
238        self.0.as_slice()
239    }
240
241    pub fn as_mut_slice(&mut self) -> &mut [Float] {
242        self.0.as_mut_slice()
243    }
244
245    pub fn at(&self, index: [usize; Shape::RANK]) -> &Float {
246        self.0.at(index)
247    }
248
249    pub fn set(&mut self, index: [usize; Shape::RANK], value: Float) {
250        self.0.set(index, value);
251    }
252
253    pub fn fill(&mut self, value: Float) {
254        self.0.fill(value);
255    }
256
257    pub fn zeros() -> Self {
258        Self::from_elem(0.0)
259    }
260
261    pub fn random() -> Self {
262        let mut rng = rand::rng();
263        Self::random_with(&mut rng)
264    }
265
266    pub fn random_with_seed(seed: u64) -> Self {
267        let mut rng = StdRng::seed_from_u64(seed);
268        Self::random_with(&mut rng)
269    }
270
271    pub fn random_with<R>(rng: &mut R) -> Self
272    where
273        R: Rng + ?Sized,
274    {
275        let mut out = Self::zeros();
276        for value in out.as_mut_slice() {
277            *value = rng.random::<Float>();
278        }
279        out
280    }
281
282    pub fn reshape<NewShape>(self) -> Tensor<NewShape>
283    where
284        NewShape: TensorShape,
285        (): ReshapePreservesElementCount<{ Shape::SIZE }, { NewShape::SIZE }>,
286    {
287        Tensor::<NewShape>::from_boxed(self.0.storage)
288    }
289
290    pub fn as_ref(&self) -> TensorRef<'_, Shape> {
291        TensorRef(StorageTensor::from_storage(self.as_slice()))
292    }
293
294    pub fn as_mut(&mut self) -> TensorMut<'_, Shape> {
295        TensorMut(StorageTensor::from_storage(self.as_mut_slice()))
296    }
297
298    pub fn map_inplace<F>(&mut self, mut f: F)
299    where
300        F: FnMut(Float) -> Float,
301    {
302        for value in self.as_mut_slice() {
303            *value = f(*value);
304        }
305    }
306
307    pub fn map<F>(&self, f: F) -> Self
308    where
309        F: FnMut(Float) -> Float,
310    {
311        let mut out = self.clone();
312        out.map_inplace(f);
313        out
314    }
315
316    pub fn zip_map<F>(&self, rhs: &Self, mut f: F) -> Self
317    where
318        F: FnMut(Float, Float) -> Float,
319    {
320        let mut out = Self::zeros();
321        for ((dst, lhs), rhs) in out
322            .as_mut_slice()
323            .iter_mut()
324            .zip(self.as_slice().iter().copied())
325            .zip(rhs.as_slice().iter().copied())
326        {
327            *dst = f(lhs, rhs);
328        }
329        out
330    }
331
332    pub fn sum(&self) -> Float {
333        self.0.sum()
334    }
335
336    pub fn mean(&self) -> Float {
337        self.0.mean()
338    }
339
340    #[deprecated(note = "Tensor::slice is not implemented yet")]
341    pub fn slice<T: Iterator>(_range: T) {}
342}
343
344impl<'a, Shape> TensorRef<'a, Shape>
345where
346    Shape: TensorShape,
347{
348    pub fn len(&self) -> usize {
349        Shape::SIZE
350    }
351
352    pub fn is_empty(&self) -> bool {
353        self.len() == 0
354    }
355
356    pub fn rank(&self) -> usize {
357        Shape::RANK
358    }
359
360    pub fn as_slice(&self) -> &[Float] {
361        self.0.as_slice()
362    }
363
364    pub fn at(&self, index: [usize; Shape::RANK]) -> &Float {
365        self.0.at(index)
366    }
367
368    pub fn sum(&self) -> Float {
369        self.0.sum()
370    }
371
372    pub fn mean(&self) -> Float {
373        self.0.mean()
374    }
375
376    pub fn reshape<NewShape>(self) -> TensorRef<'a, NewShape>
377    where
378        NewShape: TensorShape,
379        (): ReshapePreservesElementCount<{ Shape::SIZE }, { NewShape::SIZE }>,
380    {
381        TensorRef(StorageTensor::from_storage(self.0.storage))
382    }
383}
384
385impl<'a, Shape> TensorMut<'a, Shape>
386where
387    Shape: TensorShape,
388{
389    pub fn len(&self) -> usize {
390        Shape::SIZE
391    }
392
393    pub fn is_empty(&self) -> bool {
394        self.len() == 0
395    }
396
397    pub fn rank(&self) -> usize {
398        Shape::RANK
399    }
400
401    pub fn as_slice(&self) -> &[Float] {
402        self.0.as_slice()
403    }
404
405    pub fn as_mut_slice(&mut self) -> &mut [Float] {
406        self.0.as_mut_slice()
407    }
408
409    pub fn at(&self, index: [usize; Shape::RANK]) -> &Float {
410        self.0.at(index)
411    }
412
413    pub fn set(&mut self, index: [usize; Shape::RANK], value: Float) {
414        self.0.set(index, value);
415    }
416
417    pub fn fill(&mut self, value: Float) {
418        self.0.fill(value);
419    }
420
421    pub fn sum(&self) -> Float {
422        self.0.sum()
423    }
424
425    pub fn mean(&self) -> Float {
426        self.0.mean()
427    }
428
429    pub fn reshape<NewShape>(self) -> TensorMut<'a, NewShape>
430    where
431        NewShape: TensorShape,
432        (): ReshapePreservesElementCount<{ Shape::SIZE }, { NewShape::SIZE }>,
433    {
434        TensorMut(StorageTensor::from_storage(self.0.storage))
435    }
436}
437
438impl<Shape> Tensor<Shape>
439where
440    Shape: NonScalarShape,
441{
442    pub fn get_ref(&self, index: usize) -> TensorRef<'_, Shape::Subshape> {
443        assert!(index < Shape::AXIS_LEN, "index out of bounds");
444        let stride = <Shape::Subshape as TensorShape>::SIZE;
445        let start = index * stride;
446        let end = start + stride;
447        TensorRef(StorageTensor::from_storage(&self.as_slice()[start..end]))
448    }
449
450    pub fn get_mut(&mut self, index: usize) -> TensorMut<'_, Shape::Subshape> {
451        assert!(index < Shape::AXIS_LEN, "index out of bounds");
452        let stride = <Shape::Subshape as TensorShape>::SIZE;
453        let start = index * stride;
454        let end = start + stride;
455        TensorMut(StorageTensor::from_storage(
456            &mut self.as_mut_slice()[start..end],
457        ))
458    }
459
460    pub fn get(&self, index: usize) -> Tensor<Shape::Subshape> {
461        let row = self.get_ref(index);
462        Tensor::<Shape::Subshape>::from_boxed(row.as_slice().to_vec().into_boxed_slice())
463    }
464}
465
466impl<'a, Shape> TensorRef<'a, Shape>
467where
468    Shape: NonScalarShape,
469{
470    pub fn get_ref(&self, index: usize) -> TensorRef<'_, Shape::Subshape> {
471        assert!(index < Shape::AXIS_LEN, "index out of bounds");
472        let stride = <Shape::Subshape as TensorShape>::SIZE;
473        let start = index * stride;
474        let end = start + stride;
475        TensorRef(StorageTensor::from_storage(&self.as_slice()[start..end]))
476    }
477}
478
479impl<'a, Shape> TensorMut<'a, Shape>
480where
481    Shape: NonScalarShape,
482{
483    pub fn get_ref(&self, index: usize) -> TensorRef<'_, Shape::Subshape> {
484        assert!(index < Shape::AXIS_LEN, "index out of bounds");
485        let stride = <Shape::Subshape as TensorShape>::SIZE;
486        let start = index * stride;
487        let end = start + stride;
488        TensorRef(StorageTensor::from_storage(&self.as_slice()[start..end]))
489    }
490
491    pub fn get_mut(&mut self, index: usize) -> TensorMut<'_, Shape::Subshape> {
492        assert!(index < Shape::AXIS_LEN, "index out of bounds");
493        let stride = <Shape::Subshape as TensorShape>::SIZE;
494        let start = index * stride;
495        let end = start + stride;
496        TensorMut(StorageTensor::from_storage(
497            &mut self.as_mut_slice()[start..end],
498        ))
499    }
500}
501
502impl<const N: usize> Tensor<Dim<N, Nil>> {
503    pub fn dot(&self, rhs: &Self) -> Float {
504        self.as_slice()
505            .iter()
506            .zip(rhs.as_slice())
507            .map(|(lhs, rhs)| lhs * rhs)
508            .sum()
509    }
510}
511
512impl<const ROWS: usize, const COLS: usize> Tensor<Dim<ROWS, Dim<COLS, Nil>>> {
513    pub fn transpose(&self) -> Tensor<Dim<COLS, Dim<ROWS, Nil>>> {
514        let mut out = Tensor::<Dim<COLS, Dim<ROWS, Nil>>>::zeros();
515        let input = self.as_slice();
516        let output = out.as_mut_slice();
517        for row in 0..ROWS {
518            for col in 0..COLS {
519                output[col * ROWS + row] = input[row * COLS + col];
520            }
521        }
522        out
523    }
524
525    pub fn matvec(&self, rhs: &Tensor<Dim<COLS, Nil>>) -> Tensor<Dim<ROWS, Nil>> {
526        let mut out = Tensor::<Dim<ROWS, Nil>>::zeros();
527        let lhs = self.as_slice();
528        let rhs = rhs.as_slice();
529        for row in 0..ROWS {
530            let mut acc = 0.0;
531            for col in 0..COLS {
532                acc += lhs[row * COLS + col] * rhs[col];
533            }
534            out.as_mut_slice()[row] = acc;
535        }
536        out
537    }
538
539    pub fn matmul<const OUT_COLS: usize>(
540        &self,
541        rhs: &Tensor<Dim<COLS, Dim<OUT_COLS, Nil>>>,
542    ) -> Tensor<Dim<ROWS, Dim<OUT_COLS, Nil>>> {
543        let mut out = Tensor::<Dim<ROWS, Dim<OUT_COLS, Nil>>>::zeros();
544        let lhs = self.as_slice();
545        let rhs = rhs.as_slice();
546        let output = out.as_mut_slice();
547        for row in 0..ROWS {
548            for out_col in 0..OUT_COLS {
549                let mut acc = 0.0;
550                for inner in 0..COLS {
551                    acc += lhs[row * COLS + inner] * rhs[inner * OUT_COLS + out_col];
552                }
553                output[row * OUT_COLS + out_col] = acc;
554            }
555        }
556        out
557    }
558}
559
560macro_rules! impl_tensor_binop {
561    ($trait:ident, $method:ident, $assign_trait:ident, $assign_method:ident, $op:tt) => {
562        impl<Shape> ops::$trait<&Tensor<Shape>> for Tensor<Shape>
563        where
564            Shape: TensorShape,
565        {
566            type Output = Tensor<Shape>;
567
568            fn $method(mut self, rhs: &Tensor<Shape>) -> Self::Output {
569                ops::$assign_trait::$assign_method(&mut self, rhs);
570                self
571            }
572        }
573
574        impl<Shape> ops::$trait<&Tensor<Shape>> for &Tensor<Shape>
575        where
576            Shape: TensorShape,
577        {
578            type Output = Tensor<Shape>;
579
580            fn $method(self, rhs: &Tensor<Shape>) -> Self::Output {
581                self.clone().$method(rhs)
582            }
583        }
584
585        impl<Shape> ops::$assign_trait<&Tensor<Shape>> for Tensor<Shape>
586        where
587            Shape: TensorShape,
588        {
589            fn $assign_method(&mut self, rhs: &Tensor<Shape>) {
590                for (lhs, rhs) in self.as_mut_slice().iter_mut().zip(rhs.as_slice().iter().copied()) {
591                    *lhs = *lhs $op rhs;
592                }
593            }
594        }
595    };
596}
597
598macro_rules! impl_tensor_scalar_binop {
599    ($trait:ident, $method:ident, $assign_trait:ident, $assign_method:ident, $op:tt) => {
600        impl<Shape> ops::$trait<Float> for Tensor<Shape>
601        where
602            Shape: TensorShape,
603        {
604            type Output = Tensor<Shape>;
605
606            fn $method(mut self, rhs: Float) -> Self::Output {
607                ops::$assign_trait::$assign_method(&mut self, rhs);
608                self
609            }
610        }
611
612        impl<Shape> ops::$assign_trait<Float> for Tensor<Shape>
613        where
614            Shape: TensorShape,
615        {
616            fn $assign_method(&mut self, rhs: Float) {
617                for value in self.as_mut_slice() {
618                    *value = *value $op rhs;
619                }
620            }
621        }
622    };
623}
624
625impl_tensor_binop!(Add, add, AddAssign, add_assign, +);
626impl_tensor_binop!(Sub, sub, SubAssign, sub_assign, -);
627impl_tensor_binop!(Mul, mul, MulAssign, mul_assign, *);
628
629impl_tensor_scalar_binop!(Add, add, AddAssign, add_assign, +);
630impl_tensor_scalar_binop!(Sub, sub, SubAssign, sub_assign, -);
631impl_tensor_scalar_binop!(Mul, mul, MulAssign, mul_assign, *);
632impl_tensor_scalar_binop!(Div, div, DivAssign, div_assign, /);
633
634impl<const N: usize> From<[Float; N]> for Tensor<Dim<N, Nil>> {
635    fn from(value: [Float; N]) -> Self {
636        Self::from_boxed(Vec::from(value).into_boxed_slice())
637    }
638}
639
640impl TensorLiteral for Float {
641    type Shape = Nil;
642
643    fn write_flat(self, out: &mut Vec<Float>) {
644        out.push(self);
645    }
646}
647
648impl<T, const N: usize> TensorLiteral for [T; N]
649where
650    T: TensorLiteral,
651{
652    type Shape = Dim<N, T::Shape>;
653
654    fn write_flat(self, out: &mut Vec<Float>) {
655        for item in self {
656            item.write_flat(out);
657        }
658    }
659}
660
661#[doc(hidden)]
662pub fn __tensor_from_literal<T>(value: T) -> Tensor<T::Shape>
663where
664    T: TensorLiteral,
665{
666    let mut flat = Vec::with_capacity(<T::Shape as TensorShape>::SIZE);
667    value.write_flat(&mut flat);
668    Tensor::<T::Shape>::from_boxed(flat.into_boxed_slice())
669}
670
671#[macro_export]
672macro_rules! tensor {
673    [$($items:tt)*] => {
674        $crate::__tensor_from_literal([$($items)*])
675    };
676}
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681
682    type T3 = crate::shape!(2, 3, 4);
683
684    #[test]
685    fn indexing_borrows_and_owned_get_match_layout() {
686        let mut t = Tensor::<T3>::zeros();
687        let mut value = 0.0;
688        for i in 0..2 {
689            for j in 0..3 {
690                for k in 0..4 {
691                    t.set([i, j, k], value);
692                    value += 1.0;
693                }
694            }
695        }
696
697        assert_eq!(*t.at([1, 2, 3]), 23.0);
698
699        let row = t.get_ref(1);
700        assert_eq!(*row.at([2, 3]), 23.0);
701
702        let owned = t.get(1);
703        assert_eq!(*owned.at([2, 3]), 23.0);
704
705        let mut tmut = t.as_mut();
706        let mut row_mut = tmut.get_mut(0);
707        row_mut.set([0, 0], 99.0);
708        assert_eq!(*t.at([0, 0, 0]), 99.0);
709    }
710
711    #[test]
712    #[should_panic(expected = "index out of bounds")]
713    fn get_ref_panics_on_oob_index() {
714        let t = Tensor::<T3>::zeros();
715        let _ = t.get_ref(2);
716    }
717
718    #[test]
719    fn reshape_changes_shape_type_without_reordering_data() {
720        let flat = Tensor::<crate::shape!(6)>::from_flat([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
721        let reshaped = flat.reshape::<crate::shape!(2, 3)>();
722        assert_eq!(reshaped.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
723        assert_eq!(*reshaped.at([1, 2]), 6.0);
724    }
725
726    #[test]
727    fn tensor_literal_infers_shape_and_layout() {
728        let t = crate::tensor![[1.0, 2.0], [3.0, 4.0]];
729        assert_eq!(t.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
730        assert_eq!(*t.at([1, 0]), 3.0);
731    }
732
733    #[test]
734    fn tensor_debug_uses_public_type_names() {
735        let tensor = crate::tensor![[1.0, 2.0], [3.0, 4.0]];
736        assert!(format!("{tensor:?}").starts_with("Tensor {"));
737        let row = tensor.get_ref(1);
738        assert!(format!("{row:?}").starts_with("TensorRef {"));
739    }
740
741    #[test]
742    fn elementwise_ops_and_reductions_work() {
743        let lhs = crate::tensor![1.0, 2.0, 3.0];
744        let rhs = crate::tensor![4.0, 5.0, 6.0];
745
746        assert_eq!((&lhs + &rhs).as_slice(), &[5.0, 7.0, 9.0]);
747        assert_eq!((&rhs - &lhs).as_slice(), &[3.0, 3.0, 3.0]);
748        assert_eq!((&lhs * &rhs).as_slice(), &[4.0, 10.0, 18.0]);
749        assert_eq!((lhs.clone() + 1.0).as_slice(), &[2.0, 3.0, 4.0]);
750        assert_eq!(lhs.sum(), 6.0);
751        assert_eq!(lhs.mean(), 2.0);
752    }
753
754    #[test]
755    fn dot_transpose_and_matmul_work() {
756        let vec_a = crate::tensor![1.0, 2.0, 3.0];
757        let vec_b = crate::tensor![4.0, 5.0, 6.0];
758        assert_eq!(vec_a.dot(&vec_b), 32.0);
759
760        let lhs = crate::tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
761        let rhs = crate::tensor![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
762        let product = lhs.matmul(&rhs);
763        assert_eq!(product.as_slice(), &[58.0, 64.0, 139.0, 154.0]);
764
765        let transposed = lhs.transpose();
766        assert_eq!(transposed.as_slice(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
767    }
768
769    #[test]
770    fn seeded_random_is_reproducible() {
771        let a = Tensor::<crate::shape!(2, 3)>::random_with_seed(7);
772        let b = Tensor::<crate::shape!(2, 3)>::random_with_seed(7);
773        let c = Tensor::<crate::shape!(2, 3)>::random_with_seed(9);
774
775        assert_eq!(a.as_slice(), b.as_slice());
776        assert_ne!(a.as_slice(), c.as_slice());
777    }
778
779    #[test]
780    fn randomized_shape_stress_preserves_row_major_layout() {
781        let mut tensor = Tensor::<crate::shape!(2, 3, 4)>::zeros();
782        let mut rng = StdRng::seed_from_u64(42);
783
784        for index in 0..tensor.len() {
785            tensor.as_mut_slice()[index] = rng.random::<Float>();
786        }
787
788        for i in 0..2 {
789            for j in 0..3 {
790                for k in 0..4 {
791                    let flat_index = i * 12 + j * 4 + k;
792                    assert_eq!(*tensor.at([i, j, k]), tensor.as_slice()[flat_index]);
793                }
794            }
795        }
796
797        let reshaped = tensor.clone().reshape::<crate::shape!(4, 3, 2)>();
798        assert_eq!(tensor.as_slice(), reshaped.as_slice());
799    }
800
801    #[test]
802    fn borrowed_reshape_preserves_view_semantics() {
803        let mut tensor = crate::tensor![[1.0, 2.0], [3.0, 4.0]];
804
805        let flat_ref = tensor.as_ref().reshape::<crate::shape!(4)>();
806        assert_eq!(flat_ref.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
807        assert_eq!(*flat_ref.at([2]), 3.0);
808
809        {
810            let mut flat_mut = tensor.as_mut().reshape::<crate::shape!(4)>();
811            flat_mut.set([3], 9.0);
812        }
813
814        assert_eq!(tensor.as_slice(), &[1.0, 2.0, 3.0, 9.0]);
815    }
816}