1use std::{any::TypeId, marker::PhantomData};
2
3use crate::{
4 device::{Device, DeviceBase},
5 dim::{cal_offset, default_stride, DimDyn, DimTrait, LessDimTrait},
6 index::{IndexAxisTrait, SliceTrait},
7 num::Num,
8 shape_stride::ShapeStride,
9 slice::Slice,
10};
11
12#[cfg(feature = "nvidia")]
13use crate::device::nvidia::Nvidia;
14
15pub trait Repr: Default {
16 type Item: Num;
17
18 fn drop_memory<D: DeviceBase>(ptr: *mut Self::Item, _: D);
19 fn clone_memory<D: DeviceBase>(ptr: *mut Self::Item, len: usize, _: D) -> *mut Self::Item;
20}
21
22pub trait OwnedRepr: Repr {}
23
24pub struct Owned<T: Num> {
25 _maker: PhantomData<T>,
26}
27
28pub struct Ref<A> {
29 _maker: PhantomData<A>,
30}
31
32impl<T: Num> Default for Owned<T> {
33 fn default() -> Self {
34 Owned {
35 _maker: PhantomData,
36 }
37 }
38}
39
40impl<A> Default for Ref<A> {
41 fn default() -> Self {
42 Ref {
43 _maker: PhantomData,
44 }
45 }
46}
47
48impl<'a, T: Num> Repr for Ref<&'a T> {
49 type Item = T;
50
51 fn drop_memory<D: DeviceBase>(_ptr: *mut Self::Item, _: D) {}
52 fn clone_memory<D: DeviceBase>(ptr: *mut Self::Item, _len: usize, _: D) -> *mut Self::Item {
53 ptr
54 }
55}
56
57impl<'a, T: Num> Repr for Ref<&'a mut T> {
58 type Item = T;
59
60 fn drop_memory<D: DeviceBase>(_ptr: *mut Self::Item, _: D) {}
61 fn clone_memory<D: DeviceBase>(ptr: *mut Self::Item, _len: usize, _: D) -> *mut Self::Item {
62 ptr
63 }
64}
65
66impl<T: Num> Repr for Owned<T> {
67 type Item = T;
68
69 fn drop_memory<D: DeviceBase>(ptr: *mut Self::Item, _: D) {
70 D::drop_ptr(ptr);
71 }
72
73 fn clone_memory<D: DeviceBase>(ptr: *mut Self::Item, len: usize, _: D) -> *mut Self::Item {
74 D::clone_ptr(ptr, len)
75 }
76}
77
78impl<T: Num> OwnedRepr for Owned<T> {}
79
80pub struct Ptr<R, D>
81where
82 R: Repr,
83 D: DeviceBase,
84{
85 ptr: *mut R::Item,
86 len: usize,
87 offset: usize,
88 repr: PhantomData<R>,
89 device: PhantomData<D>,
90}
91
92impl<R, D> Ptr<R, D>
93where
94 R: Repr,
95 D: DeviceBase,
96{
97 pub(crate) fn new(ptr: *mut R::Item, len: usize, offset: usize) -> Self {
98 Ptr {
99 ptr,
100 len,
101 offset,
102 repr: PhantomData,
103 device: PhantomData,
104 }
105 }
106
107 #[must_use]
108 pub fn offset_ptr(&self, offset: usize) -> Ptr<Ref<&R::Item>, D> {
109 Ptr {
110 ptr: self.ptr,
111 len: self.len,
112 offset: self.offset + offset,
113 repr: PhantomData,
114 device: PhantomData,
115 }
116 }
117
118 pub(crate) fn len(&self) -> usize {
119 self.len
120 }
121
122 #[expect(clippy::missing_panics_doc)]
123 #[must_use]
124 pub fn get_item(&self, offset: usize) -> R::Item {
125 assert!(offset < self.len, "Index out of bounds");
126 D::get_item(self.ptr, offset + self.offset)
127 }
128
129 fn to_ref<'a>(&self) -> Ptr<Ref<&'a R::Item>, D> {
130 Ptr {
131 ptr: self.ptr,
132 len: self.len,
133 offset: self.offset,
134 repr: PhantomData,
135 device: PhantomData,
136 }
137 }
138
139 fn to<Dout: DeviceBase>(&self) -> Ptr<Owned<R::Item>, Dout> {
140 #[cfg(feature = "nvidia")]
141 use crate::device::cpu::Cpu;
142
143 let self_raw_ptr = self.ptr;
144 let len = self.len;
145
146 let ptr = match (TypeId::of::<D>(), TypeId::of::<Dout>()) {
147 (a, b) if a == b => Owned::clone_memory(self_raw_ptr, len, D::default()),
148 #[cfg(feature = "nvidia")]
149 (a, b) if a == TypeId::of::<Cpu>() && b == TypeId::of::<Nvidia>() => {
150 zenu_cuda::runtime::copy_to_gpu(self_raw_ptr, len)
151 }
152 #[cfg(feature = "nvidia")]
153 (a, b) if a == TypeId::of::<Nvidia>() && b == TypeId::of::<Cpu>() => {
154 zenu_cuda::runtime::copy_to_cpu(self_raw_ptr, len)
155 }
156 _ => unreachable!(),
157 };
158
159 Ptr::new(ptr, len, self.offset)
160 }
161}
162
163impl<R, D> Drop for Ptr<R, D>
164where
165 R: Repr,
166 D: DeviceBase,
167{
168 fn drop(&mut self) {
169 R::drop_memory(self.ptr, D::default());
170 }
171}
172
173impl<'a, T: Num, D: DeviceBase> Ptr<Ref<&'a mut T>, D> {
174 #[must_use]
175 pub fn offset_ptr_mut(self, offset: usize) -> Ptr<Ref<&'a mut T>, D> {
176 Ptr {
177 ptr: self.ptr,
178 len: self.len,
179 offset: self.offset + offset,
180 repr: PhantomData,
181 device: PhantomData,
182 }
183 }
184
185 #[expect(clippy::missing_panics_doc)]
186 pub fn assign_item(&self, offset: usize, value: T) {
187 assert!(offset < self.len, "Index out of bounds");
188 D::assign_item(self.ptr, offset + self.offset, value);
189 }
190}
191
192impl<R, D> Clone for Ptr<R, D>
193where
194 R: Repr,
195 D: DeviceBase,
196{
197 fn clone(&self) -> Self {
198 Ptr {
199 ptr: R::clone_memory(self.ptr, self.len, D::default()),
200 len: self.len,
201 offset: self.offset,
202 repr: PhantomData,
203 device: PhantomData,
204 }
205 }
206}
207
208impl<R, D> Ptr<R, D>
209where
210 R: OwnedRepr,
211 D: DeviceBase,
212{
213 fn to_ref_mut<'a>(&mut self) -> Ptr<Ref<&'a mut R::Item>, D> {
214 Ptr {
215 ptr: self.ptr,
216 len: self.len,
217 offset: self.offset,
218 repr: PhantomData,
219 device: PhantomData,
220 }
221 }
222}
223
224pub struct Matrix<R, S, D>
225where
226 R: Repr,
227 S: DimTrait,
228 D: DeviceBase,
229{
230 ptr: Ptr<R, D>,
231 shape: S,
232 stride: S,
233}
234
235impl<R, S, D> Clone for Matrix<R, S, D>
236where
237 R: Repr,
238 S: DimTrait,
239 D: DeviceBase,
240{
241 fn clone(&self) -> Self {
242 Matrix {
243 ptr: self.ptr.clone(),
244 shape: self.shape,
245 stride: self.stride,
246 }
247 }
248}
249
250impl<R, S, D> Matrix<R, S, D>
251where
252 R: Repr,
253 S: DimTrait,
254 D: DeviceBase,
255{
256 pub(crate) fn new(ptr: Ptr<R, D>, shape: S, stride: S) -> Self {
257 Matrix { ptr, shape, stride }
258 }
259
260 pub(crate) unsafe fn ptr(&self) -> &Ptr<R, D> {
261 &self.ptr
262 }
263
264 pub fn offset(&self) -> usize {
265 self.ptr.offset
266 }
267
268 pub fn shape_stride(&self) -> ShapeStride<S> {
269 ShapeStride::new(self.shape, self.stride)
270 }
271
272 pub fn shape(&self) -> S {
273 self.shape
274 }
275
276 pub fn stride(&self) -> S {
277 self.stride
278 }
279
280 pub fn is_default_stride(&self) -> bool {
281 self.shape_stride().is_default_stride()
282 }
283
284 pub fn is_transpose_default_stride(&self) -> bool {
285 self.shape_stride().is_transposed_default_stride()
286 }
287
288 pub fn as_ptr(&self) -> *const R::Item {
289 unsafe { self.ptr.ptr.add(self.offset()) }
290 }
291
292 pub fn to_vec(&self) -> Vec<R::Item>
295 where
296 R::Item: Clone,
297 {
298 let ptr_len = self.ptr.len();
299 let mut vec = Vec::with_capacity(ptr_len);
300 let non_offset_ptr = Ptr::<Ref<&R::Item>, D>::new(self.ptr.ptr, ptr_len, 0);
301 for i in 0..ptr_len {
302 vec.push(non_offset_ptr.get_item(i));
303 }
304 vec
305 }
306
307 pub fn into_dyn_dim(self) -> Matrix<R, DimDyn, D> {
308 let mut shape = DimDyn::default();
309 let mut stride = DimDyn::default();
310
311 for i in 0..self.shape.len() {
312 shape.push_dim(self.shape[i]);
313 stride.push_dim(self.stride[i]);
314 }
315 Matrix {
316 ptr: self.ptr,
317 shape,
318 stride,
319 }
320 }
321
322 pub fn update_shape_stride(&mut self, shape_stride: ShapeStride<S>) {
323 self.shape = shape_stride.shape();
324 self.stride = shape_stride.stride();
325 }
326
327 pub fn update_shape(&mut self, shape: S) {
328 self.shape = shape;
329 self.stride = default_stride(shape);
330 }
331
332 pub fn update_stride(&mut self, stride: S) {
333 self.stride = stride;
334 }
335
336 pub fn into_dim<S2>(self) -> Matrix<R, S2, D>
337 where
338 S2: DimTrait,
339 {
340 Matrix {
341 ptr: self.ptr,
342 shape: S2::from(self.shape.slice()),
343 stride: S2::from(self.stride.slice()),
344 }
345 }
346
347 pub fn slice<I>(&self, index: I) -> Matrix<Ref<&R::Item>, S, D>
348 where
349 I: SliceTrait<Dim = S>,
350 {
351 let shape = self.shape();
352 let stride = self.stride();
353 let new_shape_stride = index.sliced_shape_stride(shape, stride);
354 let offset = index.sliced_offset(stride);
355 Matrix {
356 ptr: self.ptr.offset_ptr(offset),
357 shape: new_shape_stride.shape(),
358 stride: new_shape_stride.stride(),
359 }
360 }
361
362 pub fn slice_dyn(&self, index: Slice) -> Matrix<Ref<&R::Item>, DimDyn, D> {
363 let shape_stride = self.shape_stride().into_dyn();
364 let new_shape_stride =
365 index.sliced_shape_stride(shape_stride.shape(), shape_stride.stride());
366 let offset = index.sliced_offset(shape_stride.stride());
367 Matrix {
368 ptr: self.ptr.offset_ptr(offset),
369 shape: new_shape_stride.shape(),
370 stride: new_shape_stride.stride(),
371 }
372 }
373
374 pub fn index_axis<I>(&self, index: I) -> Matrix<Ref<&R::Item>, S, D>
375 where
376 I: IndexAxisTrait,
377 S: LessDimTrait,
378 S::LessDim: DimTrait,
379 {
380 let shape = self.shape();
381 let stride = self.stride();
382 let new_shape_stride = index.get_shape_stride(shape, stride);
383 let offset = index.offset(stride);
384 Matrix {
385 ptr: self.ptr.offset_ptr(offset),
386 shape: new_shape_stride.shape(),
387 stride: new_shape_stride.stride(),
388 }
389 }
390
391 pub fn index_axis_dyn<I>(&self, index: I) -> Matrix<Ref<&R::Item>, DimDyn, D>
392 where
393 I: IndexAxisTrait,
394 {
395 let shape_stride = self.shape_stride().into_dyn();
396 let new_shape_stride = index.get_shape_stride(shape_stride.shape(), shape_stride.stride());
397 let offset = index.offset(shape_stride.stride());
398 Matrix {
399 ptr: self.ptr.offset_ptr(offset),
400 shape: new_shape_stride.shape(),
401 stride: new_shape_stride.stride(),
402 }
403 }
404
405 #[expect(clippy::missing_panics_doc)]
406 pub fn index_item<I: Into<S>>(&self, index: I) -> R::Item {
407 let index = index.into();
408 assert!(!self.shape().is_overflow(index), "Index out of bounds");
409 let offset = cal_offset(index, self.stride());
410 self.ptr.get_item(offset)
411 }
412
413 pub fn to_ref<'a>(&self) -> Matrix<Ref<&'a R::Item>, S, D> {
414 Matrix {
415 ptr: self.ptr.to_ref(),
416 shape: self.shape,
417 stride: self.stride,
418 }
419 }
420
421 pub fn convert_dim_type<Dout: DimTrait>(self) -> Matrix<R, Dout, D> {
422 Matrix {
423 ptr: self.ptr,
424 shape: Dout::from(self.shape.slice()),
425 stride: Dout::from(self.stride.slice()),
426 }
427 }
428
429 pub fn new_matrix(&self) -> Matrix<Owned<R::Item>, S, D>
430 where
431 D: Device,
432 {
433 let mut owned = Matrix::zeros(self.shape());
434 owned.to_ref_mut().copy_from(self);
435 owned
436 }
437
438 #[expect(clippy::missing_errors_doc)]
439 pub fn try_to_scalar(&self) -> Result<R::Item, String> {
440 if self.shape().is_scalar() {
441 let scalr = self.ptr.get_item(0);
442 Ok(scalr)
443 } else {
444 Err("this matrix is not scalar".to_string())
445 }
446 }
447
448 #[expect(clippy::missing_panics_doc)]
449 pub fn to_scalar(&self) -> R::Item {
450 if let Ok(scalar) = self.try_to_scalar() {
451 scalar
452 } else {
453 panic!("Matrix is not scalar");
454 }
455 }
456
457 #[expect(clippy::missing_panics_doc)]
458 pub fn as_slice(&self) -> &[R::Item] {
459 if self.shape().len() <= 1 {
460 self.as_slice_unchecked()
461 } else {
462 panic!("Invalid shape");
463 }
464 }
465
466 pub fn as_slice_unchecked(&self) -> &[R::Item] {
467 let num_elm = self.shape().num_elm();
468 unsafe { std::slice::from_raw_parts(self.as_ptr(), num_elm) }
469 }
470}
471
472impl<T, S, D> Matrix<Owned<T>, S, D>
473where
474 T: Num,
475 D: DeviceBase,
476 S: DimTrait,
477{
478 pub fn to_ref_mut<'a>(&mut self) -> Matrix<Ref<&'a mut T>, S, D> {
479 Matrix {
480 ptr: self.ptr.to_ref_mut(),
481 shape: self.shape,
482 stride: self.stride,
483 }
484 }
485
486 pub fn to<Dout: DeviceBase>(self) -> Matrix<Owned<T>, S, Dout> {
487 let shape = self.shape();
488 let stride = self.stride();
489 let ptr = self.ptr.to::<Dout>();
490 Matrix::new(ptr, shape, stride)
491 }
492}
493
494impl<'a, T, S, D> Matrix<Ref<&'a mut T>, S, D>
495where
496 T: Num,
497 D: DeviceBase,
498 S: DimTrait,
499{
500 pub(crate) fn offset_ptr_mut(&self, offset: usize) -> Ptr<Ref<&'a mut T>, D> {
501 self.ptr.clone().offset_ptr_mut(offset)
502 }
503
504 pub fn as_mut_ptr(&self) -> *mut T {
505 unsafe { self.ptr.ptr.add(self.offset()) }
506 }
507
508 #[expect(clippy::missing_panics_doc)]
509 pub fn as_mut_slice(&self) -> &mut [T] {
510 if self.shape().len() <= 1 {
511 self.as_mut_slice_unchecked()
512 } else {
513 panic!("Invalid shape");
514 }
515 }
516
517 #[expect(clippy::mut_from_ref)]
518 pub fn as_mut_slice_unchecked(&self) -> &mut [T] {
519 let num_elm = self.shape().num_elm();
520 unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), num_elm) }
521 }
522
523 #[expect(clippy::missing_panics_doc)]
524 pub fn each_by<F>(&mut self, f: F)
525 where
526 F: FnMut(&mut T),
527 {
528 assert_eq!(self.stride().into_iter().min(), Some(1), "Invalid stride");
529 self.as_mut_slice_unchecked().iter_mut().for_each(f);
530 }
531
532 #[must_use]
533 pub fn slice_mut<I>(&self, index: I) -> Matrix<Ref<&'a mut T>, S, D>
534 where
535 I: SliceTrait<Dim = S>,
536 {
537 let shape = self.shape();
538 let stride = self.stride();
539 let new_shape_stride = index.sliced_shape_stride(shape, stride);
540 let offset = index.sliced_offset(stride);
541 Matrix {
542 ptr: self.ptr.clone().offset_ptr_mut(offset),
543 shape: new_shape_stride.shape(),
544 stride: new_shape_stride.stride(),
545 }
546 }
547
548 pub fn slice_mut_dyn(&self, index: Slice) -> Matrix<Ref<&'a mut T>, DimDyn, D> {
549 let shape_stride = self.shape_stride().into_dyn();
550 let new_shape_stride =
551 index.sliced_shape_stride(shape_stride.shape(), shape_stride.stride());
552 let offset = index.sliced_offset(shape_stride.stride());
553 Matrix {
554 ptr: self.ptr.clone().offset_ptr_mut(offset),
555 shape: new_shape_stride.shape(),
556 stride: new_shape_stride.stride(),
557 }
558 }
559
560 #[must_use]
561 pub fn index_axis_mut<I>(&self, index: I) -> Matrix<Ref<&'a mut T>, S, D>
562 where
563 I: IndexAxisTrait,
564 S: LessDimTrait,
565 S::LessDim: DimTrait,
566 {
567 let shape = self.shape();
568 let stride = self.stride();
569 let new_shape_stride = index.get_shape_stride(shape, stride);
570 let offset = index.offset(stride);
571 Matrix {
572 ptr: self.ptr.clone().offset_ptr_mut(offset),
573 shape: new_shape_stride.shape(),
574 stride: new_shape_stride.stride(),
575 }
576 }
577
578 pub fn index_axis_mut_dyn<I>(&self, index: I) -> Matrix<Ref<&'a mut T>, DimDyn, D>
579 where
580 I: IndexAxisTrait,
581 {
582 let shape_stride = self.shape_stride().into_dyn();
583 let new_shape_stride = index.get_shape_stride(shape_stride.shape(), shape_stride.stride());
584 let offset = index.offset(shape_stride.stride());
585 Matrix {
586 ptr: self.ptr.clone().offset_ptr_mut(offset),
587 shape: new_shape_stride.shape(),
588 stride: new_shape_stride.stride(),
589 }
590 }
591
592 #[expect(clippy::missing_panics_doc)]
593 pub fn index_item_assign<I: Into<S>>(&self, index: I, value: T) {
594 let index = index.into();
595 assert!(!self.shape().is_overflow(index), "Index out of bounds");
596 let offset = cal_offset(index, self.stride());
597 self.ptr.assign_item(offset, value);
598 }
599}
600
601#[expect(clippy::float_cmp)]
602#[cfg(test)]
603mod matrix_test {
604
605 use crate::{
606 device::DeviceBase,
607 dim::{Dim1, Dim2, DimDyn, DimTrait},
608 index::Index0D,
609 slice, slice_dynamic,
610 };
611
612 use super::{Matrix, Owned};
613
614 fn index_item_1d<D: DeviceBase>() {
615 let m: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0], [3]);
616 assert_eq!(m.index_item([0]), 1.0);
617 assert_eq!(m.index_item([1]), 2.0);
618 assert_eq!(m.index_item([2]), 3.0);
619 }
620 #[test]
621 fn index_item_1d_cpu() {
622 index_item_1d::<crate::device::cpu::Cpu>();
623 }
624 #[cfg(feature = "nvidia")]
625 #[test]
626 fn index_item_1d_nvidia() {
627 index_item_1d::<crate::device::nvidia::Nvidia>();
628 }
629
630 fn index_item_2d<D: DeviceBase>() {
631 let m: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
632 assert_eq!(m.index_item([0, 0]), 1.0);
633 assert_eq!(m.index_item([0, 1]), 2.0);
634 assert_eq!(m.index_item([1, 0]), 3.0);
635 assert_eq!(m.index_item([1, 1]), 4.0);
636
637 let m: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
638 assert_eq!(m.index_item([0, 0]), 1.0);
639 assert_eq!(m.index_item([0, 1]), 2.0);
640 assert_eq!(m.index_item([1, 0]), 3.0);
641 assert_eq!(m.index_item([1, 1]), 4.0);
642 }
643 #[test]
644 fn index_item_2d_cpu() {
645 index_item_2d::<crate::device::cpu::Cpu>();
646 }
647 #[cfg(feature = "nvidia")]
648 #[test]
649 fn index_item_2d_nvidia() {
650 index_item_2d::<crate::device::nvidia::Nvidia>();
651 }
652
653 #[expect(clippy::cast_precision_loss)]
654 fn slice_1d<D: DeviceBase>() {
655 let v = (1..10).map(|x| x as f32).collect::<Vec<f32>>();
656 let m: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(v.clone(), [9]);
657 let s = m.slice(slice!(1..4));
658 assert_eq!(s.shape().slice(), [3]);
659 assert_eq!(s.stride().slice(), [1]);
660 assert_eq!(s.index_item([0]), 2.0);
661 assert_eq!(s.index_item([1]), 3.0);
662 assert_eq!(s.index_item([2]), 4.0);
663 }
664 #[test]
665 fn slice_1d_cpu() {
666 slice_1d::<crate::device::cpu::Cpu>();
667 }
668 #[cfg(feature = "nvidia")]
669 #[test]
670 fn slice_1d_nvidia() {
671 slice_1d::<crate::device::nvidia::Nvidia>();
672 }
673
674 #[expect(clippy::cast_precision_loss)]
675 fn slice_2d<D: DeviceBase>() {
676 let v = (1..13).map(|x| x as f32).collect::<Vec<f32>>();
677 let m: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(v.clone(), [3, 4]);
678 let s = m.slice(slice!(1..3, 1..4));
679 assert_eq!(s.shape().slice(), [2, 3]);
680 assert_eq!(s.stride().slice(), [4, 1]);
681
682 assert_eq!(s.index_item([0, 0]), 6.);
683 assert_eq!(s.index_item([0, 1]), 7.);
684 assert_eq!(s.index_item([0, 2]), 8.);
685 assert_eq!(s.index_item([1, 0]), 10.);
686 assert_eq!(s.index_item([1, 1]), 11.);
687 assert_eq!(s.index_item([1, 2]), 12.);
688 }
689 #[test]
690 fn slice_2d_cpu() {
691 slice_2d::<crate::device::cpu::Cpu>();
692 }
693 #[cfg(feature = "nvidia")]
694 #[test]
695 fn slice_2d_nvidia() {
696 slice_2d::<crate::device::nvidia::Nvidia>();
697 }
698
699 #[expect(clippy::cast_precision_loss)]
700 fn slice_dyn_4d<D: DeviceBase>() {
701 let v = (1..65).map(|x| x as f32).collect::<Vec<f32>>();
702 let m: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(v.clone(), [2, 2, 4, 4]);
703 let s = m.slice_dyn(slice_dynamic!(.., .., 2, ..));
704
705 assert_eq!(s.index_item([0, 0, 0]), 9.);
706 assert_eq!(s.index_item([0, 0, 1]), 10.);
707 assert_eq!(s.index_item([0, 0, 2]), 11.);
708 assert_eq!(s.index_item([0, 0, 3]), 12.);
709 assert_eq!(s.index_item([0, 1, 0]), 25.);
710 assert_eq!(s.index_item([0, 1, 1]), 26.);
711 assert_eq!(s.index_item([0, 1, 2]), 27.);
712 assert_eq!(s.index_item([0, 1, 3]), 28.);
713 assert_eq!(s.index_item([1, 0, 0]), 41.);
714 assert_eq!(s.index_item([1, 0, 1]), 42.);
715 assert_eq!(s.index_item([1, 0, 2]), 43.);
716 assert_eq!(s.index_item([1, 0, 3]), 44.);
717 assert_eq!(s.index_item([1, 1, 0]), 57.);
718 assert_eq!(s.index_item([1, 1, 1]), 58.);
719 assert_eq!(s.index_item([1, 1, 2]), 59.);
720 assert_eq!(s.index_item([1, 1, 3]), 60.);
721 }
722 #[test]
723 fn slice_dyn_4d_cpu() {
724 slice_dyn_4d::<crate::device::cpu::Cpu>();
725 }
726 #[cfg(feature = "nvidia")]
727 #[test]
728 fn slice_dyn_4d_nvidia() {
729 slice_dyn_4d::<crate::device::nvidia::Nvidia>();
730 }
731
732 #[expect(clippy::cast_precision_loss)]
733 fn index_axis_dyn_2d<D: DeviceBase>() {
734 let v = (1..13).map(|x| x as f32).collect::<Vec<f32>>();
735 let m: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(v.clone(), [3, 4]);
736 let s = m.index_axis_dyn(Index0D::new(0));
737
738 assert_eq!(s.index_item([0]), 1.);
739 assert_eq!(s.index_item([1]), 2.);
740 assert_eq!(s.index_item([2]), 3.);
741 }
742 #[test]
743 fn index_axis_dyn_2d_cpu() {
744 index_axis_dyn_2d::<crate::device::cpu::Cpu>();
745 }
746 #[cfg(feature = "nvidia")]
747 #[test]
748 fn index_axis_dyn_2d_nvidia() {
749 index_axis_dyn_2d::<crate::device::nvidia::Nvidia>();
750 }
751}