sophus_tensor/
mut_tensor.rs

1use core::{
2    fmt::Debug,
3    marker::PhantomData,
4};
5
6use ndarray::{
7    Dim,
8    Ix,
9};
10use sophus_autodiff::linalg::{
11    SMat,
12    SVec,
13};
14
15use crate::{
16    prelude::*,
17    ArcTensor,
18    MutTensorView,
19    TensorView,
20};
21
22/// mutable tensor
23///
24/// See TensorView for more details of the tensor structure
25#[derive(Default, Debug, Clone)]
26pub struct MutTensor<
27    const TOTAL_RANK: usize,
28    const DRANK: usize,
29    const SRANK: usize,
30    Scalar: IsCoreScalar + 'static,
31    STensor: IsStaticTensor<Scalar, SRANK, ROWS, COLS> + 'static,
32    const ROWS: usize,
33    const COLS: usize,
34> where
35    ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
36{
37    /// ndarray of the static tensors with shape [D1, D2, ...]
38    pub mut_array: ndarray::Array<STensor, Dim<[Ix; DRANK]>>,
39    /// phantom data
40    pub phantom: PhantomData<(Scalar, STensor)>,
41}
42
43/// Converting a tensor of vectors to a tensor of Rx1 matrices
44pub trait InnerVecToMat<
45    const TOTAL_RANK: usize,
46    const DRANK: usize,
47    const SRANK: usize,
48    const HYBER_RANK_PLUS1: usize,
49    const SRANK_PLUS1: usize,
50    Scalar: IsCoreScalar + 'static,
51    const ROWS: usize,
52> where
53    SVec<Scalar, ROWS>: IsStaticTensor<Scalar, SRANK_PLUS1, ROWS, 1>,
54    ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
55{
56    /// The output tensor
57    type Output;
58
59    /// Convert to a tensor of Rx1 matrices
60    fn inner_vec_to_mat(self) -> Self::Output;
61}
62
63/// Converting a tensor of scalars to a tensor of 1-vectors
64pub trait InnerScalarToVec<
65    const TOTAL_RANK: usize,
66    const DRANK: usize,
67    const SRANK: usize,
68    const HYBER_RANK_PLUS1: usize,
69    const SRANK_PLUS1: usize,
70    Scalar: IsCoreScalar + 'static,
71> where
72    SVec<Scalar, 1>: IsStaticTensor<Scalar, SRANK_PLUS1, 1, 1>,
73    ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
74{
75    /// The output tensor
76    type Output;
77
78    /// Convert to a tensor of 1-vectors
79    fn inner_scalar_to_vec(self) -> Self::Output;
80}
81
82impl<Scalar: IsCoreScalar + 'static, const ROWS: usize> InnerVecToMat<3, 1, 2, 4, 2, Scalar, ROWS>
83    for MutTensorXR<3, 2, 1, Scalar, ROWS>
84{
85    type Output = MutTensorXRC<4, 2, 2, Scalar, ROWS, 1>;
86
87    fn inner_vec_to_mat(self) -> MutTensorXRC<4, 2, 2, Scalar, ROWS, 1> {
88        MutTensorXRC::<4, 2, 2, Scalar, ROWS, 1> {
89            mut_array: self.mut_array,
90            phantom: PhantomData,
91        }
92    }
93}
94
95impl<Scalar: IsCoreScalar + 'static> InnerScalarToVec<2, 0, 2, 3, 1, Scalar>
96    for MutTensorX<2, Scalar>
97{
98    type Output = MutTensorXR<3, 2, 1, Scalar, 1>;
99
100    fn inner_scalar_to_vec(self) -> MutTensorXR<3, 2, 1, Scalar, 1> {
101        MutTensorXR::<3, 2, 1, Scalar, 1> {
102            mut_array: self.mut_array.map(|x| SVec::<Scalar, 1>::new(x.clone())),
103            phantom: PhantomData,
104        }
105    }
106}
107
108/// Mutable tensor of scalars
109pub type MutTensorX<const DRANK: usize, Scalar> = MutTensor<DRANK, DRANK, 0, Scalar, Scalar, 1, 1>;
110
111/// Mutable tensor of vectors with shape R
112pub type MutTensorXR<
113    const TOTAL_RANK: usize,
114    const DRANK: usize,
115    const SRANK: usize,
116    Scalar,
117    const R: usize,
118> = MutTensor<TOTAL_RANK, DRANK, SRANK, Scalar, SVec<Scalar, R>, R, 1>;
119
120/// Mutable tensor of matrices with shape [R x C]
121pub type MutTensorXRC<
122    const TOTAL_RANK: usize,
123    const DRANK: usize,
124    const SRANK: usize,
125    Scalar,
126    const R: usize,
127    const C: usize,
128> = MutTensor<TOTAL_RANK, DRANK, SRANK, Scalar, SMat<Scalar, R, C>, R, C>;
129
130/// rank-1 mutable tensor of scalars with shape D0
131pub type MutTensorD<Scalar> = MutTensorX<1, Scalar>;
132
133/// rank-2 mutable tensor of scalars with shape [D0 x D1]
134pub type MutTensorDD<Scalar> = MutTensorX<2, Scalar>;
135
136/// rank-2 mutable tensor of vectors with shape [D0 x R]
137pub type MutTensorDR<Scalar, const R: usize> = MutTensorXR<2, 1, 1, Scalar, R>;
138
139/// rank-3 mutable tensor of scalars with shape [D0 x D1 x D2]
140pub type MutTensorDDD<Scalar> = MutTensorX<3, Scalar>;
141
142/// rank-3 mutable tensor of vectors with shape [D0 x D1 x R]
143pub type MutTensorDDR<Scalar, const R: usize> = MutTensorXR<3, 2, 1, Scalar, R>;
144
145/// rank-3 mutable tensor of matrices with shape [D0 x R x C]
146pub type MutTensorDRC<Scalar, const R: usize, const C: usize> = MutTensorXRC<3, 1, 2, Scalar, R, C>;
147
148/// rank-4 mutable tensor of scalars with shape [D0 x D1 x D2 x D3]
149pub type MutTensorDDDD<Scalar> = MutTensorX<4, Scalar>;
150
151/// rank-4 mutable tensor of vectors with shape [D0 x D1 x D2 x R]
152pub type MutTensorDDDR<Scalar, const R: usize> = MutTensorXR<4, 3, 1, Scalar, R>;
153
154/// rank-4 mutable tensor of matrices with shape [D0 x D1 x R x C]
155pub type MutTensorDDRC<Scalar, const R: usize, const C: usize> =
156    MutTensorXRC<4, 2, 2, Scalar, R, C>;
157
158/// rank-5 mutable tensor of scalars with shape [D0 x D1 x D2 x D3 x D4]
159pub type MutTensorDDDDD<Scalar> = MutTensorX<5, Scalar>;
160
161/// rank-5 mutable tensor of vectors with shape [D0 x D1 x D2 x D3 x R]
162pub type MutTensorDDDDR<Scalar, const R: usize> = MutTensorXR<5, 4, 1, Scalar, R>;
163
164/// rank-5 mutable tensor of matrices with shape [D0 x D1 x D2 x R x C]
165pub type MutTensorDDDRC<Scalar, const R: usize, const C: usize> =
166    MutTensorXRC<5, 3, 2, Scalar, R, C>;
167
168macro_rules! mut_tensor_is_view {
169    ($scalar_rank:literal, $srank:literal, $drank:literal) => {
170
171
172        impl<
173        'a,
174                Scalar: IsCoreScalar + 'static,
175                STensor: IsStaticTensor<Scalar, $srank,  ROWS, COLS> + 'static,
176                const ROWS: usize,
177                const COLS: usize,
178            > IsTensorLike<'a, $scalar_rank, $drank, $srank, Scalar, STensor,  ROWS, COLS>
179            for MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor,  ROWS, COLS>
180        {
181            fn elem_view<'b:'a>(
182                &'b self,
183            ) -> ndarray::ArrayView<'a, STensor, ndarray::Dim<[ndarray::Ix; $drank]>> {
184                self.view().elem_view
185            }
186
187            fn get(& self, idx: [usize; $drank]) -> STensor {
188                self.view().get(idx)
189            }
190
191            fn dims(&self) -> [usize; $drank] {
192                self.view().dims()
193            }
194
195            fn scalar_view<'b:'a>(
196                &'b self,
197            ) -> ndarray::ArrayView<'a, Scalar, ndarray::Dim<[ndarray::Ix; $scalar_rank]>> {
198                self.view().scalar_view
199            }
200
201            fn scalar_get(&'a self, idx: [usize; $scalar_rank]) -> Scalar {
202                self.view().scalar_get(idx)
203            }
204
205            fn scalar_dims(&self) -> [usize; $scalar_rank] {
206                self.view().scalar_dims()
207            }
208
209            fn to_mut_tensor(
210                &self,
211            ) -> MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor,  ROWS, COLS> {
212                MutTensor {
213                    mut_array: self.elem_view().to_owned(),
214                    phantom: PhantomData::default(),
215                }
216            }
217        }
218
219        impl<
220        'a,
221                Scalar: IsCoreScalar + 'static,
222                STensor: IsStaticTensor<Scalar, $srank,  ROWS, COLS> + 'static,
223                const ROWS: usize,
224                const COLS: usize,
225
226            >
227            IsMutTensorLike<'a,
228                $scalar_rank, $drank, $srank,
229                Scalar, STensor,
230                ROWS, COLS
231            >
232            for MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
233        {
234            fn elem_view_mut<'b:'a>(
235                &'b mut self,
236            ) -> ndarray::ArrayViewMut<'a, STensor, ndarray::Dim<[ndarray::Ix; $drank]>>{
237                self.mut_view().elem_view_mut
238            }
239            fn get_mut(& mut self, idx: [usize; $drank]) -> &mut STensor{
240                &mut self.mut_array[idx]
241            }
242
243            fn scalar_view_mut<'b:'a>(
244                &'b mut self,
245            ) -> ndarray::ArrayViewMut<'a, Scalar, ndarray::Dim<[ndarray::Ix; $scalar_rank]>>{
246                self.mut_view().scalar_view_mut
247            }
248        }
249
250        impl<'a,  Scalar: IsCoreScalar+ 'static,
251        STensor: IsStaticTensor<Scalar, $srank,  ROWS, COLS> + 'static,
252        const ROWS: usize,
253        const COLS: usize,
254
255        > PartialEq for
256            MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
257        {
258            fn eq(&self, other: &Self) -> bool {
259                self.view().scalar_view == other.view().scalar_view
260            }
261        }
262
263        impl<'a,  Scalar: IsCoreScalar+ 'static,
264                STensor: IsStaticTensor<Scalar, $srank,  ROWS, COLS> + 'static,
265                const ROWS: usize,
266                const COLS: usize,
267
268        >
269            MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
270        {
271
272            /// create a new tensor from a shape - filled with zeros
273            pub fn from_shape(size: [usize; $drank]) -> Self {
274                MutTensor::<$scalar_rank, $drank, $srank, Scalar, STensor,
275                            ROWS, COLS>::from_shape_and_val(
276                    size, num_traits::Zero::zero()
277                )
278            }
279
280             /// create a new mutable tensor by applying a binary operator to each element of two
281            /// other tensors
282            pub fn from_map2<
283                'b,
284                const OTHER_HRANK: usize, const OTHER_SRANK: usize,
285                OtherScalar: IsCoreScalar + 'static,
286                OtherSTensor: IsStaticTensor<
287                    OtherScalar, OTHER_SRANK, OTHER_ROWS, OTHER_COLS
288                > + 'static,
289                const OTHER_ROWS: usize, const OTHER_COLS: usize,
290            V : IsTensorView::<'b,
291                OTHER_HRANK, $drank, OTHER_SRANK,
292                OtherScalar, OtherSTensor,
293                OTHER_ROWS, OTHER_COLS
294            >,
295            const OTHER_HRANK2: usize, const OTHER_SRANK2: usize,
296            OtherScalar2: IsCoreScalar + 'static,
297            OtherSTensor2: IsStaticTensor<
298                OtherScalar2, OTHER_SRANK2, OTHER_ROWS2, OTHER_COLS2,
299            > + 'static,
300            const OTHER_ROWS2: usize, const OTHER_COLS2: usize,
301            V2 : IsTensorView::<'b,
302                OTHER_HRANK2, $drank, OTHER_SRANK2,
303                OtherScalar2, OtherSTensor2,
304                OTHER_ROWS2, OTHER_COLS2
305            >,
306            F: FnMut(&OtherSTensor, &OtherSTensor2)->STensor
307            >(
308                view: &'b V,
309                view2: &'b V2,
310                mut op: F,
311            )
312            -> Self
313            where
314                ndarray::Dim<[ndarray::Ix; OTHER_HRANK]>: ndarray::Dimension,
315                ndarray::Dim<[ndarray::Ix; OTHER_HRANK2]>: ndarray::Dimension
316
317            {
318                let mut out  = Self::from_shape(view.dims());
319                ndarray::Zip::from(&mut out.elem_view_mut())
320                .and(&view.elem_view())
321                .and(&view2.elem_view())
322                .for_each(
323                    |out, v, v2|{
324                      *out = op(v, v2);
325                    });
326                out
327            }
328        }
329
330        impl<'a,  Scalar: IsCoreScalar+ 'static,
331                STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
332                const ROWS: usize,
333                const COLS: usize,
334
335        >
336            MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
337        {
338
339
340            /// returns a mutable view of the tensor
341            pub fn mut_view<'b: 'a>(
342                &'b mut self,
343            ) -> MutTensorView<'a,
344                               $scalar_rank, $drank, $srank,
345                               Scalar, STensor,
346                               ROWS, COLS>
347            {
348                MutTensorView::<
349                    'a,
350                    $scalar_rank, $drank, $srank,
351                    Scalar, STensor, ROWS, COLS>::new
352                (
353                    self.mut_array.view_mut()
354                )
355            }
356
357            /// returns a view of the tensor
358            pub fn view<'b: 'a>(&'b self
359            ) -> TensorView<'a, $scalar_rank, $drank, $srank, Scalar, STensor,
360                            ROWS, COLS> {
361                TensorView::<'a, $scalar_rank, $drank, $srank, Scalar, STensor,
362                             ROWS, COLS>::new(
363                    self.mut_array.view())
364            }
365
366
367            /// create a new tensor from a shape and a value
368            pub fn from_shape_and_val
369            (
370                shape: [usize; $drank],
371                val: STensor,
372            ) -> Self
373            {
374                Self{
375                    mut_array: ndarray::Array::<STensor, Dim<[Ix; $drank]>>::from_elem(shape, val),
376                    phantom: PhantomData::default()
377                }
378            }
379
380            /// create a new mutable tensor by copying from another tensor
381            pub fn make_copy_from(
382                v: &TensorView<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
383            ) -> Self
384            {
385                IsTensorLike::to_mut_tensor(v)
386            }
387
388            /// return ArcTensor copy of the mutable tensor
389            pub fn to_shared(self)
390                -> ArcTensor::<$scalar_rank, $drank, $srank, Scalar, STensor,  ROWS, COLS>
391            {
392                ArcTensor::<
393                    $scalar_rank,
394                    $drank, $srank,
395                    Scalar, STensor,
396                    ROWS, COLS>::from_mut_tensor(self)
397            }
398
399            /// create a new mutable tensor by applying a unary operator to each element of another
400            /// tensor
401            pub fn from_map<
402                'b,
403                const OTHER_HRANK: usize, const OTHER_SRANK: usize,
404                OtherScalar: IsCoreScalar+ 'static,
405                OtherSTensor: IsStaticTensor<
406                    OtherScalar, OTHER_SRANK,
407                    OTHER_ROWS, OTHER_COLS
408                > + 'static,
409                const OTHER_ROWS: usize, const OTHER_COLS: usize,
410                V : IsTensorView::<
411                    'b,
412                    OTHER_HRANK, $drank, OTHER_SRANK,
413                    OtherScalar, OtherSTensor,
414                    OTHER_ROWS, OTHER_COLS
415                >,
416                F: FnMut(&OtherSTensor)-> STensor
417            > (
418                view:  &'b V,
419                op: F,
420            )
421            -> Self where
422                ndarray::Dim<[ndarray::Ix; OTHER_HRANK]>: ndarray::Dimension,
423                ndarray::Dim<[ndarray::Ix; $drank]>: ndarray::Dimension,
424            {
425                Self {
426                    mut_array: view.elem_view().map(op),
427                    phantom: PhantomData::default()
428                }
429            }
430
431
432
433
434
435        }
436    };
437}
438
439mut_tensor_is_view!(1, 0, 1);
440mut_tensor_is_view!(2, 0, 2);
441mut_tensor_is_view!(2, 1, 1);
442mut_tensor_is_view!(3, 0, 3);
443mut_tensor_is_view!(3, 1, 2);
444mut_tensor_is_view!(3, 2, 1);
445mut_tensor_is_view!(4, 0, 4);
446mut_tensor_is_view!(4, 1, 3);
447mut_tensor_is_view!(4, 2, 2);
448mut_tensor_is_view!(5, 0, 5);
449mut_tensor_is_view!(5, 1, 4);
450mut_tensor_is_view!(5, 2, 3);
451
452macro_rules! mut_tensor_is_view_drank_1 {
453    ($scalar_rank:literal, $srank:literal) => {
454        impl<
455                'a,
456                Scalar: IsCoreScalar + 'static,
457                STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
458                const ROWS: usize,
459                const COLS: usize,
460            > MutTensor<$scalar_rank, 1, $srank, Scalar, STensor, ROWS, COLS>
461        {
462            /// create a new mutable tensor from fn
463            pub fn from_fn<F: FnMut([usize; 1]) -> STensor>(shape: [usize; 1], mut op: F) -> Self {
464                Self {
465                    mut_array: ndarray::Array::<STensor, Dim<[Ix; 1]>>::from_shape_fn(
466                        shape,
467                        |idx| op([idx]),
468                    ),
469                    phantom: PhantomData::default(),
470                }
471            }
472        }
473    };
474}
475
476mut_tensor_is_view_drank_1!(1, 0);
477mut_tensor_is_view_drank_1!(2, 1);
478mut_tensor_is_view_drank_1!(3, 2);
479
480macro_rules! mut_tensor_is_view_drank_2_plus {
481    ($scalar_rank:literal, $srank:literal, $drank:literal) => {
482        impl<
483                'a,
484                Scalar: IsCoreScalar + 'static,
485                STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
486                const ROWS: usize,
487                const COLS: usize,
488            > MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
489        {
490            /// create a new mutable tensor from fn
491            pub fn from_fn<F: FnMut([usize; $drank]) -> STensor>(
492                shape: [usize; $drank],
493                mut op: F,
494            ) -> Self {
495                Self {
496                    mut_array: ndarray::Array::<STensor, Dim<[Ix; $drank]>>::from_shape_fn(
497                        shape,
498                        |idx| op(idx.try_into().unwrap()),
499                    ),
500                    phantom: PhantomData::default(),
501                }
502            }
503        }
504    };
505}
506
507mut_tensor_is_view_drank_2_plus!(2, 0, 2);
508mut_tensor_is_view_drank_2_plus!(3, 0, 3);
509mut_tensor_is_view_drank_2_plus!(3, 1, 2);
510mut_tensor_is_view_drank_2_plus!(4, 0, 4);
511mut_tensor_is_view_drank_2_plus!(4, 1, 3);
512mut_tensor_is_view_drank_2_plus!(4, 2, 2);
513mut_tensor_is_view_drank_2_plus!(5, 0, 5);
514mut_tensor_is_view_drank_2_plus!(5, 1, 4);
515mut_tensor_is_view_drank_2_plus!(5, 2, 3);
516
517#[test]
518fn mut_tensor_tests() {
519    use log::info;
520    #[cfg(feature = "simd")]
521    use sophus_autodiff::linalg::BatchMatF64;
522    {
523        let _rank1_tensor = MutTensorD::<u8>::default();
524        //assert!(rank1_tensor.is_empty());
525        let shape = [2];
526        let tensor_f32 = MutTensorD::from_shape_and_val(shape, 0.0);
527        //assert!(!tensor_f32.is_empty());
528        assert_eq!(tensor_f32.view().dims(), shape);
529    }
530    {
531        let _rank2_tensor = MutTensorDD::<u8>::default();
532        //assert!(rank2_tensor.is_empty());
533        let shape = [3, 2];
534        let tensor_f32 = MutTensorDD::<f32>::from_shape(shape);
535        // assert!(!tensor_f32.is_empty());
536        assert_eq!(tensor_f32.view().dims(), shape);
537    }
538    {
539        let _rank3_tensor = MutTensorDDD::<u8>::default();
540        // assert!(rank3_tensor.is_empty());
541        let shape = [3, 2, 4];
542        let tensor_f32 = MutTensorDDD::<f32>::from_shape(shape);
543        //  assert!(!tensor_f32.is_empty());
544        assert_eq!(tensor_f32.view().dims(), shape);
545    }
546    //transform
547    {
548        let shape = [3];
549        {
550            let tensor_f32 = MutTensorD::from_shape_and_val(shape, 1.0);
551            let op = |v: &f32| {
552                let mut value = SVec::<f32, 3>::default();
553                value[0] = *v;
554                value[1] = 0.2 * *v;
555                value[2] = 0.3 * *v;
556                value
557            };
558            let pattern = MutTensorDR::<f32, 3>::from_map(&tensor_f32.view(), op);
559
560            info!("p :{}", pattern.mut_array);
561            // assert_eq!(
562            //     pattern.slice(),
563            //     MutTensorDR::from_shape_and_val(shape, op(1.0)).slice()
564            // );
565        }
566        let shape = [3, 2];
567        {
568            let tensor_f32 = MutTensorDD::from_shape_and_val(shape, 1.0);
569            let op = |v: &f32| {
570                let mut value = SVec::<f32, 3>::default();
571                value[0] = *v;
572                value[1] = 0.2 * *v;
573                value[2] = 0.3 * *v;
574                value
575            };
576            let pattern = MutTensorDDR::from_map(&tensor_f32.view(), op);
577            info!("p :{}", pattern.mut_array);
578            info!("p :{}", pattern.view().scalar_view());
579        }
580        let shape = [3, 2, 4];
581        {
582            let tensor_f32 = MutTensorDDD::from_shape_and_val(shape, 1.0);
583            let op = |v: &f32| {
584                let mut value = SVec::<f32, 3>::default();
585                value[0] = *v;
586                value[1] = 0.2 * *v;
587                value[2] = 0.3 * *v;
588                value
589            };
590            let pattern = MutTensorDDDR::from_map(&tensor_f32.view(), op);
591            info!("p :{}", pattern.mut_array);
592            info!("p :{}", pattern.view().scalar_view());
593        }
594    }
595
596    //linalg
597    #[cfg(feature = "simd")]
598    {
599        let shape = [3];
600
601        let _tensor_u8 = MutTensorD::from_shape_and_val(shape, 0);
602        let _tensor_f64 = MutTensorDRC::from_shape_and_val(shape, SMat::<f64, 4, 4>::zeros());
603        let _tensor_batched_f32 =
604            MutTensorDRC::from_shape_and_val(shape, BatchMatF64::<2, 3, 4>::zeros());
605    }
606
607    //from_raw_data
608    {
609        let shape = [1];
610        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
611        let data_mat = SMat::<f32, 3, 2>::from_vec(data.to_vec());
612        let tensor_f32 = MutTensorDRC::from_shape_and_val(shape, data_mat);
613        assert_eq!(tensor_f32.dims(), shape);
614        assert_eq!(tensor_f32.view().scalar_get([0, 0, 0]), data[0]);
615        assert_eq!(tensor_f32.view().scalar_get([0, 1, 0]), data[1]);
616        assert_eq!(tensor_f32.view().scalar_get([0, 2, 0]), data[2]);
617        assert_eq!(tensor_f32.view().scalar_get([0, 0, 1]), data[3]);
618        assert_eq!(tensor_f32.view().scalar_get([0, 1, 1]), data[4]);
619        assert_eq!(tensor_f32.view().scalar_get([0, 2, 1]), data[5]);
620    }
621}