zenu_matrix/
matrix.rs

1use std::{any::TypeId, marker::PhantomData};
2
3use crate::{
4    device::{Device, DeviceBase},
5    dim::{cal_offset, default_stride, DimDyn, DimTrait, LessDimTrait},
6    index::{IndexAxisTrait, SliceTrait},
7    num::Num,
8    shape_stride::ShapeStride,
9    slice::Slice,
10};
11
12#[cfg(feature = "nvidia")]
13use crate::device::nvidia::Nvidia;
14
15pub trait Repr: Default {
16    type Item: Num;
17
18    fn drop_memory<D: DeviceBase>(ptr: *mut Self::Item, _: D);
19    fn clone_memory<D: DeviceBase>(ptr: *mut Self::Item, len: usize, _: D) -> *mut Self::Item;
20}
21
22pub trait OwnedRepr: Repr {}
23
24pub struct Owned<T: Num> {
25    _maker: PhantomData<T>,
26}
27
28pub struct Ref<A> {
29    _maker: PhantomData<A>,
30}
31
32impl<T: Num> Default for Owned<T> {
33    fn default() -> Self {
34        Owned {
35            _maker: PhantomData,
36        }
37    }
38}
39
40impl<A> Default for Ref<A> {
41    fn default() -> Self {
42        Ref {
43            _maker: PhantomData,
44        }
45    }
46}
47
48impl<'a, T: Num> Repr for Ref<&'a T> {
49    type Item = T;
50
51    fn drop_memory<D: DeviceBase>(_ptr: *mut Self::Item, _: D) {}
52    fn clone_memory<D: DeviceBase>(ptr: *mut Self::Item, _len: usize, _: D) -> *mut Self::Item {
53        ptr
54    }
55}
56
57impl<'a, T: Num> Repr for Ref<&'a mut T> {
58    type Item = T;
59
60    fn drop_memory<D: DeviceBase>(_ptr: *mut Self::Item, _: D) {}
61    fn clone_memory<D: DeviceBase>(ptr: *mut Self::Item, _len: usize, _: D) -> *mut Self::Item {
62        ptr
63    }
64}
65
66impl<T: Num> Repr for Owned<T> {
67    type Item = T;
68
69    fn drop_memory<D: DeviceBase>(ptr: *mut Self::Item, _: D) {
70        D::drop_ptr(ptr);
71    }
72
73    fn clone_memory<D: DeviceBase>(ptr: *mut Self::Item, len: usize, _: D) -> *mut Self::Item {
74        D::clone_ptr(ptr, len)
75    }
76}
77
78impl<T: Num> OwnedRepr for Owned<T> {}
79
80pub struct Ptr<R, D>
81where
82    R: Repr,
83    D: DeviceBase,
84{
85    ptr: *mut R::Item,
86    len: usize,
87    offset: usize,
88    repr: PhantomData<R>,
89    device: PhantomData<D>,
90}
91
92impl<R, D> Ptr<R, D>
93where
94    R: Repr,
95    D: DeviceBase,
96{
97    pub(crate) fn new(ptr: *mut R::Item, len: usize, offset: usize) -> Self {
98        Ptr {
99            ptr,
100            len,
101            offset,
102            repr: PhantomData,
103            device: PhantomData,
104        }
105    }
106
107    #[must_use]
108    pub fn offset_ptr(&self, offset: usize) -> Ptr<Ref<&R::Item>, D> {
109        Ptr {
110            ptr: self.ptr,
111            len: self.len,
112            offset: self.offset + offset,
113            repr: PhantomData,
114            device: PhantomData,
115        }
116    }
117
118    pub(crate) fn len(&self) -> usize {
119        self.len
120    }
121
122    #[expect(clippy::missing_panics_doc)]
123    #[must_use]
124    pub fn get_item(&self, offset: usize) -> R::Item {
125        assert!(offset < self.len, "Index out of bounds");
126        D::get_item(self.ptr, offset + self.offset)
127    }
128
129    fn to_ref<'a>(&self) -> Ptr<Ref<&'a R::Item>, D> {
130        Ptr {
131            ptr: self.ptr,
132            len: self.len,
133            offset: self.offset,
134            repr: PhantomData,
135            device: PhantomData,
136        }
137    }
138
139    fn to<Dout: DeviceBase>(&self) -> Ptr<Owned<R::Item>, Dout> {
140        #[cfg(feature = "nvidia")]
141        use crate::device::cpu::Cpu;
142
143        let self_raw_ptr = self.ptr;
144        let len = self.len;
145
146        let ptr = match (TypeId::of::<D>(), TypeId::of::<Dout>()) {
147            (a, b) if a == b => Owned::clone_memory(self_raw_ptr, len, D::default()),
148            #[cfg(feature = "nvidia")]
149            (a, b) if a == TypeId::of::<Cpu>() && b == TypeId::of::<Nvidia>() => {
150                zenu_cuda::runtime::copy_to_gpu(self_raw_ptr, len)
151            }
152            #[cfg(feature = "nvidia")]
153            (a, b) if a == TypeId::of::<Nvidia>() && b == TypeId::of::<Cpu>() => {
154                zenu_cuda::runtime::copy_to_cpu(self_raw_ptr, len)
155            }
156            _ => unreachable!(),
157        };
158
159        Ptr::new(ptr, len, self.offset)
160    }
161}
162
163impl<R, D> Drop for Ptr<R, D>
164where
165    R: Repr,
166    D: DeviceBase,
167{
168    fn drop(&mut self) {
169        R::drop_memory(self.ptr, D::default());
170    }
171}
172
173impl<'a, T: Num, D: DeviceBase> Ptr<Ref<&'a mut T>, D> {
174    #[must_use]
175    pub fn offset_ptr_mut(self, offset: usize) -> Ptr<Ref<&'a mut T>, D> {
176        Ptr {
177            ptr: self.ptr,
178            len: self.len,
179            offset: self.offset + offset,
180            repr: PhantomData,
181            device: PhantomData,
182        }
183    }
184
185    #[expect(clippy::missing_panics_doc)]
186    pub fn assign_item(&self, offset: usize, value: T) {
187        assert!(offset < self.len, "Index out of bounds");
188        D::assign_item(self.ptr, offset + self.offset, value);
189    }
190}
191
192impl<R, D> Clone for Ptr<R, D>
193where
194    R: Repr,
195    D: DeviceBase,
196{
197    fn clone(&self) -> Self {
198        Ptr {
199            ptr: R::clone_memory(self.ptr, self.len, D::default()),
200            len: self.len,
201            offset: self.offset,
202            repr: PhantomData,
203            device: PhantomData,
204        }
205    }
206}
207
208impl<R, D> Ptr<R, D>
209where
210    R: OwnedRepr,
211    D: DeviceBase,
212{
213    fn to_ref_mut<'a>(&mut self) -> Ptr<Ref<&'a mut R::Item>, D> {
214        Ptr {
215            ptr: self.ptr,
216            len: self.len,
217            offset: self.offset,
218            repr: PhantomData,
219            device: PhantomData,
220        }
221    }
222}
223
224pub struct Matrix<R, S, D>
225where
226    R: Repr,
227    S: DimTrait,
228    D: DeviceBase,
229{
230    ptr: Ptr<R, D>,
231    shape: S,
232    stride: S,
233}
234
235impl<R, S, D> Clone for Matrix<R, S, D>
236where
237    R: Repr,
238    S: DimTrait,
239    D: DeviceBase,
240{
241    fn clone(&self) -> Self {
242        Matrix {
243            ptr: self.ptr.clone(),
244            shape: self.shape,
245            stride: self.stride,
246        }
247    }
248}
249
250impl<R, S, D> Matrix<R, S, D>
251where
252    R: Repr,
253    S: DimTrait,
254    D: DeviceBase,
255{
256    pub(crate) fn new(ptr: Ptr<R, D>, shape: S, stride: S) -> Self {
257        Matrix { ptr, shape, stride }
258    }
259
260    pub(crate) unsafe fn ptr(&self) -> &Ptr<R, D> {
261        &self.ptr
262    }
263
264    pub fn offset(&self) -> usize {
265        self.ptr.offset
266    }
267
268    pub fn shape_stride(&self) -> ShapeStride<S> {
269        ShapeStride::new(self.shape, self.stride)
270    }
271
272    pub fn shape(&self) -> S {
273        self.shape
274    }
275
276    pub fn stride(&self) -> S {
277        self.stride
278    }
279
280    pub fn is_default_stride(&self) -> bool {
281        self.shape_stride().is_default_stride()
282    }
283
284    pub fn is_transpose_default_stride(&self) -> bool {
285        self.shape_stride().is_transposed_default_stride()
286    }
287
288    pub fn as_ptr(&self) -> *const R::Item {
289        unsafe { self.ptr.ptr.add(self.offset()) }
290    }
291
292    /// this code retunrs a slice of the matrix
293    /// WARNING: even if the matrix has offset, the slice will be created from the original pointer
294    pub fn to_vec(&self) -> Vec<R::Item>
295    where
296        R::Item: Clone,
297    {
298        let ptr_len = self.ptr.len();
299        let mut vec = Vec::with_capacity(ptr_len);
300        let non_offset_ptr = Ptr::<Ref<&R::Item>, D>::new(self.ptr.ptr, ptr_len, 0);
301        for i in 0..ptr_len {
302            vec.push(non_offset_ptr.get_item(i));
303        }
304        vec
305    }
306
307    pub fn into_dyn_dim(self) -> Matrix<R, DimDyn, D> {
308        let mut shape = DimDyn::default();
309        let mut stride = DimDyn::default();
310
311        for i in 0..self.shape.len() {
312            shape.push_dim(self.shape[i]);
313            stride.push_dim(self.stride[i]);
314        }
315        Matrix {
316            ptr: self.ptr,
317            shape,
318            stride,
319        }
320    }
321
322    pub fn update_shape_stride(&mut self, shape_stride: ShapeStride<S>) {
323        self.shape = shape_stride.shape();
324        self.stride = shape_stride.stride();
325    }
326
327    pub fn update_shape(&mut self, shape: S) {
328        self.shape = shape;
329        self.stride = default_stride(shape);
330    }
331
332    pub fn update_stride(&mut self, stride: S) {
333        self.stride = stride;
334    }
335
336    pub fn into_dim<S2>(self) -> Matrix<R, S2, D>
337    where
338        S2: DimTrait,
339    {
340        Matrix {
341            ptr: self.ptr,
342            shape: S2::from(self.shape.slice()),
343            stride: S2::from(self.stride.slice()),
344        }
345    }
346
347    pub fn slice<I>(&self, index: I) -> Matrix<Ref<&R::Item>, S, D>
348    where
349        I: SliceTrait<Dim = S>,
350    {
351        let shape = self.shape();
352        let stride = self.stride();
353        let new_shape_stride = index.sliced_shape_stride(shape, stride);
354        let offset = index.sliced_offset(stride);
355        Matrix {
356            ptr: self.ptr.offset_ptr(offset),
357            shape: new_shape_stride.shape(),
358            stride: new_shape_stride.stride(),
359        }
360    }
361
362    pub fn slice_dyn(&self, index: Slice) -> Matrix<Ref<&R::Item>, DimDyn, D> {
363        let shape_stride = self.shape_stride().into_dyn();
364        let new_shape_stride =
365            index.sliced_shape_stride(shape_stride.shape(), shape_stride.stride());
366        let offset = index.sliced_offset(shape_stride.stride());
367        Matrix {
368            ptr: self.ptr.offset_ptr(offset),
369            shape: new_shape_stride.shape(),
370            stride: new_shape_stride.stride(),
371        }
372    }
373
374    pub fn index_axis<I>(&self, index: I) -> Matrix<Ref<&R::Item>, S, D>
375    where
376        I: IndexAxisTrait,
377        S: LessDimTrait,
378        S::LessDim: DimTrait,
379    {
380        let shape = self.shape();
381        let stride = self.stride();
382        let new_shape_stride = index.get_shape_stride(shape, stride);
383        let offset = index.offset(stride);
384        Matrix {
385            ptr: self.ptr.offset_ptr(offset),
386            shape: new_shape_stride.shape(),
387            stride: new_shape_stride.stride(),
388        }
389    }
390
391    pub fn index_axis_dyn<I>(&self, index: I) -> Matrix<Ref<&R::Item>, DimDyn, D>
392    where
393        I: IndexAxisTrait,
394    {
395        let shape_stride = self.shape_stride().into_dyn();
396        let new_shape_stride = index.get_shape_stride(shape_stride.shape(), shape_stride.stride());
397        let offset = index.offset(shape_stride.stride());
398        Matrix {
399            ptr: self.ptr.offset_ptr(offset),
400            shape: new_shape_stride.shape(),
401            stride: new_shape_stride.stride(),
402        }
403    }
404
405    #[expect(clippy::missing_panics_doc)]
406    pub fn index_item<I: Into<S>>(&self, index: I) -> R::Item {
407        let index = index.into();
408        assert!(!self.shape().is_overflow(index), "Index out of bounds");
409        let offset = cal_offset(index, self.stride());
410        self.ptr.get_item(offset)
411    }
412
413    pub fn to_ref<'a>(&self) -> Matrix<Ref<&'a R::Item>, S, D> {
414        Matrix {
415            ptr: self.ptr.to_ref(),
416            shape: self.shape,
417            stride: self.stride,
418        }
419    }
420
421    pub fn convert_dim_type<Dout: DimTrait>(self) -> Matrix<R, Dout, D> {
422        Matrix {
423            ptr: self.ptr,
424            shape: Dout::from(self.shape.slice()),
425            stride: Dout::from(self.stride.slice()),
426        }
427    }
428
429    pub fn new_matrix(&self) -> Matrix<Owned<R::Item>, S, D>
430    where
431        D: Device,
432    {
433        let mut owned = Matrix::zeros(self.shape());
434        owned.to_ref_mut().copy_from(self);
435        owned
436    }
437
438    #[expect(clippy::missing_errors_doc)]
439    pub fn try_to_scalar(&self) -> Result<R::Item, String> {
440        if self.shape().is_scalar() {
441            let scalr = self.ptr.get_item(0);
442            Ok(scalr)
443        } else {
444            Err("this matrix is not scalar".to_string())
445        }
446    }
447
448    #[expect(clippy::missing_panics_doc)]
449    pub fn to_scalar(&self) -> R::Item {
450        if let Ok(scalar) = self.try_to_scalar() {
451            scalar
452        } else {
453            panic!("Matrix is not scalar");
454        }
455    }
456
457    #[expect(clippy::missing_panics_doc)]
458    pub fn as_slice(&self) -> &[R::Item] {
459        if self.shape().len() <= 1 {
460            self.as_slice_unchecked()
461        } else {
462            panic!("Invalid shape");
463        }
464    }
465
466    pub fn as_slice_unchecked(&self) -> &[R::Item] {
467        let num_elm = self.shape().num_elm();
468        unsafe { std::slice::from_raw_parts(self.as_ptr(), num_elm) }
469    }
470}
471
472impl<T, S, D> Matrix<Owned<T>, S, D>
473where
474    T: Num,
475    D: DeviceBase,
476    S: DimTrait,
477{
478    pub fn to_ref_mut<'a>(&mut self) -> Matrix<Ref<&'a mut T>, S, D> {
479        Matrix {
480            ptr: self.ptr.to_ref_mut(),
481            shape: self.shape,
482            stride: self.stride,
483        }
484    }
485
486    pub fn to<Dout: DeviceBase>(self) -> Matrix<Owned<T>, S, Dout> {
487        let shape = self.shape();
488        let stride = self.stride();
489        let ptr = self.ptr.to::<Dout>();
490        Matrix::new(ptr, shape, stride)
491    }
492}
493
494impl<'a, T, S, D> Matrix<Ref<&'a mut T>, S, D>
495where
496    T: Num,
497    D: DeviceBase,
498    S: DimTrait,
499{
500    pub(crate) fn offset_ptr_mut(&self, offset: usize) -> Ptr<Ref<&'a mut T>, D> {
501        self.ptr.clone().offset_ptr_mut(offset)
502    }
503
504    pub fn as_mut_ptr(&self) -> *mut T {
505        unsafe { self.ptr.ptr.add(self.offset()) }
506    }
507
508    #[expect(clippy::missing_panics_doc)]
509    pub fn as_mut_slice(&self) -> &mut [T] {
510        if self.shape().len() <= 1 {
511            self.as_mut_slice_unchecked()
512        } else {
513            panic!("Invalid shape");
514        }
515    }
516
517    #[expect(clippy::mut_from_ref)]
518    pub fn as_mut_slice_unchecked(&self) -> &mut [T] {
519        let num_elm = self.shape().num_elm();
520        unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), num_elm) }
521    }
522
523    #[expect(clippy::missing_panics_doc)]
524    pub fn each_by<F>(&mut self, f: F)
525    where
526        F: FnMut(&mut T),
527    {
528        assert_eq!(self.stride().into_iter().min(), Some(1), "Invalid stride");
529        self.as_mut_slice_unchecked().iter_mut().for_each(f);
530    }
531
532    #[must_use]
533    pub fn slice_mut<I>(&self, index: I) -> Matrix<Ref<&'a mut T>, S, D>
534    where
535        I: SliceTrait<Dim = S>,
536    {
537        let shape = self.shape();
538        let stride = self.stride();
539        let new_shape_stride = index.sliced_shape_stride(shape, stride);
540        let offset = index.sliced_offset(stride);
541        Matrix {
542            ptr: self.ptr.clone().offset_ptr_mut(offset),
543            shape: new_shape_stride.shape(),
544            stride: new_shape_stride.stride(),
545        }
546    }
547
548    pub fn slice_mut_dyn(&self, index: Slice) -> Matrix<Ref<&'a mut T>, DimDyn, D> {
549        let shape_stride = self.shape_stride().into_dyn();
550        let new_shape_stride =
551            index.sliced_shape_stride(shape_stride.shape(), shape_stride.stride());
552        let offset = index.sliced_offset(shape_stride.stride());
553        Matrix {
554            ptr: self.ptr.clone().offset_ptr_mut(offset),
555            shape: new_shape_stride.shape(),
556            stride: new_shape_stride.stride(),
557        }
558    }
559
560    #[must_use]
561    pub fn index_axis_mut<I>(&self, index: I) -> Matrix<Ref<&'a mut T>, S, D>
562    where
563        I: IndexAxisTrait,
564        S: LessDimTrait,
565        S::LessDim: DimTrait,
566    {
567        let shape = self.shape();
568        let stride = self.stride();
569        let new_shape_stride = index.get_shape_stride(shape, stride);
570        let offset = index.offset(stride);
571        Matrix {
572            ptr: self.ptr.clone().offset_ptr_mut(offset),
573            shape: new_shape_stride.shape(),
574            stride: new_shape_stride.stride(),
575        }
576    }
577
578    pub fn index_axis_mut_dyn<I>(&self, index: I) -> Matrix<Ref<&'a mut T>, DimDyn, D>
579    where
580        I: IndexAxisTrait,
581    {
582        let shape_stride = self.shape_stride().into_dyn();
583        let new_shape_stride = index.get_shape_stride(shape_stride.shape(), shape_stride.stride());
584        let offset = index.offset(shape_stride.stride());
585        Matrix {
586            ptr: self.ptr.clone().offset_ptr_mut(offset),
587            shape: new_shape_stride.shape(),
588            stride: new_shape_stride.stride(),
589        }
590    }
591
592    #[expect(clippy::missing_panics_doc)]
593    pub fn index_item_assign<I: Into<S>>(&self, index: I, value: T) {
594        let index = index.into();
595        assert!(!self.shape().is_overflow(index), "Index out of bounds");
596        let offset = cal_offset(index, self.stride());
597        self.ptr.assign_item(offset, value);
598    }
599}
600
601#[expect(clippy::float_cmp)]
602#[cfg(test)]
603mod matrix_test {
604
605    use crate::{
606        device::DeviceBase,
607        dim::{Dim1, Dim2, DimDyn, DimTrait},
608        index::Index0D,
609        slice, slice_dynamic,
610    };
611
612    use super::{Matrix, Owned};
613
614    fn index_item_1d<D: DeviceBase>() {
615        let m: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0], [3]);
616        assert_eq!(m.index_item([0]), 1.0);
617        assert_eq!(m.index_item([1]), 2.0);
618        assert_eq!(m.index_item([2]), 3.0);
619    }
620    #[test]
621    fn index_item_1d_cpu() {
622        index_item_1d::<crate::device::cpu::Cpu>();
623    }
624    #[cfg(feature = "nvidia")]
625    #[test]
626    fn index_item_1d_nvidia() {
627        index_item_1d::<crate::device::nvidia::Nvidia>();
628    }
629
630    fn index_item_2d<D: DeviceBase>() {
631        let m: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
632        assert_eq!(m.index_item([0, 0]), 1.0);
633        assert_eq!(m.index_item([0, 1]), 2.0);
634        assert_eq!(m.index_item([1, 0]), 3.0);
635        assert_eq!(m.index_item([1, 1]), 4.0);
636
637        let m: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
638        assert_eq!(m.index_item([0, 0]), 1.0);
639        assert_eq!(m.index_item([0, 1]), 2.0);
640        assert_eq!(m.index_item([1, 0]), 3.0);
641        assert_eq!(m.index_item([1, 1]), 4.0);
642    }
643    #[test]
644    fn index_item_2d_cpu() {
645        index_item_2d::<crate::device::cpu::Cpu>();
646    }
647    #[cfg(feature = "nvidia")]
648    #[test]
649    fn index_item_2d_nvidia() {
650        index_item_2d::<crate::device::nvidia::Nvidia>();
651    }
652
653    #[expect(clippy::cast_precision_loss)]
654    fn slice_1d<D: DeviceBase>() {
655        let v = (1..10).map(|x| x as f32).collect::<Vec<f32>>();
656        let m: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(v.clone(), [9]);
657        let s = m.slice(slice!(1..4));
658        assert_eq!(s.shape().slice(), [3]);
659        assert_eq!(s.stride().slice(), [1]);
660        assert_eq!(s.index_item([0]), 2.0);
661        assert_eq!(s.index_item([1]), 3.0);
662        assert_eq!(s.index_item([2]), 4.0);
663    }
664    #[test]
665    fn slice_1d_cpu() {
666        slice_1d::<crate::device::cpu::Cpu>();
667    }
668    #[cfg(feature = "nvidia")]
669    #[test]
670    fn slice_1d_nvidia() {
671        slice_1d::<crate::device::nvidia::Nvidia>();
672    }
673
674    #[expect(clippy::cast_precision_loss)]
675    fn slice_2d<D: DeviceBase>() {
676        let v = (1..13).map(|x| x as f32).collect::<Vec<f32>>();
677        let m: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(v.clone(), [3, 4]);
678        let s = m.slice(slice!(1..3, 1..4));
679        assert_eq!(s.shape().slice(), [2, 3]);
680        assert_eq!(s.stride().slice(), [4, 1]);
681
682        assert_eq!(s.index_item([0, 0]), 6.);
683        assert_eq!(s.index_item([0, 1]), 7.);
684        assert_eq!(s.index_item([0, 2]), 8.);
685        assert_eq!(s.index_item([1, 0]), 10.);
686        assert_eq!(s.index_item([1, 1]), 11.);
687        assert_eq!(s.index_item([1, 2]), 12.);
688    }
689    #[test]
690    fn slice_2d_cpu() {
691        slice_2d::<crate::device::cpu::Cpu>();
692    }
693    #[cfg(feature = "nvidia")]
694    #[test]
695    fn slice_2d_nvidia() {
696        slice_2d::<crate::device::nvidia::Nvidia>();
697    }
698
699    #[expect(clippy::cast_precision_loss)]
700    fn slice_dyn_4d<D: DeviceBase>() {
701        let v = (1..65).map(|x| x as f32).collect::<Vec<f32>>();
702        let m: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(v.clone(), [2, 2, 4, 4]);
703        let s = m.slice_dyn(slice_dynamic!(.., .., 2, ..));
704
705        assert_eq!(s.index_item([0, 0, 0]), 9.);
706        assert_eq!(s.index_item([0, 0, 1]), 10.);
707        assert_eq!(s.index_item([0, 0, 2]), 11.);
708        assert_eq!(s.index_item([0, 0, 3]), 12.);
709        assert_eq!(s.index_item([0, 1, 0]), 25.);
710        assert_eq!(s.index_item([0, 1, 1]), 26.);
711        assert_eq!(s.index_item([0, 1, 2]), 27.);
712        assert_eq!(s.index_item([0, 1, 3]), 28.);
713        assert_eq!(s.index_item([1, 0, 0]), 41.);
714        assert_eq!(s.index_item([1, 0, 1]), 42.);
715        assert_eq!(s.index_item([1, 0, 2]), 43.);
716        assert_eq!(s.index_item([1, 0, 3]), 44.);
717        assert_eq!(s.index_item([1, 1, 0]), 57.);
718        assert_eq!(s.index_item([1, 1, 1]), 58.);
719        assert_eq!(s.index_item([1, 1, 2]), 59.);
720        assert_eq!(s.index_item([1, 1, 3]), 60.);
721    }
722    #[test]
723    fn slice_dyn_4d_cpu() {
724        slice_dyn_4d::<crate::device::cpu::Cpu>();
725    }
726    #[cfg(feature = "nvidia")]
727    #[test]
728    fn slice_dyn_4d_nvidia() {
729        slice_dyn_4d::<crate::device::nvidia::Nvidia>();
730    }
731
732    #[expect(clippy::cast_precision_loss)]
733    fn index_axis_dyn_2d<D: DeviceBase>() {
734        let v = (1..13).map(|x| x as f32).collect::<Vec<f32>>();
735        let m: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(v.clone(), [3, 4]);
736        let s = m.index_axis_dyn(Index0D::new(0));
737
738        assert_eq!(s.index_item([0]), 1.);
739        assert_eq!(s.index_item([1]), 2.);
740        assert_eq!(s.index_item([2]), 3.);
741    }
742    #[test]
743    fn index_axis_dyn_2d_cpu() {
744        index_axis_dyn_2d::<crate::device::cpu::Cpu>();
745    }
746    #[cfg(feature = "nvidia")]
747    #[test]
748    fn index_axis_dyn_2d_nvidia() {
749        index_axis_dyn_2d::<crate::device::nvidia::Nvidia>();
750    }
751}