zenu_matrix/operation/
max.rs

1use crate::{
2    device::{cpu::Cpu, Device, DeviceBase},
3    dim::{DimDyn, DimTrait},
4    index::index_dyn_impl::Index,
5    matrix::{Matrix, Owned, Repr},
6    num::Num,
7};
8
9#[cfg(feature = "nvidia")]
10use crate::device::nvidia::Nvidia;
11
12#[cfg(feature = "nvidia")]
13use zenu_cuda::kernel::array_max_idx;
14
15pub trait MaxIdx: DeviceBase {
16    fn max_idx<T: Num>(input: *const T, size: usize, stride: usize) -> usize;
17}
18
19impl MaxIdx for Cpu {
20    #[expect(clippy::not_unsafe_ptr_arg_deref)]
21    fn max_idx<T: Num>(input: *const T, size: usize, stride: usize) -> usize {
22        let tmep_v = unsafe { std::slice::from_raw_parts(input, size * stride) };
23        let mut max_idx = 0;
24        let mut max_val = tmep_v[0];
25
26        for i in 1..size {
27            if tmep_v[i * stride] > max_val {
28                max_val = tmep_v[i * stride];
29                max_idx = i;
30            }
31        }
32        max_idx
33    }
34}
35
36#[cfg(feature = "nvidia")]
37impl MaxIdx for Nvidia {
38    fn max_idx<T: Num>(input: *const T, size: usize, stride: usize) -> usize {
39        array_max_idx(input, size, stride)
40    }
41}
42
43impl<T: Num, R: Repr<Item = T>, D: Device> Matrix<R, DimDyn, D> {
44    #[must_use]
45    pub fn max_idx(&self) -> DimDyn {
46        if self.shape().is_empty() {
47            return DimDyn::from(&[] as &[usize]);
48        }
49        let default_stride = self.to_default_stride();
50        let idx = <D as MaxIdx>::max_idx(
51            default_stride.as_ptr(),
52            default_stride.shape().num_elm(),
53            default_stride.stride()[default_stride.shape().len() - 1],
54        );
55        default_stride.shape_stride().get_dim_by_offset(idx)
56    }
57
58    #[must_use]
59    pub fn max_item(&self) -> T {
60        let idx = self.max_idx();
61        self.index_item(idx)
62    }
63
64    /// selfはdefault stride
65    #[expect(clippy::missing_panics_doc)]
66    #[must_use]
67    pub fn max_axis(&self, axis: usize, keep_dim: bool) -> Matrix<Owned<T>, DimDyn, D> {
68        assert!(axis < self.shape().len(), "max_axis: Axis out of bounds");
69
70        let mut output_shape = Vec::new();
71        for i in 0..self.shape().len() {
72            if i == axis {
73                continue;
74            }
75            output_shape.push(self.shape()[i]);
76        }
77
78        let output_shape = DimDyn::from(&output_shape as &[usize]);
79        let mut output = Matrix::<Owned<T>, DimDyn, D>::alloc(output_shape);
80
81        if axis == 0 {
82            let output_flatten = output.reshape_mut([output.shape().num_elm()]);
83            let s = self.reshape_new_matrix([self.shape()[0], output_shape.num_elm()]);
84            for i in 0..output_shape.num_elm() {
85                output_flatten.index_item_assign([i], s.index_axis(Index::new(1, i)).max_item());
86            }
87        } else {
88            for i in 0..self.shape()[0] {
89                let s = self.index_axis(Index::new(0, i));
90                let output = output.to_ref_mut().index_axis_mut(Index::new(0, i));
91                output.copy_from(&s.max_axis(axis - 1, false));
92            }
93        }
94
95        if keep_dim {
96            output.add_axis(axis);
97        }
98        output
99    }
100
101    #[expect(clippy::missing_panics_doc)]
102    #[must_use]
103    pub fn max_axis_idx_ravel(&self, axis: usize) -> Vec<usize> {
104        assert!(axis < self.shape().len(), "max_axis: Axis out of bounds");
105
106        let mut output_shape = Vec::new();
107        for i in 0..self.shape().len() {
108            if i == axis {
109                continue;
110            }
111            output_shape.push(self.shape()[i]);
112        }
113
114        let output_shape = DimDyn::from(&output_shape as &[usize]);
115        let mut output = Vec::with_capacity(output_shape.num_elm());
116
117        if axis == 0 {
118            let s = self.reshape_new_matrix([self.shape()[0], output_shape.num_elm()]);
119            for i in 0..output_shape.num_elm() {
120                let idx = s.index_axis(Index::new(1, i)).max_idx()[0];
121                output.push(idx);
122            }
123        } else {
124            for i in 0..self.shape()[0] {
125                let s = self.index_axis(Index::new(0, i));
126                let tmp = s.max_axis_idx_ravel(axis - 1);
127                for tmp_idx in tmp {
128                    output.push(tmp_idx);
129                }
130            }
131        }
132
133        output
134    }
135}
136
137#[cfg(test)]
138mod max_idx {
139    #![expect(
140        clippy::float_cmp,
141        clippy::unreadable_literal,
142        clippy::cast_precision_loss,
143        clippy::cast_possible_truncation,
144        clippy::too_many_lines,
145        clippy::excessive_precision
146    )]
147
148    use crate::{
149        device::Device,
150        dim::DimDyn,
151        matrix::{Matrix, Owned},
152        slice_dynamic,
153    };
154
155    use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
156
157    fn default_1d<D: Device>() {
158        let a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![0., 1., 2., 3.], [4]);
159        assert_eq!(a.to_ref().max_idx(), [3].into());
160    }
161    #[test]
162    fn default_1d_cpu() {
163        default_1d::<crate::device::cpu::Cpu>();
164    }
165    #[cfg(feature = "nvidia")]
166    #[test]
167    fn default_1d_gpu() {
168        default_1d::<crate::device::nvidia::Nvidia>();
169    }
170
171    fn default_2d<D: Device>() {
172        let a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![0., 1., 2., 3.], [2, 2]);
173        assert_eq!(a.to_ref().max_idx(), [1, 1].into());
174    }
175    #[test]
176    fn default_2d_cpu() {
177        default_2d::<crate::device::cpu::Cpu>();
178    }
179    #[cfg(feature = "nvidia")]
180    #[test]
181    fn default_2d_gpu() {
182        default_2d::<crate::device::nvidia::Nvidia>();
183    }
184
185    fn sliced_3d<D: Device>() {
186        let mut v = Vec::new();
187        for i in 0..8 * 8 * 8 {
188            v.push(i as f32);
189        }
190        let a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(v, [8, 8, 8]);
191        let sliced = a.slice(slice_dynamic!(..;3, ..;4, ..;2));
192        assert_eq!(sliced.max_idx(), [2, 1, 3].into());
193    }
194    #[test]
195    fn sliced_3d_cpu() {
196        sliced_3d::<crate::device::cpu::Cpu>();
197    }
198    #[cfg(feature = "nvidia")]
199    #[test]
200    fn sliced_3d_gpu() {
201        sliced_3d::<crate::device::nvidia::Nvidia>();
202    }
203
204    fn max_axis_reval<D: Device>() {
205        let input = vec![
206            0.04881350392732475,
207            0.21518936637241948,
208            0.10276337607164387,
209            0.044883182996896864,
210            -0.07634520066109529,
211            0.14589411306665612,
212            -0.06241278873730749,
213            0.39177300078207977,
214            0.4636627605010293,
215            -0.1165584811742223,
216            0.2917250380826646,
217            0.02889491975290448,
218            0.06804456109393231,
219            0.42559663829266103,
220            -0.42896394180211306,
221            -0.4128707002984593,
222            -0.4797816025596743,
223            0.332619845547938,
224            0.2781567509498505,
225            0.37001214824681916,
226            0.478618342232764,
227            0.2991585642167236,
228            -0.038520637747068154,
229            0.28052917628645546,
230        ];
231
232        let input = Matrix::<Owned<f32>, DimDyn, D>::from_vec(input, [2, 3, 4]);
233        let ans_0d = vec![1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1];
234        let ans_1d = vec![2, 0, 2, 1, 2, 0, 1, 1];
235        let ans_2d = vec![1, 3, 0, 1, 3, 0];
236
237        let result_0d = input.max_axis_idx_ravel(0);
238        let result_1d = input.max_axis_idx_ravel(1);
239        let result_2d = input.max_axis_idx_ravel(2);
240
241        assert_eq!(result_0d, ans_0d);
242        assert_eq!(result_1d, ans_1d);
243        assert_eq!(result_2d, ans_2d);
244    }
245    run_mat_test!(max_axis_reval, max_axis_reval_cpu, max_axis_reval_gpu);
246
247    fn max_axis_4d<D: Device>() {
248        let input: Vec<f32> = vec![
249            0.5488135039273248,
250            0.7151893663724195,
251            0.6027633760716439,
252            0.5448831829968969,
253            0.4236547993389047,
254            0.6458941130666561,
255            0.4375872112626925,
256            0.8917730007820798,
257            0.9636627605010293,
258            0.3834415188257777,
259            0.7917250380826646,
260            0.5288949197529045,
261            0.5680445610939323,
262            0.925596638292661,
263            0.07103605819788694,
264            0.08712929970154071,
265            0.02021839744032572,
266            0.832619845547938,
267            0.7781567509498505,
268            0.8700121482468192,
269            0.978618342232764,
270            0.7991585642167236,
271            0.46147936225293185,
272            0.7805291762864555,
273            0.11827442586893322,
274            0.6399210213275238,
275            0.1433532874090464,
276            0.9446689170495839,
277            0.5218483217500717,
278            0.4146619399905236,
279            0.26455561210462697,
280            0.7742336894342167,
281            0.45615033221654855,
282            0.5684339488686485,
283            0.018789800436355142,
284            0.6176354970758771,
285            0.6120957227224214,
286            0.6169339968747569,
287            0.9437480785146242,
288            0.6818202991034834,
289            0.359507900573786,
290            0.43703195379934145,
291            0.6976311959272649,
292            0.06022547162926983,
293            0.6667667154456677,
294            0.6706378696181594,
295            0.2103825610738409,
296            0.1289262976548533,
297            0.31542835092418386,
298            0.3637107709426226,
299            0.5701967704178796,
300            0.43860151346232035,
301            0.9883738380592262,
302            0.10204481074802807,
303            0.2088767560948347,
304            0.16130951788499626,
305            0.6531083254653984,
306            0.2532916025397821,
307            0.4663107728563063,
308            0.24442559200160274,
309            0.15896958364551972,
310            0.11037514116430513,
311            0.6563295894652734,
312            0.1381829513486138,
313            0.1965823616800535,
314            0.3687251706609641,
315            0.8209932298479351,
316            0.09710127579306127,
317            0.8379449074988039,
318            0.09609840789396307,
319            0.9764594650133958,
320            0.4686512016477016,
321            0.9767610881903371,
322            0.604845519745046,
323            0.7392635793983017,
324            0.039187792254320675,
325            0.2828069625764096,
326            0.1201965612131689,
327            0.29614019752214493,
328            0.11872771895424405,
329            0.317983179393976,
330            0.41426299451466997,
331            0.06414749634878436,
332            0.6924721193700198,
333            0.5666014542065752,
334            0.2653894909394454,
335            0.5232480534666997,
336            0.09394051075844168,
337            0.5759464955561793,
338            0.9292961975762141,
339            0.31856895245132366,
340            0.6674103799636817,
341            0.13179786240439217,
342            0.7163272041185655,
343            0.2894060929472011,
344            0.18319136200711683,
345            0.5865129348100832,
346            0.020107546187493552,
347            0.8289400292173631,
348            0.004695476192547066,
349            0.6778165367962301,
350            0.27000797319216485,
351            0.7351940221225949,
352            0.9621885451174382,
353            0.24875314351995803,
354            0.5761573344178369,
355            0.592041931271839,
356            0.5722519057908734,
357            0.2230816326406183,
358            0.952749011516985,
359            0.44712537861762736,
360            0.8464086724711278,
361            0.6994792753175043,
362            0.29743695085513366,
363            0.8137978197024772,
364            0.39650574084698464,
365            0.8811031971111616,
366            0.5812728726358587,
367            0.8817353618548528,
368            0.6925315900777659,
369        ];
370        let ans_0d = vec![
371            0.5488135039273248,
372            0.7151893663724195,
373            0.6563295894652734,
374            0.5448831829968969,
375            0.4236547993389047,
376            0.6458941130666561,
377            0.8209932298479351,
378            0.8917730007820798,
379            0.9636627605010293,
380            0.3834415188257777,
381            0.9764594650133958,
382            0.5288949197529045,
383            0.9767610881903371,
384            0.925596638292661,
385            0.7392635793983017,
386            0.08712929970154071,
387            0.2828069625764096,
388            0.832619845547938,
389            0.7781567509498505,
390            0.8700121482468192,
391            0.978618342232764,
392            0.7991585642167236,
393            0.46147936225293185,
394            0.7805291762864555,
395            0.5666014542065752,
396            0.6399210213275238,
397            0.5232480534666997,
398            0.9446689170495839,
399            0.5759464955561793,
400            0.9292961975762141,
401            0.31856895245132366,
402            0.7742336894342167,
403            0.45615033221654855,
404            0.7163272041185655,
405            0.2894060929472011,
406            0.6176354970758771,
407            0.6120957227224214,
408            0.6169339968747569,
409            0.9437480785146242,
410            0.6818202991034834,
411            0.6778165367962301,
412            0.43703195379934145,
413            0.7351940221225949,
414            0.9621885451174382,
415            0.6667667154456677,
416            0.6706378696181594,
417            0.592041931271839,
418            0.5722519057908734,
419            0.31542835092418386,
420            0.952749011516985,
421            0.5701967704178796,
422            0.8464086724711278,
423            0.9883738380592262,
424            0.29743695085513366,
425            0.8137978197024772,
426            0.39650574084698464,
427            0.8811031971111616,
428            0.5812728726358587,
429            0.8817353618548528,
430            0.6925315900777659,
431        ];
432        let ans_1d = vec![
433            0.978618342232764,
434            0.7991585642167236,
435            0.6976311959272649,
436            0.7805291762864555,
437            0.6667667154456677,
438            0.6706378696181594,
439            0.4375872112626925,
440            0.9446689170495839,
441            0.9636627605010293,
442            0.4146619399905236,
443            0.7917250380826646,
444            0.7742336894342167,
445            0.9883738380592262,
446            0.925596638292661,
447            0.2088767560948347,
448            0.6176354970758771,
449            0.6531083254653984,
450            0.832619845547938,
451            0.9437480785146242,
452            0.8700121482468192,
453            0.6778165367962301,
454            0.41426299451466997,
455            0.7351940221225949,
456            0.9621885451174382,
457            0.5666014542065752,
458            0.5761573344178369,
459            0.8209932298479351,
460            0.5722519057908734,
461            0.8379449074988039,
462            0.952749011516985,
463            0.9764594650133958,
464            0.8464086724711278,
465            0.9767610881903371,
466            0.7163272041185655,
467            0.8137978197024772,
468            0.39650574084698464,
469            0.8811031971111616,
470            0.5812728726358587,
471            0.8817353618548528,
472            0.6925315900777659,
473        ];
474        let ans_2d = vec![
475            0.7917250380826646,
476            0.7151893663724195,
477            0.8917730007820798,
478            0.9636627605010293,
479            0.8700121482468192,
480            0.978618342232764,
481            0.7991585642167236,
482            0.9446689170495839,
483            0.9437480785146242,
484            0.6818202991034834,
485            0.6706378696181594,
486            0.6531083254653984,
487            0.9883738380592262,
488            0.4663107728563063,
489            0.6667667154456677,
490            0.9764594650133958,
491            0.8209932298479351,
492            0.9767610881903371,
493            0.8379449074988039,
494            0.7392635793983017,
495            0.31856895245132366,
496            0.6674103799636817,
497            0.13179786240439217,
498            0.8289400292173631,
499            0.9292961975762141,
500            0.6778165367962301,
501            0.8811031971111616,
502            0.7351940221225949,
503            0.9621885451174382,
504            0.952749011516985,
505        ];
506
507        let ans_3d = vec![
508            0.7151893663724195,
509            0.9636627605010293,
510            0.925596638292661,
511            0.8700121482468192,
512            0.978618342232764,
513            0.9446689170495839,
514            0.7742336894342167,
515            0.9437480785146242,
516            0.6976311959272649,
517            0.6706378696181594,
518            0.9883738380592262,
519            0.6531083254653984,
520            0.6563295894652734,
521            0.8379449074988039,
522            0.9767610881903371,
523            0.29614019752214493,
524            0.6924721193700198,
525            0.9292961975762141,
526            0.7163272041185655,
527            0.8289400292173631,
528            0.9621885451174382,
529            0.952749011516985,
530            0.8464086724711278,
531            0.8817353618548528,
532        ];
533
534        let input = Matrix::<Owned<f32>, DimDyn, D>::from_vec(input, [2, 3, 4, 5]);
535        let ans_0d = Matrix::<Owned<f32>, DimDyn, D>::from_vec(ans_0d, [3, 4, 5]);
536        let ans_1d = Matrix::<Owned<f32>, DimDyn, D>::from_vec(ans_1d, [2, 4, 5]);
537        let ans_2d = Matrix::<Owned<f32>, DimDyn, D>::from_vec(ans_2d, [2, 3, 5]);
538        let ans_3d = Matrix::<Owned<f32>, DimDyn, D>::from_vec(ans_3d, [2, 3, 4]);
539
540        assert_mat_eq_epsilon!(input.max_axis(0, false), ans_0d, 1e-6);
541        assert_mat_eq_epsilon!(input.max_axis(1, false), ans_1d, 1e-6);
542        assert_mat_eq_epsilon!(input.max_axis(2, false), ans_2d, 1e-6);
543        assert_mat_eq_epsilon!(input.max_axis(3, false), ans_3d, 1e-6);
544    }
545    run_mat_test!(max_axis_4d, max_axis_4d_cpu, max_axis_4d_gpu);
546    fn max_axis_4d_f64<D: Device>() {
547        let input: Vec<f64> = vec![
548            0.5488135039273248,
549            0.7151893663724195,
550            0.6027633760716439,
551            0.5448831829968969,
552            0.4236547993389047,
553            0.6458941130666561,
554            0.4375872112626925,
555            0.8917730007820798,
556            0.9636627605010293,
557            0.3834415188257777,
558            0.7917250380826646,
559            0.5288949197529045,
560            0.5680445610939323,
561            0.925596638292661,
562            0.07103605819788694,
563            0.08712929970154071,
564            0.02021839744032572,
565            0.832619845547938,
566            0.7781567509498505,
567            0.8700121482468192,
568            0.978618342232764,
569            0.7991585642167236,
570            0.46147936225293185,
571            0.7805291762864555,
572            0.11827442586893322,
573            0.6399210213275238,
574            0.1433532874090464,
575            0.9446689170495839,
576            0.5218483217500717,
577            0.4146619399905236,
578            0.26455561210462697,
579            0.7742336894342167,
580            0.45615033221654855,
581            0.5684339488686485,
582            0.018789800436355142,
583            0.6176354970758771,
584            0.6120957227224214,
585            0.6169339968747569,
586            0.9437480785146242,
587            0.6818202991034834,
588            0.359507900573786,
589            0.43703195379934145,
590            0.6976311959272649,
591            0.06022547162926983,
592            0.6667667154456677,
593            0.6706378696181594,
594            0.2103825610738409,
595            0.1289262976548533,
596            0.31542835092418386,
597            0.3637107709426226,
598            0.5701967704178796,
599            0.43860151346232035,
600            0.9883738380592262,
601            0.10204481074802807,
602            0.2088767560948347,
603            0.16130951788499626,
604            0.6531083254653984,
605            0.2532916025397821,
606            0.4663107728563063,
607            0.24442559200160274,
608            0.15896958364551972,
609            0.11037514116430513,
610            0.6563295894652734,
611            0.1381829513486138,
612            0.1965823616800535,
613            0.3687251706609641,
614            0.8209932298479351,
615            0.09710127579306127,
616            0.8379449074988039,
617            0.09609840789396307,
618            0.9764594650133958,
619            0.4686512016477016,
620            0.9767610881903371,
621            0.604845519745046,
622            0.7392635793983017,
623            0.039187792254320675,
624            0.2828069625764096,
625            0.1201965612131689,
626            0.29614019752214493,
627            0.11872771895424405,
628            0.317983179393976,
629            0.41426299451466997,
630            0.06414749634878436,
631            0.6924721193700198,
632            0.5666014542065752,
633            0.2653894909394454,
634            0.5232480534666997,
635            0.09394051075844168,
636            0.5759464955561793,
637            0.9292961975762141,
638            0.31856895245132366,
639            0.6674103799636817,
640            0.13179786240439217,
641            0.7163272041185655,
642            0.2894060929472011,
643            0.18319136200711683,
644            0.5865129348100832,
645            0.020107546187493552,
646            0.8289400292173631,
647            0.004695476192547066,
648            0.6778165367962301,
649            0.27000797319216485,
650            0.7351940221225949,
651            0.9621885451174382,
652            0.24875314351995803,
653            0.5761573344178369,
654            0.592041931271839,
655            0.5722519057908734,
656            0.2230816326406183,
657            0.952749011516985,
658            0.44712537861762736,
659            0.8464086724711278,
660            0.6994792753175043,
661            0.29743695085513366,
662            0.8137978197024772,
663            0.39650574084698464,
664            0.8811031971111616,
665            0.5812728726358587,
666            0.8817353618548528,
667            0.6925315900777659,
668        ];
669        let ans_0d = vec![
670            0.5488135039273248,
671            0.7151893663724195,
672            0.6563295894652734,
673            0.5448831829968969,
674            0.4236547993389047,
675            0.6458941130666561,
676            0.8209932298479351,
677            0.8917730007820798,
678            0.9636627605010293,
679            0.3834415188257777,
680            0.9764594650133958,
681            0.5288949197529045,
682            0.9767610881903371,
683            0.925596638292661,
684            0.7392635793983017,
685            0.08712929970154071,
686            0.2828069625764096,
687            0.832619845547938,
688            0.7781567509498505,
689            0.8700121482468192,
690            0.978618342232764,
691            0.7991585642167236,
692            0.46147936225293185,
693            0.7805291762864555,
694            0.5666014542065752,
695            0.6399210213275238,
696            0.5232480534666997,
697            0.9446689170495839,
698            0.5759464955561793,
699            0.9292961975762141,
700            0.31856895245132366,
701            0.7742336894342167,
702            0.45615033221654855,
703            0.7163272041185655,
704            0.2894060929472011,
705            0.6176354970758771,
706            0.6120957227224214,
707            0.6169339968747569,
708            0.9437480785146242,
709            0.6818202991034834,
710            0.6778165367962301,
711            0.43703195379934145,
712            0.7351940221225949,
713            0.9621885451174382,
714            0.6667667154456677,
715            0.6706378696181594,
716            0.592041931271839,
717            0.5722519057908734,
718            0.31542835092418386,
719            0.952749011516985,
720            0.5701967704178796,
721            0.8464086724711278,
722            0.9883738380592262,
723            0.29743695085513366,
724            0.8137978197024772,
725            0.39650574084698464,
726            0.8811031971111616,
727            0.5812728726358587,
728            0.8817353618548528,
729            0.6925315900777659,
730        ];
731        let ans_1d = vec![
732            0.978618342232764,
733            0.7991585642167236,
734            0.6976311959272649,
735            0.7805291762864555,
736            0.6667667154456677,
737            0.6706378696181594,
738            0.4375872112626925,
739            0.9446689170495839,
740            0.9636627605010293,
741            0.4146619399905236,
742            0.7917250380826646,
743            0.7742336894342167,
744            0.9883738380592262,
745            0.925596638292661,
746            0.2088767560948347,
747            0.6176354970758771,
748            0.6531083254653984,
749            0.832619845547938,
750            0.9437480785146242,
751            0.8700121482468192,
752            0.6778165367962301,
753            0.41426299451466997,
754            0.7351940221225949,
755            0.9621885451174382,
756            0.5666014542065752,
757            0.5761573344178369,
758            0.8209932298479351,
759            0.5722519057908734,
760            0.8379449074988039,
761            0.952749011516985,
762            0.9764594650133958,
763            0.8464086724711278,
764            0.9767610881903371,
765            0.7163272041185655,
766            0.8137978197024772,
767            0.39650574084698464,
768            0.8811031971111616,
769            0.5812728726358587,
770            0.8817353618548528,
771            0.6925315900777659,
772        ];
773        let ans_2d = vec![
774            0.7917250380826646,
775            0.7151893663724195,
776            0.8917730007820798,
777            0.9636627605010293,
778            0.8700121482468192,
779            0.978618342232764,
780            0.7991585642167236,
781            0.9446689170495839,
782            0.9437480785146242,
783            0.6818202991034834,
784            0.6706378696181594,
785            0.6531083254653984,
786            0.9883738380592262,
787            0.4663107728563063,
788            0.6667667154456677,
789            0.9764594650133958,
790            0.8209932298479351,
791            0.9767610881903371,
792            0.8379449074988039,
793            0.7392635793983017,
794            0.31856895245132366,
795            0.6674103799636817,
796            0.13179786240439217,
797            0.8289400292173631,
798            0.9292961975762141,
799            0.6778165367962301,
800            0.8811031971111616,
801            0.7351940221225949,
802            0.9621885451174382,
803            0.952749011516985,
804        ];
805
806        let ans_3d = vec![
807            0.7151893663724195,
808            0.9636627605010293,
809            0.925596638292661,
810            0.8700121482468192,
811            0.978618342232764,
812            0.9446689170495839,
813            0.7742336894342167,
814            0.9437480785146242,
815            0.6976311959272649,
816            0.6706378696181594,
817            0.9883738380592262,
818            0.6531083254653984,
819            0.6563295894652734,
820            0.8379449074988039,
821            0.9767610881903371,
822            0.29614019752214493,
823            0.6924721193700198,
824            0.9292961975762141,
825            0.7163272041185655,
826            0.8289400292173631,
827            0.9621885451174382,
828            0.952749011516985,
829            0.8464086724711278,
830            0.8817353618548528,
831        ];
832
833        let input = Matrix::<Owned<f64>, DimDyn, D>::from_vec(input, [2, 3, 4, 5]);
834        let ans_0d = Matrix::<Owned<f64>, DimDyn, D>::from_vec(ans_0d, [3, 4, 5]);
835        let ans_1d = Matrix::<Owned<f64>, DimDyn, D>::from_vec(ans_1d, [2, 4, 5]);
836        let ans_2d = Matrix::<Owned<f64>, DimDyn, D>::from_vec(ans_2d, [2, 3, 5]);
837        let ans_3d = Matrix::<Owned<f64>, DimDyn, D>::from_vec(ans_3d, [2, 3, 4]);
838
839        assert_mat_eq_epsilon!(input.max_axis(0, false), ans_0d, 1e-6);
840        assert_mat_eq_epsilon!(input.max_axis(1, false), ans_1d, 1e-6);
841        assert_mat_eq_epsilon!(input.max_axis(2, false), ans_2d, 1e-6);
842        assert_mat_eq_epsilon!(input.max_axis(3, false), ans_3d, 1e-6);
843    }
844    run_mat_test!(max_axis_4d_f64, max_axis_4d_cpu_f64, max_axis_4d_gpu_f64);
845}