1use crate::{
2 device::{cpu::Cpu, Device, DeviceBase},
3 dim::{DimDyn, DimTrait},
4 index::index_dyn_impl::Index,
5 matrix::{Matrix, Owned, Repr},
6 num::Num,
7};
8
9#[cfg(feature = "nvidia")]
10use crate::device::nvidia::Nvidia;
11
12#[cfg(feature = "nvidia")]
13use zenu_cuda::kernel::array_max_idx;
14
15pub trait MaxIdx: DeviceBase {
16 fn max_idx<T: Num>(input: *const T, size: usize, stride: usize) -> usize;
17}
18
19impl MaxIdx for Cpu {
20 #[expect(clippy::not_unsafe_ptr_arg_deref)]
21 fn max_idx<T: Num>(input: *const T, size: usize, stride: usize) -> usize {
22 let tmep_v = unsafe { std::slice::from_raw_parts(input, size * stride) };
23 let mut max_idx = 0;
24 let mut max_val = tmep_v[0];
25
26 for i in 1..size {
27 if tmep_v[i * stride] > max_val {
28 max_val = tmep_v[i * stride];
29 max_idx = i;
30 }
31 }
32 max_idx
33 }
34}
35
36#[cfg(feature = "nvidia")]
37impl MaxIdx for Nvidia {
38 fn max_idx<T: Num>(input: *const T, size: usize, stride: usize) -> usize {
39 array_max_idx(input, size, stride)
40 }
41}
42
43impl<T: Num, R: Repr<Item = T>, D: Device> Matrix<R, DimDyn, D> {
44 #[must_use]
45 pub fn max_idx(&self) -> DimDyn {
46 if self.shape().is_empty() {
47 return DimDyn::from(&[] as &[usize]);
48 }
49 let default_stride = self.to_default_stride();
50 let idx = <D as MaxIdx>::max_idx(
51 default_stride.as_ptr(),
52 default_stride.shape().num_elm(),
53 default_stride.stride()[default_stride.shape().len() - 1],
54 );
55 default_stride.shape_stride().get_dim_by_offset(idx)
56 }
57
58 #[must_use]
59 pub fn max_item(&self) -> T {
60 let idx = self.max_idx();
61 self.index_item(idx)
62 }
63
64 #[expect(clippy::missing_panics_doc)]
66 #[must_use]
67 pub fn max_axis(&self, axis: usize, keep_dim: bool) -> Matrix<Owned<T>, DimDyn, D> {
68 assert!(axis < self.shape().len(), "max_axis: Axis out of bounds");
69
70 let mut output_shape = Vec::new();
71 for i in 0..self.shape().len() {
72 if i == axis {
73 continue;
74 }
75 output_shape.push(self.shape()[i]);
76 }
77
78 let output_shape = DimDyn::from(&output_shape as &[usize]);
79 let mut output = Matrix::<Owned<T>, DimDyn, D>::alloc(output_shape);
80
81 if axis == 0 {
82 let output_flatten = output.reshape_mut([output.shape().num_elm()]);
83 let s = self.reshape_new_matrix([self.shape()[0], output_shape.num_elm()]);
84 for i in 0..output_shape.num_elm() {
85 output_flatten.index_item_assign([i], s.index_axis(Index::new(1, i)).max_item());
86 }
87 } else {
88 for i in 0..self.shape()[0] {
89 let s = self.index_axis(Index::new(0, i));
90 let output = output.to_ref_mut().index_axis_mut(Index::new(0, i));
91 output.copy_from(&s.max_axis(axis - 1, false));
92 }
93 }
94
95 if keep_dim {
96 output.add_axis(axis);
97 }
98 output
99 }
100
101 #[expect(clippy::missing_panics_doc)]
102 #[must_use]
103 pub fn max_axis_idx_ravel(&self, axis: usize) -> Vec<usize> {
104 assert!(axis < self.shape().len(), "max_axis: Axis out of bounds");
105
106 let mut output_shape = Vec::new();
107 for i in 0..self.shape().len() {
108 if i == axis {
109 continue;
110 }
111 output_shape.push(self.shape()[i]);
112 }
113
114 let output_shape = DimDyn::from(&output_shape as &[usize]);
115 let mut output = Vec::with_capacity(output_shape.num_elm());
116
117 if axis == 0 {
118 let s = self.reshape_new_matrix([self.shape()[0], output_shape.num_elm()]);
119 for i in 0..output_shape.num_elm() {
120 let idx = s.index_axis(Index::new(1, i)).max_idx()[0];
121 output.push(idx);
122 }
123 } else {
124 for i in 0..self.shape()[0] {
125 let s = self.index_axis(Index::new(0, i));
126 let tmp = s.max_axis_idx_ravel(axis - 1);
127 for tmp_idx in tmp {
128 output.push(tmp_idx);
129 }
130 }
131 }
132
133 output
134 }
135}
136
137#[cfg(test)]
138mod max_idx {
139 #![expect(
140 clippy::float_cmp,
141 clippy::unreadable_literal,
142 clippy::cast_precision_loss,
143 clippy::cast_possible_truncation,
144 clippy::too_many_lines,
145 clippy::excessive_precision
146 )]
147
148 use crate::{
149 device::Device,
150 dim::DimDyn,
151 matrix::{Matrix, Owned},
152 slice_dynamic,
153 };
154
155 use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
156
157 fn default_1d<D: Device>() {
158 let a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![0., 1., 2., 3.], [4]);
159 assert_eq!(a.to_ref().max_idx(), [3].into());
160 }
161 #[test]
162 fn default_1d_cpu() {
163 default_1d::<crate::device::cpu::Cpu>();
164 }
165 #[cfg(feature = "nvidia")]
166 #[test]
167 fn default_1d_gpu() {
168 default_1d::<crate::device::nvidia::Nvidia>();
169 }
170
171 fn default_2d<D: Device>() {
172 let a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![0., 1., 2., 3.], [2, 2]);
173 assert_eq!(a.to_ref().max_idx(), [1, 1].into());
174 }
175 #[test]
176 fn default_2d_cpu() {
177 default_2d::<crate::device::cpu::Cpu>();
178 }
179 #[cfg(feature = "nvidia")]
180 #[test]
181 fn default_2d_gpu() {
182 default_2d::<crate::device::nvidia::Nvidia>();
183 }
184
185 fn sliced_3d<D: Device>() {
186 let mut v = Vec::new();
187 for i in 0..8 * 8 * 8 {
188 v.push(i as f32);
189 }
190 let a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(v, [8, 8, 8]);
191 let sliced = a.slice(slice_dynamic!(..;3, ..;4, ..;2));
192 assert_eq!(sliced.max_idx(), [2, 1, 3].into());
193 }
194 #[test]
195 fn sliced_3d_cpu() {
196 sliced_3d::<crate::device::cpu::Cpu>();
197 }
198 #[cfg(feature = "nvidia")]
199 #[test]
200 fn sliced_3d_gpu() {
201 sliced_3d::<crate::device::nvidia::Nvidia>();
202 }
203
204 fn max_axis_reval<D: Device>() {
205 let input = vec![
206 0.04881350392732475,
207 0.21518936637241948,
208 0.10276337607164387,
209 0.044883182996896864,
210 -0.07634520066109529,
211 0.14589411306665612,
212 -0.06241278873730749,
213 0.39177300078207977,
214 0.4636627605010293,
215 -0.1165584811742223,
216 0.2917250380826646,
217 0.02889491975290448,
218 0.06804456109393231,
219 0.42559663829266103,
220 -0.42896394180211306,
221 -0.4128707002984593,
222 -0.4797816025596743,
223 0.332619845547938,
224 0.2781567509498505,
225 0.37001214824681916,
226 0.478618342232764,
227 0.2991585642167236,
228 -0.038520637747068154,
229 0.28052917628645546,
230 ];
231
232 let input = Matrix::<Owned<f32>, DimDyn, D>::from_vec(input, [2, 3, 4]);
233 let ans_0d = vec![1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1];
234 let ans_1d = vec![2, 0, 2, 1, 2, 0, 1, 1];
235 let ans_2d = vec![1, 3, 0, 1, 3, 0];
236
237 let result_0d = input.max_axis_idx_ravel(0);
238 let result_1d = input.max_axis_idx_ravel(1);
239 let result_2d = input.max_axis_idx_ravel(2);
240
241 assert_eq!(result_0d, ans_0d);
242 assert_eq!(result_1d, ans_1d);
243 assert_eq!(result_2d, ans_2d);
244 }
245 run_mat_test!(max_axis_reval, max_axis_reval_cpu, max_axis_reval_gpu);
246
247 fn max_axis_4d<D: Device>() {
248 let input: Vec<f32> = vec![
249 0.5488135039273248,
250 0.7151893663724195,
251 0.6027633760716439,
252 0.5448831829968969,
253 0.4236547993389047,
254 0.6458941130666561,
255 0.4375872112626925,
256 0.8917730007820798,
257 0.9636627605010293,
258 0.3834415188257777,
259 0.7917250380826646,
260 0.5288949197529045,
261 0.5680445610939323,
262 0.925596638292661,
263 0.07103605819788694,
264 0.08712929970154071,
265 0.02021839744032572,
266 0.832619845547938,
267 0.7781567509498505,
268 0.8700121482468192,
269 0.978618342232764,
270 0.7991585642167236,
271 0.46147936225293185,
272 0.7805291762864555,
273 0.11827442586893322,
274 0.6399210213275238,
275 0.1433532874090464,
276 0.9446689170495839,
277 0.5218483217500717,
278 0.4146619399905236,
279 0.26455561210462697,
280 0.7742336894342167,
281 0.45615033221654855,
282 0.5684339488686485,
283 0.018789800436355142,
284 0.6176354970758771,
285 0.6120957227224214,
286 0.6169339968747569,
287 0.9437480785146242,
288 0.6818202991034834,
289 0.359507900573786,
290 0.43703195379934145,
291 0.6976311959272649,
292 0.06022547162926983,
293 0.6667667154456677,
294 0.6706378696181594,
295 0.2103825610738409,
296 0.1289262976548533,
297 0.31542835092418386,
298 0.3637107709426226,
299 0.5701967704178796,
300 0.43860151346232035,
301 0.9883738380592262,
302 0.10204481074802807,
303 0.2088767560948347,
304 0.16130951788499626,
305 0.6531083254653984,
306 0.2532916025397821,
307 0.4663107728563063,
308 0.24442559200160274,
309 0.15896958364551972,
310 0.11037514116430513,
311 0.6563295894652734,
312 0.1381829513486138,
313 0.1965823616800535,
314 0.3687251706609641,
315 0.8209932298479351,
316 0.09710127579306127,
317 0.8379449074988039,
318 0.09609840789396307,
319 0.9764594650133958,
320 0.4686512016477016,
321 0.9767610881903371,
322 0.604845519745046,
323 0.7392635793983017,
324 0.039187792254320675,
325 0.2828069625764096,
326 0.1201965612131689,
327 0.29614019752214493,
328 0.11872771895424405,
329 0.317983179393976,
330 0.41426299451466997,
331 0.06414749634878436,
332 0.6924721193700198,
333 0.5666014542065752,
334 0.2653894909394454,
335 0.5232480534666997,
336 0.09394051075844168,
337 0.5759464955561793,
338 0.9292961975762141,
339 0.31856895245132366,
340 0.6674103799636817,
341 0.13179786240439217,
342 0.7163272041185655,
343 0.2894060929472011,
344 0.18319136200711683,
345 0.5865129348100832,
346 0.020107546187493552,
347 0.8289400292173631,
348 0.004695476192547066,
349 0.6778165367962301,
350 0.27000797319216485,
351 0.7351940221225949,
352 0.9621885451174382,
353 0.24875314351995803,
354 0.5761573344178369,
355 0.592041931271839,
356 0.5722519057908734,
357 0.2230816326406183,
358 0.952749011516985,
359 0.44712537861762736,
360 0.8464086724711278,
361 0.6994792753175043,
362 0.29743695085513366,
363 0.8137978197024772,
364 0.39650574084698464,
365 0.8811031971111616,
366 0.5812728726358587,
367 0.8817353618548528,
368 0.6925315900777659,
369 ];
370 let ans_0d = vec![
371 0.5488135039273248,
372 0.7151893663724195,
373 0.6563295894652734,
374 0.5448831829968969,
375 0.4236547993389047,
376 0.6458941130666561,
377 0.8209932298479351,
378 0.8917730007820798,
379 0.9636627605010293,
380 0.3834415188257777,
381 0.9764594650133958,
382 0.5288949197529045,
383 0.9767610881903371,
384 0.925596638292661,
385 0.7392635793983017,
386 0.08712929970154071,
387 0.2828069625764096,
388 0.832619845547938,
389 0.7781567509498505,
390 0.8700121482468192,
391 0.978618342232764,
392 0.7991585642167236,
393 0.46147936225293185,
394 0.7805291762864555,
395 0.5666014542065752,
396 0.6399210213275238,
397 0.5232480534666997,
398 0.9446689170495839,
399 0.5759464955561793,
400 0.9292961975762141,
401 0.31856895245132366,
402 0.7742336894342167,
403 0.45615033221654855,
404 0.7163272041185655,
405 0.2894060929472011,
406 0.6176354970758771,
407 0.6120957227224214,
408 0.6169339968747569,
409 0.9437480785146242,
410 0.6818202991034834,
411 0.6778165367962301,
412 0.43703195379934145,
413 0.7351940221225949,
414 0.9621885451174382,
415 0.6667667154456677,
416 0.6706378696181594,
417 0.592041931271839,
418 0.5722519057908734,
419 0.31542835092418386,
420 0.952749011516985,
421 0.5701967704178796,
422 0.8464086724711278,
423 0.9883738380592262,
424 0.29743695085513366,
425 0.8137978197024772,
426 0.39650574084698464,
427 0.8811031971111616,
428 0.5812728726358587,
429 0.8817353618548528,
430 0.6925315900777659,
431 ];
432 let ans_1d = vec![
433 0.978618342232764,
434 0.7991585642167236,
435 0.6976311959272649,
436 0.7805291762864555,
437 0.6667667154456677,
438 0.6706378696181594,
439 0.4375872112626925,
440 0.9446689170495839,
441 0.9636627605010293,
442 0.4146619399905236,
443 0.7917250380826646,
444 0.7742336894342167,
445 0.9883738380592262,
446 0.925596638292661,
447 0.2088767560948347,
448 0.6176354970758771,
449 0.6531083254653984,
450 0.832619845547938,
451 0.9437480785146242,
452 0.8700121482468192,
453 0.6778165367962301,
454 0.41426299451466997,
455 0.7351940221225949,
456 0.9621885451174382,
457 0.5666014542065752,
458 0.5761573344178369,
459 0.8209932298479351,
460 0.5722519057908734,
461 0.8379449074988039,
462 0.952749011516985,
463 0.9764594650133958,
464 0.8464086724711278,
465 0.9767610881903371,
466 0.7163272041185655,
467 0.8137978197024772,
468 0.39650574084698464,
469 0.8811031971111616,
470 0.5812728726358587,
471 0.8817353618548528,
472 0.6925315900777659,
473 ];
474 let ans_2d = vec![
475 0.7917250380826646,
476 0.7151893663724195,
477 0.8917730007820798,
478 0.9636627605010293,
479 0.8700121482468192,
480 0.978618342232764,
481 0.7991585642167236,
482 0.9446689170495839,
483 0.9437480785146242,
484 0.6818202991034834,
485 0.6706378696181594,
486 0.6531083254653984,
487 0.9883738380592262,
488 0.4663107728563063,
489 0.6667667154456677,
490 0.9764594650133958,
491 0.8209932298479351,
492 0.9767610881903371,
493 0.8379449074988039,
494 0.7392635793983017,
495 0.31856895245132366,
496 0.6674103799636817,
497 0.13179786240439217,
498 0.8289400292173631,
499 0.9292961975762141,
500 0.6778165367962301,
501 0.8811031971111616,
502 0.7351940221225949,
503 0.9621885451174382,
504 0.952749011516985,
505 ];
506
507 let ans_3d = vec![
508 0.7151893663724195,
509 0.9636627605010293,
510 0.925596638292661,
511 0.8700121482468192,
512 0.978618342232764,
513 0.9446689170495839,
514 0.7742336894342167,
515 0.9437480785146242,
516 0.6976311959272649,
517 0.6706378696181594,
518 0.9883738380592262,
519 0.6531083254653984,
520 0.6563295894652734,
521 0.8379449074988039,
522 0.9767610881903371,
523 0.29614019752214493,
524 0.6924721193700198,
525 0.9292961975762141,
526 0.7163272041185655,
527 0.8289400292173631,
528 0.9621885451174382,
529 0.952749011516985,
530 0.8464086724711278,
531 0.8817353618548528,
532 ];
533
534 let input = Matrix::<Owned<f32>, DimDyn, D>::from_vec(input, [2, 3, 4, 5]);
535 let ans_0d = Matrix::<Owned<f32>, DimDyn, D>::from_vec(ans_0d, [3, 4, 5]);
536 let ans_1d = Matrix::<Owned<f32>, DimDyn, D>::from_vec(ans_1d, [2, 4, 5]);
537 let ans_2d = Matrix::<Owned<f32>, DimDyn, D>::from_vec(ans_2d, [2, 3, 5]);
538 let ans_3d = Matrix::<Owned<f32>, DimDyn, D>::from_vec(ans_3d, [2, 3, 4]);
539
540 assert_mat_eq_epsilon!(input.max_axis(0, false), ans_0d, 1e-6);
541 assert_mat_eq_epsilon!(input.max_axis(1, false), ans_1d, 1e-6);
542 assert_mat_eq_epsilon!(input.max_axis(2, false), ans_2d, 1e-6);
543 assert_mat_eq_epsilon!(input.max_axis(3, false), ans_3d, 1e-6);
544 }
545 run_mat_test!(max_axis_4d, max_axis_4d_cpu, max_axis_4d_gpu);
546 fn max_axis_4d_f64<D: Device>() {
547 let input: Vec<f64> = vec![
548 0.5488135039273248,
549 0.7151893663724195,
550 0.6027633760716439,
551 0.5448831829968969,
552 0.4236547993389047,
553 0.6458941130666561,
554 0.4375872112626925,
555 0.8917730007820798,
556 0.9636627605010293,
557 0.3834415188257777,
558 0.7917250380826646,
559 0.5288949197529045,
560 0.5680445610939323,
561 0.925596638292661,
562 0.07103605819788694,
563 0.08712929970154071,
564 0.02021839744032572,
565 0.832619845547938,
566 0.7781567509498505,
567 0.8700121482468192,
568 0.978618342232764,
569 0.7991585642167236,
570 0.46147936225293185,
571 0.7805291762864555,
572 0.11827442586893322,
573 0.6399210213275238,
574 0.1433532874090464,
575 0.9446689170495839,
576 0.5218483217500717,
577 0.4146619399905236,
578 0.26455561210462697,
579 0.7742336894342167,
580 0.45615033221654855,
581 0.5684339488686485,
582 0.018789800436355142,
583 0.6176354970758771,
584 0.6120957227224214,
585 0.6169339968747569,
586 0.9437480785146242,
587 0.6818202991034834,
588 0.359507900573786,
589 0.43703195379934145,
590 0.6976311959272649,
591 0.06022547162926983,
592 0.6667667154456677,
593 0.6706378696181594,
594 0.2103825610738409,
595 0.1289262976548533,
596 0.31542835092418386,
597 0.3637107709426226,
598 0.5701967704178796,
599 0.43860151346232035,
600 0.9883738380592262,
601 0.10204481074802807,
602 0.2088767560948347,
603 0.16130951788499626,
604 0.6531083254653984,
605 0.2532916025397821,
606 0.4663107728563063,
607 0.24442559200160274,
608 0.15896958364551972,
609 0.11037514116430513,
610 0.6563295894652734,
611 0.1381829513486138,
612 0.1965823616800535,
613 0.3687251706609641,
614 0.8209932298479351,
615 0.09710127579306127,
616 0.8379449074988039,
617 0.09609840789396307,
618 0.9764594650133958,
619 0.4686512016477016,
620 0.9767610881903371,
621 0.604845519745046,
622 0.7392635793983017,
623 0.039187792254320675,
624 0.2828069625764096,
625 0.1201965612131689,
626 0.29614019752214493,
627 0.11872771895424405,
628 0.317983179393976,
629 0.41426299451466997,
630 0.06414749634878436,
631 0.6924721193700198,
632 0.5666014542065752,
633 0.2653894909394454,
634 0.5232480534666997,
635 0.09394051075844168,
636 0.5759464955561793,
637 0.9292961975762141,
638 0.31856895245132366,
639 0.6674103799636817,
640 0.13179786240439217,
641 0.7163272041185655,
642 0.2894060929472011,
643 0.18319136200711683,
644 0.5865129348100832,
645 0.020107546187493552,
646 0.8289400292173631,
647 0.004695476192547066,
648 0.6778165367962301,
649 0.27000797319216485,
650 0.7351940221225949,
651 0.9621885451174382,
652 0.24875314351995803,
653 0.5761573344178369,
654 0.592041931271839,
655 0.5722519057908734,
656 0.2230816326406183,
657 0.952749011516985,
658 0.44712537861762736,
659 0.8464086724711278,
660 0.6994792753175043,
661 0.29743695085513366,
662 0.8137978197024772,
663 0.39650574084698464,
664 0.8811031971111616,
665 0.5812728726358587,
666 0.8817353618548528,
667 0.6925315900777659,
668 ];
669 let ans_0d = vec![
670 0.5488135039273248,
671 0.7151893663724195,
672 0.6563295894652734,
673 0.5448831829968969,
674 0.4236547993389047,
675 0.6458941130666561,
676 0.8209932298479351,
677 0.8917730007820798,
678 0.9636627605010293,
679 0.3834415188257777,
680 0.9764594650133958,
681 0.5288949197529045,
682 0.9767610881903371,
683 0.925596638292661,
684 0.7392635793983017,
685 0.08712929970154071,
686 0.2828069625764096,
687 0.832619845547938,
688 0.7781567509498505,
689 0.8700121482468192,
690 0.978618342232764,
691 0.7991585642167236,
692 0.46147936225293185,
693 0.7805291762864555,
694 0.5666014542065752,
695 0.6399210213275238,
696 0.5232480534666997,
697 0.9446689170495839,
698 0.5759464955561793,
699 0.9292961975762141,
700 0.31856895245132366,
701 0.7742336894342167,
702 0.45615033221654855,
703 0.7163272041185655,
704 0.2894060929472011,
705 0.6176354970758771,
706 0.6120957227224214,
707 0.6169339968747569,
708 0.9437480785146242,
709 0.6818202991034834,
710 0.6778165367962301,
711 0.43703195379934145,
712 0.7351940221225949,
713 0.9621885451174382,
714 0.6667667154456677,
715 0.6706378696181594,
716 0.592041931271839,
717 0.5722519057908734,
718 0.31542835092418386,
719 0.952749011516985,
720 0.5701967704178796,
721 0.8464086724711278,
722 0.9883738380592262,
723 0.29743695085513366,
724 0.8137978197024772,
725 0.39650574084698464,
726 0.8811031971111616,
727 0.5812728726358587,
728 0.8817353618548528,
729 0.6925315900777659,
730 ];
731 let ans_1d = vec![
732 0.978618342232764,
733 0.7991585642167236,
734 0.6976311959272649,
735 0.7805291762864555,
736 0.6667667154456677,
737 0.6706378696181594,
738 0.4375872112626925,
739 0.9446689170495839,
740 0.9636627605010293,
741 0.4146619399905236,
742 0.7917250380826646,
743 0.7742336894342167,
744 0.9883738380592262,
745 0.925596638292661,
746 0.2088767560948347,
747 0.6176354970758771,
748 0.6531083254653984,
749 0.832619845547938,
750 0.9437480785146242,
751 0.8700121482468192,
752 0.6778165367962301,
753 0.41426299451466997,
754 0.7351940221225949,
755 0.9621885451174382,
756 0.5666014542065752,
757 0.5761573344178369,
758 0.8209932298479351,
759 0.5722519057908734,
760 0.8379449074988039,
761 0.952749011516985,
762 0.9764594650133958,
763 0.8464086724711278,
764 0.9767610881903371,
765 0.7163272041185655,
766 0.8137978197024772,
767 0.39650574084698464,
768 0.8811031971111616,
769 0.5812728726358587,
770 0.8817353618548528,
771 0.6925315900777659,
772 ];
773 let ans_2d = vec![
774 0.7917250380826646,
775 0.7151893663724195,
776 0.8917730007820798,
777 0.9636627605010293,
778 0.8700121482468192,
779 0.978618342232764,
780 0.7991585642167236,
781 0.9446689170495839,
782 0.9437480785146242,
783 0.6818202991034834,
784 0.6706378696181594,
785 0.6531083254653984,
786 0.9883738380592262,
787 0.4663107728563063,
788 0.6667667154456677,
789 0.9764594650133958,
790 0.8209932298479351,
791 0.9767610881903371,
792 0.8379449074988039,
793 0.7392635793983017,
794 0.31856895245132366,
795 0.6674103799636817,
796 0.13179786240439217,
797 0.8289400292173631,
798 0.9292961975762141,
799 0.6778165367962301,
800 0.8811031971111616,
801 0.7351940221225949,
802 0.9621885451174382,
803 0.952749011516985,
804 ];
805
806 let ans_3d = vec![
807 0.7151893663724195,
808 0.9636627605010293,
809 0.925596638292661,
810 0.8700121482468192,
811 0.978618342232764,
812 0.9446689170495839,
813 0.7742336894342167,
814 0.9437480785146242,
815 0.6976311959272649,
816 0.6706378696181594,
817 0.9883738380592262,
818 0.6531083254653984,
819 0.6563295894652734,
820 0.8379449074988039,
821 0.9767610881903371,
822 0.29614019752214493,
823 0.6924721193700198,
824 0.9292961975762141,
825 0.7163272041185655,
826 0.8289400292173631,
827 0.9621885451174382,
828 0.952749011516985,
829 0.8464086724711278,
830 0.8817353618548528,
831 ];
832
833 let input = Matrix::<Owned<f64>, DimDyn, D>::from_vec(input, [2, 3, 4, 5]);
834 let ans_0d = Matrix::<Owned<f64>, DimDyn, D>::from_vec(ans_0d, [3, 4, 5]);
835 let ans_1d = Matrix::<Owned<f64>, DimDyn, D>::from_vec(ans_1d, [2, 4, 5]);
836 let ans_2d = Matrix::<Owned<f64>, DimDyn, D>::from_vec(ans_2d, [2, 3, 5]);
837 let ans_3d = Matrix::<Owned<f64>, DimDyn, D>::from_vec(ans_3d, [2, 3, 4]);
838
839 assert_mat_eq_epsilon!(input.max_axis(0, false), ans_0d, 1e-6);
840 assert_mat_eq_epsilon!(input.max_axis(1, false), ans_1d, 1e-6);
841 assert_mat_eq_epsilon!(input.max_axis(2, false), ans_2d, 1e-6);
842 assert_mat_eq_epsilon!(input.max_axis(3, false), ans_3d, 1e-6);
843 }
844 run_mat_test!(max_axis_4d_f64, max_axis_4d_cpu_f64, max_axis_4d_gpu_f64);
845}