1use std::fmt;
2use std::fmt::{Debug, Display};
3use std::ops::Neg;
4use std::ops::Range;
5
6use crate::numbers::basenum::Number;
7use crate::numbers::realnum::RealNumber;
8
9use num::ToPrimitive;
10use num_traits::Signed;
11
12pub trait Array<T: Debug + Display + Copy + Sized, S>: Debug {
14 fn get(&self, pos: S) -> &T;
16 fn shape(&self) -> S;
18 fn is_empty(&self) -> bool;
20 fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b>;
22}
23
24pub trait MutArray<T: Debug + Display + Copy + Sized, S>: Array<T, S> {
26 fn set(&mut self, pos: S, x: T);
28 fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b>;
30 fn swap(&mut self, a: S, b: S)
32 where
33 S: Copy,
34 {
35 let t = *self.get(a);
36 self.set(a, *self.get(b));
37 self.set(b, t);
38 }
39 fn div_element_mut(&mut self, pos: S, x: T)
41 where
42 T: Number,
43 S: Copy,
44 {
45 self.set(pos, *self.get(pos) / x);
46 }
47 fn mul_element_mut(&mut self, pos: S, x: T)
49 where
50 T: Number,
51 S: Copy,
52 {
53 self.set(pos, *self.get(pos) * x);
54 }
55 fn add_element_mut(&mut self, pos: S, x: T)
57 where
58 T: Number,
59 S: Copy,
60 {
61 self.set(pos, *self.get(pos) + x);
62 }
63 fn sub_element_mut(&mut self, pos: S, x: T)
65 where
66 T: Number,
67 S: Copy,
68 {
69 self.set(pos, *self.get(pos) - x);
70 }
71 fn sub_scalar_mut(&mut self, x: T)
73 where
74 T: Number,
75 {
76 self.iterator_mut(0).for_each(|v| *v -= x);
77 }
78 fn add_scalar_mut(&mut self, x: T)
80 where
81 T: Number,
82 {
83 self.iterator_mut(0).for_each(|v| *v += x);
84 }
85 fn mul_scalar_mut(&mut self, x: T)
87 where
88 T: Number,
89 {
90 self.iterator_mut(0).for_each(|v| *v *= x);
91 }
92 fn div_scalar_mut(&mut self, x: T)
94 where
95 T: Number,
96 {
97 self.iterator_mut(0).for_each(|v| *v /= x);
98 }
99 fn add_mut(&mut self, other: &dyn Array<T, S>)
101 where
102 T: Number,
103 S: Eq,
104 {
105 assert!(
106 self.shape() == other.shape(),
107 "A and B should have the same shape"
108 );
109 self.iterator_mut(0)
110 .zip(other.iterator(0))
111 .for_each(|(a, &b)| *a += b);
112 }
113 fn sub_mut(&mut self, other: &dyn Array<T, S>)
115 where
116 T: Number,
117 S: Eq,
118 {
119 assert!(
120 self.shape() == other.shape(),
121 "A and B should have the same shape"
122 );
123 self.iterator_mut(0)
124 .zip(other.iterator(0))
125 .for_each(|(a, &b)| *a -= b);
126 }
127 fn mul_mut(&mut self, other: &dyn Array<T, S>)
129 where
130 T: Number,
131 S: Eq,
132 {
133 assert!(
134 self.shape() == other.shape(),
135 "A and B should have the same shape"
136 );
137 self.iterator_mut(0)
138 .zip(other.iterator(0))
139 .for_each(|(a, &b)| *a *= b);
140 }
141 fn div_mut(&mut self, other: &dyn Array<T, S>)
143 where
144 T: Number,
145 S: Eq,
146 {
147 assert!(
148 self.shape() == other.shape(),
149 "A and B should have the same shape"
150 );
151 self.iterator_mut(0)
152 .zip(other.iterator(0))
153 .for_each(|(a, &b)| *a /= b);
154 }
155}
156
157pub trait ArrayView1<T: Debug + Display + Copy + Sized>: Array<T, usize> {
159 fn dot(&self, other: &dyn ArrayView1<T>) -> T
161 where
162 T: Number,
163 {
164 assert!(
165 self.shape() == other.shape(),
166 "Can't take dot product. Arrays have different shapes"
167 );
168 self.iterator(0)
169 .zip(other.iterator(0))
170 .map(|(s, o)| *s * *o)
171 .sum()
172 }
173 fn sum(&self) -> T
175 where
176 T: Number,
177 {
178 self.iterator(0).copied().sum()
179 }
180 fn max(&self) -> T
182 where
183 T: Number + PartialOrd,
184 {
185 let max_f = |max: T, v: &T| -> T {
186 match T::gt(v, &max) {
187 true => *v,
188 _ => max,
189 }
190 };
191 self.iterator(0).fold(T::min_value(), max_f)
192 }
193 fn min(&self) -> T
195 where
196 T: Number + PartialOrd,
197 {
198 let min_f = |min: T, v: &T| -> T {
199 match T::lt(v, &min) {
200 true => *v,
201 _ => min,
202 }
203 };
204 self.iterator(0).fold(T::max_value(), min_f)
205 }
206 fn argmax(&self) -> usize
208 where
209 T: Number + PartialOrd,
210 {
211 let mut max = T::min_value();
213 let mut max_pos = 0usize;
214 for (i, v) in self.iterator(0).enumerate() {
215 if T::gt(v, &max) {
216 max = *v;
217 max_pos = i;
218 }
219 }
220 max_pos
221 }
222 fn unique(&self) -> Vec<T>
224 where
225 T: Number + Ord,
226 {
227 let mut result: Vec<T> = self.iterator(0).copied().collect();
228 result.sort();
229 result.dedup();
230 result
231 }
232 fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>)
234 where
235 T: Number + Ord,
236 {
237 let mut unique: Vec<T> = self.iterator(0).copied().collect();
238 unique.sort();
239 unique.dedup();
240
241 let mut unique_index = Vec::with_capacity(self.shape());
242 for idx in 0..self.shape() {
243 unique_index.push(unique.iter().position(|v| self.get(idx) == v).unwrap());
244 }
245
246 (unique, unique_index)
247 }
248 fn norm2(&self) -> f64
250 where
251 T: Number,
252 {
253 self.iterator(0)
254 .fold(0f64, |norm, xi| {
255 let xi = xi.to_f64().unwrap();
256 norm + xi * xi
257 })
258 .sqrt()
259 }
260 fn norm(&self, p: f64) -> f64
262 where
263 T: Number,
264 {
265 if p.is_infinite() && p.is_sign_positive() {
266 self.iterator(0)
267 .map(|x| x.to_f64().unwrap().abs())
268 .fold(f64::NEG_INFINITY, |a, b| a.max(b))
269 } else if p.is_infinite() && p.is_sign_negative() {
270 self.iterator(0)
271 .map(|x| x.to_f64().unwrap().abs())
272 .fold(f64::INFINITY, |a, b| a.min(b))
273 } else {
274 let mut norm = 0f64;
275
276 for xi in self.iterator(0) {
277 norm += xi.to_f64().unwrap().abs().powf(p);
278 }
279
280 norm.powf(1f64 / p)
281 }
282 }
283 fn max_diff(&self, other: &dyn ArrayView1<T>) -> T
285 where
286 T: Number + Signed + PartialOrd,
287 {
288 assert!(
289 self.shape() == other.shape(),
290 "Both arrays should have the same shape ({})",
291 self.shape()
292 );
293 let max_f = |max: T, v: T| -> T {
294 match T::gt(&v, &max) {
295 true => v,
296 _ => max,
297 }
298 };
299 self.iterator(0)
300 .zip(other.iterator(0))
301 .map(|(&a, &b)| (a - b).abs())
302 .fold(T::min_value(), max_f)
303 }
304 fn variance(&self) -> f64
306 where
307 T: Number,
308 {
309 let n = self.shape();
310
311 let mut mu = 0f64;
312 let mut sum = 0f64;
313 let div = n as f64;
314 for i in 0..n {
315 let xi = T::to_f64(self.get(i)).unwrap();
316 mu += xi;
317 sum += xi * xi;
318 }
319 mu /= div;
320 sum / div - mu.powi(2)
321 }
322 fn std_dev(&self) -> f64
324 where
325 T: Number,
326 {
327 self.variance().sqrt()
328 }
329 fn mean_by(&self) -> f64
331 where
332 T: Number,
333 {
334 self.sum().to_f64().unwrap() / self.shape() as f64
335 }
336}
337
338pub trait ArrayView2<T: Debug + Display + Copy + Sized>: Array<T, (usize, usize)> {
340 fn max(&self, axis: u8) -> Vec<T>
342 where
343 T: Number + PartialOrd,
344 {
345 let (nrows, ncols) = self.shape();
346 let max_f = |max: T, r: usize, c: usize| -> T {
347 let v = self.get((r, c));
348 match T::gt(v, &max) {
349 true => *v,
350 _ => max,
351 }
352 };
353 match axis {
354 0 => (0..ncols)
355 .map(move |c| (0..nrows).fold(T::min_value(), |max, r| max_f(max, r, c)))
356 .collect(),
357 _ => (0..nrows)
358 .map(move |r| (0..ncols).fold(T::min_value(), |max, c| max_f(max, r, c)))
359 .collect(),
360 }
361 }
362 fn sum(&self, axis: u8) -> Vec<T>
364 where
365 T: Number,
366 {
367 let (nrows, ncols) = self.shape();
368 match axis {
369 0 => (0..ncols)
370 .map(move |c| (0..nrows).map(|r| *self.get((r, c))).sum())
371 .collect(),
372 _ => (0..nrows)
373 .map(move |r| (0..ncols).map(|c| *self.get((r, c))).sum())
374 .collect(),
375 }
376 }
377 fn min(&self, axis: u8) -> Vec<T>
379 where
380 T: Number + PartialOrd,
381 {
382 let (nrows, ncols) = self.shape();
383 let min_f = |min: T, r: usize, c: usize| -> T {
384 let v = self.get((r, c));
385 match T::lt(v, &min) {
386 true => *v,
387 _ => min,
388 }
389 };
390 match axis {
391 0 => (0..ncols)
392 .map(move |c| (0..nrows).fold(T::max_value(), |min, r| min_f(min, r, c)))
393 .collect(),
394 _ => (0..nrows)
395 .map(move |r| (0..ncols).fold(T::max_value(), |min, c| min_f(min, r, c)))
396 .collect(),
397 }
398 }
399 fn argmax(&self, axis: u8) -> Vec<usize>
401 where
402 T: Number + PartialOrd,
403 {
404 let max_f = |max: (T, usize), v: (T, usize)| -> (T, usize) {
406 match T::gt(&v.0, &max.0) {
407 true => v,
408 _ => max,
409 }
410 };
411 let (nrows, ncols) = self.shape();
412 match axis {
413 0 => (0..ncols)
414 .map(move |c| {
415 (0..nrows).fold((T::min_value(), 0), |max, r| {
416 max_f(max, (*self.get((r, c)), r))
417 })
418 })
419 .map(|(_, i)| i)
420 .collect(),
421 _ => (0..nrows)
422 .map(move |r| {
423 (0..ncols).fold((T::min_value(), 0), |max, c| {
424 max_f(max, (*self.get((r, c)), c))
425 })
426 })
427 .map(|(_, i)| i)
428 .collect(),
429 }
430 }
431 fn mean_by(&self, axis: u8) -> Vec<f64>
435 where
436 T: Number,
437 {
438 let (n, m) = match axis {
439 0 => {
440 let (n, m) = self.shape();
441 (m, n)
442 }
443 _ => self.shape(),
444 };
445
446 let mut x: Vec<f64> = vec![0f64; n];
447
448 let div = m as f64;
449
450 for (i, x_i) in x.iter_mut().enumerate().take(n) {
451 for j in 0..m {
452 *x_i += match axis {
453 0 => T::to_f64(self.get((j, i))).unwrap(),
454 _ => T::to_f64(self.get((i, j))).unwrap(),
455 };
456 }
457 *x_i /= div;
458 }
459
460 x
461 }
462 fn variance(&self, axis: u8) -> Vec<f64>
464 where
465 T: Number + RealNumber,
466 {
467 let (n, m) = match axis {
468 0 => {
469 let (n, m) = self.shape();
470 (m, n)
471 }
472 _ => self.shape(),
473 };
474
475 let mut x: Vec<f64> = vec![0f64; n];
476
477 let div = m as f64;
478
479 for (i, x_i) in x.iter_mut().enumerate().take(n) {
480 let mut mu = 0f64;
481 let mut sum = 0f64;
482 for j in 0..m {
483 let a = match axis {
484 0 => T::to_f64(self.get((j, i))).unwrap(),
485 _ => T::to_f64(self.get((i, j))).unwrap(),
486 };
487 mu += a;
488 sum += a * a;
489 }
490 mu /= div;
491 *x_i = sum / div - mu.powi(2);
492 }
493
494 x
495 }
496 fn std_dev(&self, axis: u8) -> Vec<f64>
498 where
499 T: Number + RealNumber,
500 {
501 let mut x = self.variance(axis);
502
503 let n = match axis {
504 0 => self.shape().1,
505 _ => self.shape().0,
506 };
507
508 for x_i in x.iter_mut().take(n) {
509 *x_i = x_i.sqrt();
510 }
511
512 x
513 }
514 fn cov(&self, cov: &mut dyn MutArrayView2<f64>)
516 where
517 T: Number,
518 {
519 let (m, n) = self.shape();
520
521 let mu = self.mean_by(0);
522
523 for k in 0..m {
524 for i in 0..n {
525 for j in 0..=i {
526 cov.add_element_mut(
527 (i, j),
528 (self.get((k, i)).to_f64().unwrap() - mu[i])
529 * (self.get((k, j)).to_f64().unwrap() - mu[j]),
530 );
531 }
532 }
533 }
534
535 let m = (m - 1) as f64;
536
537 for i in 0..n {
538 for j in 0..=i {
539 cov.div_element_mut((i, j), m);
540 cov.set((j, i), *cov.get((i, j)));
541 }
542 }
543 }
544 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
546 let (nrows, ncols) = self.shape();
547 for r in 0..nrows {
548 let row: Vec<T> = (0..ncols).map(|c| *self.get((r, c))).collect();
549 writeln!(f, "{row:?}")?
550 }
551 Ok(())
552 }
553 fn norm(&self, p: f64) -> f64
555 where
556 T: Number,
557 {
558 if p.is_infinite() && p.is_sign_positive() {
559 self.iterator(0)
560 .map(|x| x.to_f64().unwrap().abs())
561 .fold(f64::NEG_INFINITY, |a, b| a.max(b))
562 } else if p.is_infinite() && p.is_sign_negative() {
563 self.iterator(0)
564 .map(|x| x.to_f64().unwrap().abs())
565 .fold(f64::INFINITY, |a, b| a.min(b))
566 } else {
567 let mut norm = 0f64;
568
569 for xi in self.iterator(0) {
570 norm += xi.to_f64().unwrap().abs().powf(p);
571 }
572
573 norm.powf(1f64 / p)
574 }
575 }
576 fn diag(&self) -> Vec<T> {
578 let (nrows, ncols) = self.shape();
579 let n = nrows.min(ncols);
580
581 (0..n).map(|i| *self.get((i, i))).collect()
582 }
583}
584
585pub trait MutArrayView1<T: Debug + Display + Copy + Sized>:
587 MutArray<T, usize> + ArrayView1<T>
588{
589 fn copy_from(&mut self, other: &dyn Array<T, usize>) {
591 self.iterator_mut(0)
592 .zip(other.iterator(0))
593 .for_each(|(s, o)| *s = *o);
594 }
595 fn abs_mut(&mut self)
597 where
598 T: Number + Signed,
599 {
600 self.iterator_mut(0).for_each(|v| *v = v.abs());
601 }
602 fn neg_mut(&mut self)
604 where
605 T: Number + Neg<Output = T>,
606 {
607 self.iterator_mut(0).for_each(|v| *v = -*v);
608 }
609 fn pow_mut(&mut self, p: T)
611 where
612 T: RealNumber,
613 {
614 self.iterator_mut(0).for_each(|v| *v = v.powf(p));
615 }
616 fn argsort_mut(&mut self) -> Vec<usize>
618 where
619 T: Number + PartialOrd,
620 {
621 let stack_size = 64;
622 let mut jstack = -1;
623 let mut l = 0;
624 let mut istack = vec![0; stack_size];
625 let mut ir = self.shape() - 1;
626 let mut index: Vec<usize> = (0..self.shape()).collect();
627
628 loop {
629 if ir - l < 7 {
630 for j in l + 1..=ir {
631 let a = *self.get(j);
632 let b = index[j];
633 let mut i: i32 = (j - 1) as i32;
634 while i >= l as i32 {
635 if *self.get(i as usize) <= a {
636 break;
637 }
638 self.set((i + 1) as usize, *self.get(i as usize));
639 index[(i + 1) as usize] = index[i as usize];
640 i -= 1;
641 }
642 self.set((i + 1) as usize, a);
643 index[(i + 1) as usize] = b;
644 }
645 if jstack < 0 {
646 break;
647 }
648 ir = istack[jstack as usize];
649 jstack -= 1;
650 l = istack[jstack as usize];
651 jstack -= 1;
652 } else {
653 let k = (l + ir) >> 1;
654 self.swap(k, l + 1);
655 index.swap(k, l + 1);
656 if self.get(l) > self.get(ir) {
657 self.swap(l, ir);
658 index.swap(l, ir);
659 }
660 if self.get(l + 1) > self.get(ir) {
661 self.swap(l + 1, ir);
662 index.swap(l + 1, ir);
663 }
664 if self.get(l) > self.get(l + 1) {
665 self.swap(l, l + 1);
666 index.swap(l, l + 1);
667 }
668 let mut i = l + 1;
669 let mut j = ir;
670 let a = *self.get(l + 1);
671 let b = index[l + 1];
672 loop {
673 loop {
674 i += 1;
675 if *self.get(i) >= a {
676 break;
677 }
678 }
679 loop {
680 j -= 1;
681 if *self.get(j) <= a {
682 break;
683 }
684 }
685 if j < i {
686 break;
687 }
688 self.swap(i, j);
689 index.swap(i, j);
690 }
691 self.set(l + 1, *self.get(j));
692 self.set(j, a);
693 index[l + 1] = index[j];
694 index[j] = b;
695 jstack += 2;
696
697 if jstack >= 64 {
698 panic!("stack size is too small.");
699 }
700
701 if ir - i + 1 >= j - l {
702 istack[jstack as usize] = ir;
703 istack[jstack as usize - 1] = i;
704 ir = j - 1;
705 } else {
706 istack[jstack as usize] = j - 1;
707 istack[jstack as usize - 1] = l;
708 l = i;
709 }
710 }
711 }
712
713 index
714 }
715 fn softmax_mut(&mut self)
717 where
718 T: RealNumber,
719 {
720 let max = self.max();
721 let mut z = T::zero();
722 self.iterator_mut(0).for_each(|v| {
723 *v = (*v - max).exp();
724 z += *v;
725 });
726 self.iterator_mut(0).for_each(|v| *v /= z);
727 }
728}
729
730pub trait MutArrayView2<T: Debug + Display + Copy + Sized>:
732 MutArray<T, (usize, usize)> + ArrayView2<T>
733{
734 fn copy_from(&mut self, other: &dyn Array<T, (usize, usize)>) {
736 self.iterator_mut(0)
737 .zip(other.iterator(0))
738 .for_each(|(s, o)| *s = *o);
739 }
740 fn abs_mut(&mut self)
742 where
743 T: Number + Signed,
744 {
745 self.iterator_mut(0).for_each(|v| *v = v.abs());
746 }
747 fn neg_mut(&mut self)
749 where
750 T: Number + Neg<Output = T>,
751 {
752 self.iterator_mut(0).for_each(|v| *v = -*v);
753 }
754 fn pow_mut(&mut self, p: T)
756 where
757 T: RealNumber,
758 {
759 self.iterator_mut(0).for_each(|v| *v = v.powf(p));
760 }
761 fn scale_mut(&mut self, mean: &[T], std: &[T], axis: u8)
763 where
764 T: Number,
765 {
766 let (n, m) = match axis {
767 0 => {
768 let (n, m) = self.shape();
769 (m, n)
770 }
771 _ => self.shape(),
772 };
773
774 for i in 0..n {
775 for j in 0..m {
776 match axis {
777 0 => self.set((j, i), (*self.get((j, i)) - mean[i]) / std[i]),
778 _ => self.set((i, j), (*self.get((i, j)) - mean[i]) / std[i]),
779 }
780 }
781 }
782 }
783}
784
785pub trait Array1<T: Debug + Display + Copy + Sized>: MutArrayView1<T> + Sized + Clone {
787 fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a>;
789 fn slice_mut<'a>(&'a mut self, range: Range<usize>) -> Box<dyn MutArrayView1<T> + 'a>;
791 fn fill(len: usize, value: T) -> Self
793 where
794 Self: Sized;
795 fn from_iterator<I: Iterator<Item = T>>(iter: I, len: usize) -> Self
797 where
798 Self: Sized;
799 fn from_vec_slice(slice: &[T]) -> Self
801 where
802 Self: Sized;
803 fn from_slice(slice: &'_ dyn ArrayView1<T>) -> Self
805 where
806 Self: Sized;
807 fn zeros(len: usize) -> Self
809 where
810 T: Number,
811 Self: Sized,
812 {
813 Self::fill(len, T::zero())
814 }
815 fn ones(len: usize) -> Self
817 where
818 T: Number,
819 Self: Sized,
820 {
821 Self::fill(len, T::one())
822 }
823 fn rand(len: usize) -> Self
825 where
826 T: RealNumber,
827 Self: Sized,
828 {
829 Self::from_iterator((0..len).map(|_| T::rand()), len)
830 }
831 fn add_scalar(&self, x: T) -> Self
833 where
834 T: Number,
835 Self: Sized,
836 {
837 let mut result = self.clone();
838 result.add_scalar_mut(x);
839 result
840 }
841 fn sub_scalar(&self, x: T) -> Self
843 where
844 T: Number,
845 Self: Sized,
846 {
847 let mut result = self.clone();
848 result.sub_scalar_mut(x);
849 result
850 }
851 fn div_scalar(&self, x: T) -> Self
853 where
854 T: Number,
855 Self: Sized,
856 {
857 let mut result = self.clone();
858 result.div_scalar_mut(x);
859 result
860 }
861 fn mul_scalar(&self, x: T) -> Self
863 where
864 T: Number,
865 Self: Sized,
866 {
867 let mut result = self.clone();
868 result.mul_scalar_mut(x);
869 result
870 }
871 fn add(&self, other: &dyn Array<T, usize>) -> Self
873 where
874 T: Number,
875 Self: Sized,
876 {
877 let mut result = self.clone();
878 result.add_mut(other);
879 result
880 }
881 fn sub(&self, other: &impl Array1<T>) -> Self
883 where
884 T: Number,
885 Self: Sized,
886 {
887 let mut result = self.clone();
888 result.sub_mut(other);
889 result
890 }
891 fn mul(&self, other: &dyn Array<T, usize>) -> Self
893 where
894 T: Number,
895 Self: Sized,
896 {
897 let mut result = self.clone();
898 result.mul_mut(other);
899 result
900 }
901 fn div(&self, other: &dyn Array<T, usize>) -> Self
903 where
904 T: Number,
905 Self: Sized,
906 {
907 let mut result = self.clone();
908 result.div_mut(other);
909 result
910 }
911 fn take(&self, index: &[usize]) -> Self
913 where
914 Self: Sized,
915 {
916 let len = self.shape();
917 assert!(
918 index.iter().all(|&i| i < len),
919 "All indices in `take` should be < {len}"
920 );
921 Self::from_iterator(index.iter().map(move |&i| *self.get(i)), index.len())
922 }
923 fn abs(&self) -> Self
925 where
926 T: Number + Signed,
927 Self: Sized,
928 {
929 let mut result = self.clone();
930 result.abs_mut();
931 result
932 }
933 fn neg(&self) -> Self
935 where
936 T: Number + Neg<Output = T>,
937 Self: Sized,
938 {
939 let mut result = self.clone();
940 result.neg_mut();
941 result
942 }
943 fn pow(&self, p: T) -> Self
945 where
946 T: RealNumber,
947 Self: Sized,
948 {
949 let mut result = self.clone();
950 result.pow_mut(p);
951 result
952 }
953 fn argsort(&self) -> Vec<usize>
955 where
956 T: Number + PartialOrd,
957 {
958 let mut v = self.clone();
959 v.argsort_mut()
960 }
961 fn map<O: Debug + Display + Copy + Sized, A: Array1<O>, F: FnMut(&T) -> O>(self, f: F) -> A {
963 let len = self.shape();
964 A::from_iterator(self.iterator(0).map(f), len)
965 }
966 fn softmax(&self) -> Self
968 where
969 T: RealNumber,
970 Self: Sized,
971 {
972 let mut result = self.clone();
973 result.softmax_mut();
974 result
975 }
976 fn xa(&self, a_transpose: bool, a: &dyn ArrayView2<T>) -> Self
978 where
979 T: Number,
980 Self: Sized,
981 {
982 let (nrows, ncols) = a.shape();
983 let len = self.shape();
984 let (d1, d2) = match a_transpose {
985 true => (ncols, nrows),
986 _ => (nrows, ncols),
987 };
988 assert!(
989 d1 == len,
990 "Can not multiply {nrows}x{ncols} matrix by {len} vector"
991 );
992 let mut result = Self::zeros(d2);
993 for i in 0..d2 {
994 let mut s = T::zero();
995 for j in 0..d1 {
996 match a_transpose {
997 true => s += *a.get((i, j)) * *self.get(j),
998 _ => s += *a.get((j, i)) * *self.get(j),
999 }
1000 }
1001 result.set(i, s);
1002 }
1003 result
1004 }
1005
1006 fn approximate_eq(&self, other: &Self, error: T) -> bool
1008 where
1009 T: Number + RealNumber,
1010 Self: Sized,
1011 {
1012 (self.sub(other)).iterator(0).all(|v| v.abs() <= error)
1013 }
1014}
1015
1016pub trait Array2<T: Debug + Display + Copy + Sized>: MutArrayView2<T> + Sized + Clone {
1018 fn fill(nrows: usize, ncols: usize, value: T) -> Self;
1020 fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a>
1022 where
1023 Self: Sized;
1024 fn slice_mut<'a>(
1026 &'a mut self,
1027 rows: Range<usize>,
1028 cols: Range<usize>,
1029 ) -> Box<dyn MutArrayView2<T> + 'a>
1030 where
1031 Self: Sized;
1032 fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self;
1034 fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a>
1036 where
1037 Self: Sized;
1038 fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a>
1040 where
1041 Self: Sized;
1042 fn zeros(nrows: usize, ncols: usize) -> Self
1044 where
1045 T: Number,
1046 {
1047 Self::fill(nrows, ncols, T::zero())
1048 }
1049 fn ones(nrows: usize, ncols: usize) -> Self
1051 where
1052 T: Number,
1053 {
1054 Self::fill(nrows, ncols, T::one())
1055 }
1056 fn eye(size: usize) -> Self
1058 where
1059 T: Number,
1060 {
1061 let mut matrix = Self::zeros(size, size);
1062
1063 for i in 0..size {
1064 matrix.set((i, i), T::one());
1065 }
1066
1067 matrix
1068 }
1069 fn rand(nrows: usize, ncols: usize) -> Self
1071 where
1072 T: RealNumber,
1073 {
1074 Self::from_iterator((0..nrows * ncols).map(|_| T::rand()), nrows, ncols, 0)
1075 }
1076 fn from_slice(slice: &dyn ArrayView2<T>) -> Self {
1078 let (nrows, ncols) = slice.shape();
1079 Self::from_iterator(slice.iterator(0).cloned(), nrows, ncols, 0)
1080 }
1081 fn from_row(slice: &dyn ArrayView1<T>) -> Self {
1083 let ncols = slice.shape();
1084 Self::from_iterator(slice.iterator(0).cloned(), 1, ncols, 0)
1085 }
1086 fn from_column(slice: &dyn ArrayView1<T>) -> Self {
1088 let nrows = slice.shape();
1089 Self::from_iterator(slice.iterator(0).cloned(), nrows, 1, 0)
1090 }
1091 fn transpose(&self) -> Self {
1093 let (nrows, ncols) = self.shape();
1094 let mut m = Self::fill(ncols, nrows, *self.get((0, 0)));
1095 for c in 0..ncols {
1096 for r in 0..nrows {
1097 m.set((c, r), *self.get((r, c)));
1098 }
1099 }
1100 m
1101 }
1102 fn reshape(&self, nrows: usize, ncols: usize, axis: u8) -> Self {
1104 let (onrows, oncols) = self.shape();
1105
1106 assert!(
1107 nrows * ncols == onrows * oncols,
1108 "Can't reshape {onrows}x{oncols} array into a {nrows}x{ncols} array"
1109 );
1110
1111 Self::from_iterator(self.iterator(0).cloned(), nrows, ncols, axis)
1112 }
1113 fn matmul(&self, other: &dyn ArrayView2<T>) -> Self
1115 where
1116 T: Number,
1117 {
1118 let (nrows, ncols) = self.shape();
1119 let (o_nrows, o_ncols) = other.shape();
1120 assert!(
1121 ncols == o_nrows,
1122 "Can't multiply {nrows}x{ncols} and {o_nrows}x{o_ncols} matrices"
1123 );
1124 let inner_d = ncols;
1125 let mut result = Self::zeros(nrows, o_ncols);
1126
1127 for r in 0..nrows {
1128 for c in 0..o_ncols {
1129 let mut s = T::zero();
1130 for i in 0..inner_d {
1131 s += *self.get((r, i)) * *other.get((i, c));
1132 }
1133 result.set((r, c), s);
1134 }
1135 }
1136
1137 result
1138 }
1139 fn ab(&self, a_transpose: bool, b: &dyn ArrayView2<T>, b_transpose: bool) -> Self
1141 where
1142 T: Number,
1143 {
1144 if !a_transpose && !b_transpose {
1145 self.matmul(b)
1146 } else {
1147 let (nrows, ncols) = self.shape();
1148 let (o_nrows, o_ncols) = b.shape();
1149 let (d1, d2, d3, d4) = match (a_transpose, b_transpose) {
1150 (true, false) => (nrows, ncols, o_ncols, o_nrows),
1151 (false, true) => (ncols, nrows, o_nrows, o_ncols),
1152 _ => (nrows, ncols, o_nrows, o_ncols),
1153 };
1154 if d1 != d4 {
1155 panic!("Can not multiply {d2}x{d1} by {d4}x{d3} matrices");
1156 }
1157 let mut result = Self::zeros(d2, d3);
1158 for r in 0..d2 {
1159 for c in 0..d3 {
1160 let mut s = T::zero();
1161 for i in 0..d1 {
1162 match (a_transpose, b_transpose) {
1163 (true, false) => s += *self.get((i, r)) * *b.get((i, c)),
1164 (false, true) => s += *self.get((r, i)) * *b.get((c, i)),
1165 _ => s += *self.get((i, r)) * *b.get((c, i)),
1166 }
1167 }
1168 result.set((r, c), s);
1169 }
1170 }
1171 result
1172 }
1173 }
1174 fn ax(&self, a_transpose: bool, x: &dyn ArrayView1<T>) -> Self
1176 where
1177 T: Number,
1178 {
1179 let (nrows, ncols) = self.shape();
1180 let len = x.shape();
1181 let (d1, d2) = match a_transpose {
1182 true => (ncols, nrows),
1183 _ => (nrows, ncols),
1184 };
1185 assert!(
1186 d2 == len,
1187 "Can not multiply {nrows}x{ncols} matrix by {len} vector"
1188 );
1189 let mut result = Self::zeros(d1, 1);
1190 for i in 0..d1 {
1191 let mut s = T::zero();
1192 for j in 0..d2 {
1193 match a_transpose {
1194 true => s += *self.get((j, i)) * *x.get(j),
1195 _ => s += *self.get((i, j)) * *x.get(j),
1196 }
1197 }
1198 result.set((i, 0), s);
1199 }
1200 result
1201 }
1202 fn concatenate_1d<'a>(arrays: &'a [&'a dyn ArrayView1<T>], axis: u8) -> Self {
1204 assert!(
1205 axis == 1 || axis == 0,
1206 "For two dimensional array `axis` should be either 0 or 1"
1207 );
1208 assert!(!arrays.is_empty(), "Can't concatenate an empty array");
1209 assert!(
1210 arrays.windows(2).all(|w| w[0].shape() == w[1].shape()),
1211 "Can't concatenate arrays of different sizes"
1212 );
1213
1214 let first = &arrays[0];
1215 let tail = &arrays[1..];
1216
1217 match axis {
1218 0 => Self::from_iterator(
1219 tail.iter()
1220 .fold(first.iterator(0), |acc, i| {
1221 Box::new(acc.chain(i.iterator(0)))
1222 })
1223 .cloned(),
1224 arrays.len(),
1225 arrays[0].shape(),
1226 axis,
1227 ),
1228 _ => Self::from_iterator(
1229 tail.iter()
1230 .fold(first.iterator(0), |acc, i| {
1231 Box::new(acc.chain(i.iterator(0)))
1232 })
1233 .cloned(),
1234 arrays[0].shape(),
1235 arrays.len(),
1236 axis,
1237 ),
1238 }
1239 }
1240 fn concatenate_2d<'a>(arrays: &'a [&'a dyn ArrayView2<T>], axis: u8) -> Self {
1242 assert!(
1243 axis == 1 || axis == 0,
1244 "For two dimensional array `axis` should be either 0 or 1"
1245 );
1246 assert!(!arrays.is_empty(), "Can't concatenate an empty array");
1247 if axis == 0 {
1248 assert!(
1249 arrays.windows(2).all(|w| w[0].shape().1 == w[1].shape().1),
1250 "Number of columns in all arrays should match"
1251 );
1252 } else {
1253 assert!(
1254 arrays.windows(2).all(|w| w[0].shape().0 == w[1].shape().0),
1255 "Number of rows in all arrays should match"
1256 );
1257 }
1258
1259 let first = &arrays[0];
1260 let tail = &arrays[1..];
1261
1262 match axis {
1263 0 => {
1264 let (nrows, ncols) = (
1265 arrays.iter().map(|a| a.shape().0).sum(),
1266 arrays[0].shape().1,
1267 );
1268 Self::from_iterator(
1269 tail.iter()
1270 .fold(first.iterator(0), |acc, i| {
1271 Box::new(acc.chain(i.iterator(0)))
1272 })
1273 .cloned(),
1274 nrows,
1275 ncols,
1276 axis,
1277 )
1278 }
1279 _ => {
1280 let (nrows, ncols) = (
1281 arrays[0].shape().0,
1282 (arrays.iter().map(|a| a.shape().1).sum()),
1283 );
1284 Self::from_iterator(
1285 tail.iter()
1286 .fold(first.iterator(1), |acc, i| {
1287 Box::new(acc.chain(i.iterator(1)))
1288 })
1289 .cloned(),
1290 nrows,
1291 ncols,
1292 axis,
1293 )
1294 }
1295 }
1296 }
1297 fn merge_1d<'a>(&'a self, arrays: &'a [&'a dyn ArrayView1<T>], axis: u8, append: bool) -> Self {
1299 assert!(
1300 axis == 1 || axis == 0,
1301 "For two dimensional array `axis` should be either 0 or 1"
1302 );
1303 assert!(!arrays.is_empty(), "Can't merge with an empty array");
1304
1305 let first = &arrays[0];
1306 let tail = &arrays[1..];
1307
1308 match (append, axis) {
1309 (true, 0) => {
1310 let (nrows, ncols) = (self.shape().0 + arrays.len(), self.shape().1);
1311 Self::from_iterator(
1312 self.iterator(0)
1313 .chain(tail.iter().fold(first.iterator(0), |acc, i| {
1314 Box::new(acc.chain(i.iterator(0)))
1315 }))
1316 .cloned(),
1317 nrows,
1318 ncols,
1319 axis,
1320 )
1321 }
1322 (true, 1) => {
1323 let (nrows, ncols) = (self.shape().0, self.shape().1 + arrays.len());
1324 Self::from_iterator(
1325 self.iterator(1)
1326 .chain(tail.iter().fold(first.iterator(0), |acc, i| {
1327 Box::new(acc.chain(i.iterator(0)))
1328 }))
1329 .cloned(),
1330 nrows,
1331 ncols,
1332 axis,
1333 )
1334 }
1335 (false, 0) => {
1336 let (nrows, ncols) = (self.shape().0 + arrays.len(), self.shape().1);
1337 Self::from_iterator(
1338 tail.iter()
1339 .fold(first.iterator(0), |acc, i| {
1340 Box::new(acc.chain(i.iterator(0)))
1341 })
1342 .chain(self.iterator(0))
1343 .cloned(),
1344 nrows,
1345 ncols,
1346 axis,
1347 )
1348 }
1349 _ => {
1350 let (nrows, ncols) = (self.shape().0, self.shape().1 + arrays.len());
1351 Self::from_iterator(
1352 tail.iter()
1353 .fold(first.iterator(0), |acc, i| {
1354 Box::new(acc.chain(i.iterator(0)))
1355 })
1356 .chain(self.iterator(1))
1357 .cloned(),
1358 nrows,
1359 ncols,
1360 axis,
1361 )
1362 }
1363 }
1364 }
1365 fn v_stack(&self, other: &dyn ArrayView2<T>) -> Self {
1367 let (nrows, ncols) = self.shape();
1368 let (other_nrows, other_ncols) = other.shape();
1369
1370 assert!(
1371 ncols == other_ncols,
1372 "For vertical stack number of rows in both arrays should match"
1373 );
1374 Self::from_iterator(
1375 self.iterator(0).chain(other.iterator(0)).cloned(),
1376 nrows + other_nrows,
1377 ncols,
1378 0,
1379 )
1380 }
1381 fn h_stack(&self, other: &dyn ArrayView2<T>) -> Self {
1383 let (nrows, ncols) = self.shape();
1384 let (other_nrows, other_ncols) = other.shape();
1385
1386 assert!(
1387 nrows == other_nrows,
1388 "For horizontal stack number of rows in both arrays should match"
1389 );
1390 Self::from_iterator(
1391 self.iterator(1).chain(other.iterator(1)).cloned(),
1392 nrows,
1393 other_ncols + ncols,
1394 1,
1395 )
1396 }
1397 fn map<O: Debug + Display + Copy + Sized, A: Array2<O>, F: FnMut(&T) -> O>(self, f: F) -> A {
1399 let (nrows, ncols) = self.shape();
1400 A::from_iterator(self.iterator(0).map(f), nrows, ncols, 0)
1401 }
1402 fn row_iter<'a>(&'a self) -> Box<dyn Iterator<Item = Box<dyn ArrayView1<T> + 'a>> + 'a> {
1404 Box::new((0..self.shape().0).map(move |r| self.get_row(r)))
1405 }
1406 fn col_iter<'a>(&'a self) -> Box<dyn Iterator<Item = Box<dyn ArrayView1<T> + 'a>> + 'a> {
1408 Box::new((0..self.shape().1).map(move |r| self.get_col(r)))
1409 }
1410 fn take(&self, index: &[usize], axis: u8) -> Self {
1412 let (nrows, ncols) = self.shape();
1413
1414 match axis {
1415 0 => {
1416 assert!(
1417 index.iter().all(|&i| i < nrows),
1418 "All indices in `take` should be < {nrows}"
1419 );
1420 Self::from_iterator(
1421 index
1422 .iter()
1423 .flat_map(move |&r| (0..ncols).map(move |c| self.get((r, c))))
1424 .cloned(),
1425 index.len(),
1426 ncols,
1427 0,
1428 )
1429 }
1430 _ => {
1431 assert!(
1432 index.iter().all(|&i| i < ncols),
1433 "All indices in `take` should be < {ncols}"
1434 );
1435 Self::from_iterator(
1436 (0..nrows)
1437 .flat_map(move |r| index.iter().map(move |&c| self.get((r, c))))
1438 .cloned(),
1439 nrows,
1440 index.len(),
1441 0,
1442 )
1443 }
1444 }
1445 }
1446 fn take_column(&self, column_index: usize) -> Self {
1448 self.take(&[column_index], 1)
1449 }
1450 fn add_scalar(&self, x: T) -> Self
1452 where
1453 T: Number,
1454 {
1455 let mut result = self.clone();
1456 result.add_scalar_mut(x);
1457 result
1458 }
1459 fn sub_scalar(&self, x: T) -> Self
1461 where
1462 T: Number,
1463 {
1464 let mut result = self.clone();
1465 result.sub_scalar_mut(x);
1466 result
1467 }
1468 fn div_scalar(&self, x: T) -> Self
1470 where
1471 T: Number,
1472 {
1473 let mut result = self.clone();
1474 result.div_scalar_mut(x);
1475 result
1476 }
1477 fn mul_scalar(&self, x: T) -> Self
1479 where
1480 T: Number,
1481 {
1482 let mut result = self.clone();
1483 result.mul_scalar_mut(x);
1484 result
1485 }
1486 fn add(&self, other: &dyn Array<T, (usize, usize)>) -> Self
1488 where
1489 T: Number,
1490 {
1491 let mut result = self.clone();
1492 result.add_mut(other);
1493 result
1494 }
1495 fn sub(&self, other: &dyn Array<T, (usize, usize)>) -> Self
1497 where
1498 T: Number,
1499 {
1500 let mut result = self.clone();
1501 result.sub_mut(other);
1502 result
1503 }
1504 fn mul(&self, other: &dyn Array<T, (usize, usize)>) -> Self
1506 where
1507 T: Number,
1508 {
1509 let mut result = self.clone();
1510 result.mul_mut(other);
1511 result
1512 }
1513 fn div(&self, other: &dyn Array<T, (usize, usize)>) -> Self
1515 where
1516 T: Number,
1517 {
1518 let mut result = self.clone();
1519 result.div_mut(other);
1520 result
1521 }
1522 fn abs(&self) -> Self
1524 where
1525 T: Number + Signed,
1526 {
1527 let mut result = self.clone();
1528 result.abs_mut();
1529 result
1530 }
1531 fn neg(&self) -> Self
1533 where
1534 T: Number + Neg<Output = T>,
1535 {
1536 let mut result = self.clone();
1537 result.neg_mut();
1538 result
1539 }
1540 fn pow(&self, p: T) -> Self
1542 where
1543 T: RealNumber,
1544 {
1545 let mut result = self.clone();
1546 result.pow_mut(p);
1547 result
1548 }
1549
1550 fn column_mean(&self) -> Vec<f64>
1552 where
1553 T: Number + ToPrimitive,
1554 {
1555 let mut mean = vec![0f64; self.shape().1];
1556
1557 for r in 0..self.shape().0 {
1558 for (c, mean_c) in mean.iter_mut().enumerate().take(self.shape().1) {
1559 let value: f64 = self.get((r, c)).to_f64().unwrap();
1560 *mean_c += value;
1561 }
1562 }
1563
1564 for mean_i in mean.iter_mut() {
1565 *mean_i /= self.shape().0 as f64;
1566 }
1567
1568 mean
1569 }
1570
1571 fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>) {
1573 for (r, result_r) in result.iter_mut().enumerate().take(self.shape().0) {
1574 *result_r = *self.get((r, col));
1575 }
1576 }
1577
1578 fn approximate_eq(&self, other: &Self, error: T) -> bool
1580 where
1581 T: Number + RealNumber,
1582 {
1583 (self.sub(other)).iterator(0).all(|v| v.abs() <= error)
1584 && (self.sub(other)).iterator(1).all(|v| v.abs() <= error)
1585 }
1586}
1587
1588#[cfg(test)]
1589mod tests {
1590 use super::*;
1591 use crate::linalg::basic::arrays::{Array, Array2, ArrayView2, MutArrayView2};
1592 use crate::linalg::basic::matrix::DenseMatrix;
1593 use approx::relative_eq;
1594
1595 #[test]
1596 fn test_dot() {
1597 let a = vec![1, 2, 3];
1598 let b = vec![1.0, 2.0, 3.0];
1599 let c = vec![4.0, 5.0, 6.0];
1600
1601 assert_eq!(b.slice(0..2).dot(c.slice(0..2).as_ref()), 14.);
1602 assert_eq!(b.slice(0..3).dot(&c), 32.);
1603 assert_eq!(b.dot(&c), 32.);
1604 assert_eq!(a.dot(&a), 14);
1605 }
1606
1607 #[test]
1608 #[should_panic]
1609 fn test_failed_dot() {
1610 let a = vec![1, 2, 3];
1611
1612 a.slice(0..2).dot(a.slice(0..3).as_ref());
1613 }
1614
1615 #[test]
1616 fn test_vec_chaining() {
1617 let mut x: Vec<i32> = Vec::zeros(6);
1618
1619 x.add_scalar(5);
1620 assert_eq!(vec!(5, 5, 5, 5, 5, 5), x.add_scalar(5));
1621 {
1622 let mut x_s = x.slice_mut(0..3);
1623 x_s.add_scalar_mut(1);
1624 }
1625
1626 assert_eq!(vec!(1, 1, 1, 0, 0, 0), x);
1627 }
1628
1629 #[test]
1630 fn test_vec_norm() {
1631 let v = vec![3., -2., 6.];
1632 assert_eq!(v.norm(1.), 11.);
1633 assert_eq!(v.norm(2.), 7.);
1634 assert_eq!(v.norm(f64::INFINITY), 6.);
1635 assert_eq!(v.norm(f64::NEG_INFINITY), 2.);
1636 }
1637
1638 #[test]
1639 fn test_vec_unique() {
1640 let n = vec![1, 2, 2, 3, 4, 5, 3, 2];
1641 assert_eq!(
1642 n.unique_with_indices(),
1643 (vec!(1, 2, 3, 4, 5), vec!(0, 1, 1, 2, 3, 4, 2, 1))
1644 );
1645 assert_eq!(n.unique(), vec!(1, 2, 3, 4, 5));
1646 assert_eq!(Vec::<i32>::zeros(100).unique(), vec![0]);
1647 assert_eq!(Vec::<i32>::zeros(100).slice(0..10).unique(), vec![0]);
1648 }
1649
1650 #[test]
1651 fn test_vec_var_std() {
1652 assert_eq!(vec![1., 2., 3., 4., 5.].variance(), 2.);
1653 assert_eq!(vec![1., 2.].std_dev(), 0.5);
1654 assert_eq!(vec![1.].variance(), 0.0);
1655 assert_eq!(vec![1.].std_dev(), 0.0);
1656 }
1657
1658 #[test]
1659 fn test_vec_abs() {
1660 let mut x = vec![-1, 2, -3];
1661 x.abs_mut();
1662 assert_eq!(x, vec![1, 2, 3]);
1663 }
1664
1665 #[test]
1666 fn test_vec_neg() {
1667 let mut x = vec![-1, 2, -3];
1668 x.neg_mut();
1669 assert_eq!(x, vec![1, -2, 3]);
1670 }
1671
1672 #[test]
1673 fn test_vec_copy_from() {
1674 let x = vec![1, 2, 3];
1675 let mut y = Vec::<i32>::zeros(3);
1676 y.copy_from(&x);
1677 assert_eq!(y, vec![1, 2, 3]);
1678 }
1679
1680 #[test]
1681 fn test_vec_element_ops() {
1682 let mut x = vec![1, 2, 3, 4];
1683 x.slice_mut(0..1).mul_element_mut(0, 4);
1684 x.slice_mut(1..2).add_element_mut(0, 1);
1685 x.slice_mut(2..3).sub_element_mut(0, 1);
1686 x.slice_mut(3..4).div_element_mut(0, 4);
1687 assert_eq!(x, vec![4, 3, 2, 1]);
1688 }
1689
1690 #[test]
1691 fn test_vec_ops() {
1692 assert_eq!(vec![1, 2, 3, 4].mul_scalar(2), vec![2, 4, 6, 8]);
1693 assert_eq!(vec![1, 2, 3, 4].add_scalar(2), vec![3, 4, 5, 6]);
1694 assert_eq!(vec![1, 2, 3, 4].sub_scalar(1), vec![0, 1, 2, 3]);
1695 assert_eq!(vec![1, 2, 3, 4].div_scalar(2), vec![0, 1, 1, 2]);
1696 }
1697
1698 #[test]
1699 fn test_vec_init() {
1700 assert_eq!(Vec::<i32>::ones(3), vec![1, 1, 1]);
1701 assert_eq!(Vec::<i32>::zeros(3), vec![0, 0, 0]);
1702 }
1703
1704 #[test]
1705 fn test_vec_min_max() {
1706 assert_eq!(ArrayView1::min(&vec![1, 2, 3, 4, 5, 6]), 1);
1707 assert_eq!(ArrayView1::max(&vec![1, 2, 3, 4, 5, 6]), 6);
1708 }
1709
1710 #[test]
1711 fn test_vec_take() {
1712 assert_eq!(vec![1, 2, 3, 4, 5, 6].take(&[0, 4, 5]), vec![1, 5, 6]);
1713 }
1714
1715 #[test]
1716 fn test_vec_rand() {
1717 let r = Vec::<f32>::rand(4);
1718 assert!(r.iterator(0).all(|&e| e <= 1f32));
1719 assert!(r.iterator(0).all(|&e| e >= 0f32));
1720 assert!(r.iterator(0).copied().sum::<f32>() > 0f32);
1721 }
1722
1723 #[test]
1724 #[should_panic]
1725 fn test_failed_vec_take() {
1726 assert_eq!(vec![1, 2, 3, 4, 5, 6].take(&[10, 4, 5]), vec![1, 5, 6]);
1727 }
1728
1729 #[test]
1730 fn test_vec_quicksort() {
1731 let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
1732 assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.argsort());
1733
1734 let arr2 = vec![
1735 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6,
1736 1.0, 1.3, 1.4,
1737 ];
1738 assert_eq!(
1739 vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16],
1740 arr2.argsort()
1741 );
1742 }
1743
1744 #[test]
1745 fn test_vec_map() {
1746 let a = vec![1.0, 2.0, 3.0, 4.0];
1747 let expected = vec![2, 4, 6, 8];
1748 let result: Vec<i32> = a.map(|&v| v as i32 * 2);
1749 assert_eq!(result, expected);
1750 }
1751
1752 #[test]
1753 fn test_vec_mean() {
1754 let m = vec![1, 2, 3];
1755
1756 assert_eq!(m.mean_by(), 2.0);
1757 }
1758
1759 #[test]
1760 fn test_vec_max_diff() {
1761 let a = vec![1, 2, 3, 4, -5, 6];
1762 let b = vec![2, 3, 4, 1, 0, -12];
1763 assert_eq!(a.max_diff(&b), 18);
1764 assert_eq!(b.max_diff(&b), 0);
1765 }
1766
1767 #[test]
1768 fn test_vec_softmax() {
1769 let mut prob = vec![1., 2., 3.];
1770 prob.softmax_mut();
1771 assert!((prob[0] - 0.09).abs() < 0.01);
1772 assert!((prob[1] - 0.24).abs() < 0.01);
1773 assert!((prob[2] - 0.66).abs() < 0.01);
1774 }
1775
1776 #[test]
1777 fn test_xa() {
1778 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1779 assert_eq!(vec![7, 8].xa(false, &a), vec![39, 54, 69]);
1780 assert_eq!(vec![7, 8, 9].xa(true, &a), vec![50, 122]);
1781 }
1782
1783 #[test]
1784 fn test_min_max() {
1785 assert_eq!(
1786 DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]])
1787 .unwrap()
1788 .max(0),
1789 vec!(4, 5, 6)
1790 );
1791 assert_eq!(
1792 DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]])
1793 .unwrap()
1794 .max(1),
1795 vec!(3, 6)
1796 );
1797 assert_eq!(
1798 DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]])
1799 .unwrap()
1800 .min(0),
1801 vec!(1., 2., 3.)
1802 );
1803 assert_eq!(
1804 DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]])
1805 .unwrap()
1806 .min(1),
1807 vec!(1., 4.)
1808 );
1809 }
1810
1811 #[test]
1812 fn test_argmax() {
1813 assert_eq!(
1814 DenseMatrix::from_2d_array(&[&[1, 5, 3], &[4, 2, 6]])
1815 .unwrap()
1816 .argmax(0),
1817 vec!(1, 0, 1)
1818 );
1819 assert_eq!(
1820 DenseMatrix::from_2d_array(&[&[4, 2, 3], &[1, 5, 6]])
1821 .unwrap()
1822 .argmax(1),
1823 vec!(0, 2)
1824 );
1825 }
1826
1827 #[test]
1828 fn test_sum() {
1829 assert_eq!(
1830 DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]])
1831 .unwrap()
1832 .sum(0),
1833 vec!(5, 7, 9)
1834 );
1835 assert_eq!(
1836 DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]])
1837 .unwrap()
1838 .sum(1),
1839 vec!(6., 15.)
1840 );
1841 }
1842
1843 #[test]
1844 fn test_abs() {
1845 let mut x = DenseMatrix::from_2d_array(&[&[-1, 2, -3], &[4, -5, 6]]).unwrap();
1846 x.abs_mut();
1847 assert_eq!(
1848 x,
1849 DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap()
1850 );
1851 }
1852
1853 #[test]
1854 fn test_neg() {
1855 let mut x = DenseMatrix::from_2d_array(&[&[-1, 2, -3], &[4, -5, 6]]).unwrap();
1856 x.neg_mut();
1857 assert_eq!(
1858 x,
1859 DenseMatrix::from_2d_array(&[&[1, -2, 3], &[-4, 5, -6]]).unwrap()
1860 );
1861 }
1862
1863 #[test]
1864 fn test_copy_from() {
1865 let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1866 let mut y = DenseMatrix::<i32>::zeros(2, 3);
1867 y.copy_from(&x);
1868 assert_eq!(
1869 y,
1870 DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap()
1871 );
1872 }
1873
1874 #[test]
1875 fn test_init() {
1876 let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1877 assert_eq!(
1878 DenseMatrix::<i32>::zeros(2, 2),
1879 DenseMatrix::from_2d_array(&[&[0, 0], &[0, 0]]).unwrap()
1880 );
1881 assert_eq!(
1882 DenseMatrix::<i32>::ones(2, 2),
1883 DenseMatrix::from_2d_array(&[&[1, 1], &[1, 1]]).unwrap()
1884 );
1885 assert_eq!(
1886 DenseMatrix::<i32>::eye(3),
1887 DenseMatrix::from_2d_array(&[&[1, 0, 0], &[0, 1, 0], &[0, 0, 1]]).unwrap()
1888 );
1889 assert_eq!(
1890 DenseMatrix::from_slice(x.slice(0..2, 0..2).as_ref()), DenseMatrix::from_2d_array(&[&[1, 2], &[4, 5]]).unwrap()
1892 );
1893 assert_eq!(
1894 DenseMatrix::from_row(x.get_row(0).as_ref()), DenseMatrix::from_2d_array(&[&[1, 2, 3]]).unwrap()
1896 );
1897 assert_eq!(
1898 DenseMatrix::from_column(x.get_col(0).as_ref()), DenseMatrix::from_2d_array(&[&[1], &[4]]).unwrap()
1900 );
1901 }
1902
1903 #[test]
1904 fn test_transpose() {
1905 let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1906 assert_eq!(
1907 x.transpose(),
1908 DenseMatrix::from_2d_array(&[&[1, 4], &[2, 5], &[3, 6]]).unwrap()
1909 );
1910 }
1911
1912 #[test]
1913 fn test_reshape() {
1914 let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1915 assert_eq!(
1916 x.reshape(3, 2, 0),
1917 DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap()
1918 );
1919 assert_eq!(
1920 x.reshape(3, 2, 1),
1921 DenseMatrix::from_2d_array(&[&[1, 4], &[2, 5], &[3, 6]]).unwrap()
1922 );
1923 }
1924
1925 #[test]
1926 #[should_panic]
1927 fn test_failed_reshape() {
1928 let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1929 assert_eq!(
1930 x.reshape(4, 2, 0),
1931 DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap()
1932 );
1933 }
1934
1935 #[test]
1936 fn test_matmul() {
1937 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1938 let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
1939 assert_eq!(
1940 a.matmul(&(*b.slice(0..3, 0..2))),
1941 DenseMatrix::from_2d_array(&[&[22, 28], &[49, 64]]).unwrap()
1942 );
1943 assert_eq!(
1944 a.matmul(&b),
1945 DenseMatrix::from_2d_array(&[&[22, 28], &[49, 64]]).unwrap()
1946 );
1947 }
1948
1949 #[test]
1950 fn test_concat() {
1951 let a = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]]).unwrap();
1952 let b = DenseMatrix::from_2d_array(&[&[5, 6], &[7, 8]]).unwrap();
1953
1954 assert_eq!(
1955 DenseMatrix::concatenate_1d(&[&vec!(1, 2, 3), &vec!(4, 5, 6)], 0),
1956 DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap()
1957 );
1958 assert_eq!(
1959 DenseMatrix::concatenate_1d(&[&vec!(1, 2), &vec!(3, 4)], 1),
1960 DenseMatrix::from_2d_array(&[&[1, 3], &[2, 4]]).unwrap()
1961 );
1962 assert_eq!(
1963 DenseMatrix::concatenate_2d(&[&a, &b], 0),
1964 DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6], &[7, 8]]).unwrap()
1965 );
1966 assert_eq!(
1967 DenseMatrix::concatenate_2d(&[&a, &b], 1),
1968 DenseMatrix::from_2d_array(&[&[1, 2, 5, 6], &[3, 4, 7, 8]]).unwrap()
1969 );
1970 }
1971
1972 #[test]
1973 fn test_take() {
1974 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1975 let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
1976
1977 assert_eq!(
1978 a.take(&[0, 2], 1),
1979 DenseMatrix::from_2d_array(&[&[1, 3], &[4, 6]]).unwrap()
1980 );
1981 assert_eq!(
1982 b.take(&[0, 2], 0),
1983 DenseMatrix::from_2d_array(&[&[1, 2], &[5, 6]]).unwrap()
1984 );
1985 }
1986
1987 #[test]
1988 fn test_merge() {
1989 let a = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]]).unwrap();
1990
1991 assert_eq!(
1992 DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6], &[7, 8]]).unwrap(),
1993 a.merge_1d(&[&vec!(5, 6), &vec!(7, 8)], 0, true)
1994 );
1995 assert_eq!(
1996 DenseMatrix::from_2d_array(&[&[5, 6], &[7, 8], &[1, 2], &[3, 4]]).unwrap(),
1997 a.merge_1d(&[&vec!(5, 6), &vec!(7, 8)], 0, false)
1998 );
1999 assert_eq!(
2000 DenseMatrix::from_2d_array(&[&[1, 2, 5, 7], &[3, 4, 6, 8]]).unwrap(),
2001 a.merge_1d(&[&vec!(5, 6), &vec!(7, 8)], 1, true)
2002 );
2003 assert_eq!(
2004 DenseMatrix::from_2d_array(&[&[5, 7, 1, 2], &[6, 8, 3, 4]]).unwrap(),
2005 a.merge_1d(&[&vec!(5, 6), &vec!(7, 8)], 1, false)
2006 );
2007 }
2008
2009 #[test]
2010 fn test_ops() {
2011 assert_eq!(
2012 DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]])
2013 .unwrap()
2014 .mul_scalar(2),
2015 DenseMatrix::from_2d_array(&[&[2, 4], &[6, 8]]).unwrap()
2016 );
2017 assert_eq!(
2018 DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]])
2019 .unwrap()
2020 .add_scalar(2),
2021 DenseMatrix::from_2d_array(&[&[3, 4], &[5, 6]]).unwrap()
2022 );
2023 assert_eq!(
2024 DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]])
2025 .unwrap()
2026 .sub_scalar(1),
2027 DenseMatrix::from_2d_array(&[&[0, 1], &[2, 3]]).unwrap()
2028 );
2029 assert_eq!(
2030 DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]])
2031 .unwrap()
2032 .div_scalar(2),
2033 DenseMatrix::from_2d_array(&[&[0, 1], &[1, 2]]).unwrap()
2034 );
2035 }
2036
2037 #[test]
2038 fn test_rand() {
2039 let r = DenseMatrix::<f32>::rand(2, 2);
2040 assert!(r.iterator(0).all(|&e| e <= 1f32));
2041 assert!(r.iterator(0).all(|&e| e >= 0f32));
2042 assert!(r.iterator(0).copied().sum::<f32>() > 0f32);
2043 }
2044
2045 #[test]
2046 fn test_vstack() {
2047 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
2048 let b = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
2049 let expected = DenseMatrix::from_2d_array(&[
2050 &[1, 2, 3],
2051 &[4, 5, 6],
2052 &[7, 8, 9],
2053 &[1, 2, 3],
2054 &[4, 5, 6],
2055 ])
2056 .unwrap();
2057 let result = a.v_stack(&b);
2058 assert_eq!(result, expected);
2059 }
2060
2061 #[test]
2062 fn test_hstack() {
2063 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
2064 let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
2065 let expected =
2066 DenseMatrix::from_2d_array(&[&[1, 2, 3, 1, 2], &[4, 5, 6, 3, 4], &[7, 8, 9, 5, 6]])
2067 .unwrap();
2068 let result = a.h_stack(&b);
2069 assert_eq!(result, expected);
2070 }
2071
2072 #[test]
2073 fn test_map() {
2074 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
2075 let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
2076 let result: DenseMatrix<f64> = a.map(|&v| v as f64);
2077 assert_eq!(result, expected);
2078 }
2079
2080 #[test]
2081 fn scale() {
2082 let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap();
2083 let expected_0 = DenseMatrix::from_2d_array(&[&[-1., -1., -1.], &[1., 1., 1.]]).unwrap();
2084 let expected_1 =
2085 DenseMatrix::from_2d_array(&[&[-1.22, 0.0, 1.22], &[-1.22, 0.0, 1.22]]).unwrap();
2086
2087 {
2088 let mut m = m.clone();
2089 m.scale_mut(&m.mean_by(0), &m.std_dev(0), 0);
2090 assert!(relative_eq!(m, expected_0));
2091 }
2092
2093 m.scale_mut(&m.mean_by(1), &m.std_dev(1), 1);
2094 assert!(relative_eq!(m, expected_1, epsilon = 1e-2));
2095 }
2096
2097 #[test]
2098 fn test_pow_mut() {
2099 let mut a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
2100 a.pow_mut(2.0);
2101 assert_eq!(
2102 a,
2103 DenseMatrix::from_2d_array(&[&[1.0, 4.0, 9.0], &[16.0, 25.0, 36.0]]).unwrap()
2104 );
2105 }
2106
2107 #[test]
2108 fn test_ab() {
2109 let a = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]]).unwrap();
2110 let b = DenseMatrix::from_2d_array(&[&[5, 6], &[7, 8]]).unwrap();
2111 assert_eq!(
2112 a.ab(false, &b, false),
2113 DenseMatrix::from_2d_array(&[&[19, 22], &[43, 50]]).unwrap()
2114 );
2115 assert_eq!(
2116 a.ab(true, &b, false),
2117 DenseMatrix::from_2d_array(&[&[26, 30], &[38, 44]]).unwrap()
2118 );
2119 assert_eq!(
2120 a.ab(false, &b, true),
2121 DenseMatrix::from_2d_array(&[&[17, 23], &[39, 53]]).unwrap()
2122 );
2123 assert_eq!(
2124 a.ab(true, &b, true),
2125 DenseMatrix::from_2d_array(&[&[23, 31], &[34, 46]]).unwrap()
2126 );
2127 }
2128
2129 #[test]
2130 fn test_ax() {
2131 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
2132 assert_eq!(
2133 a.ax(false, &vec![7, 8, 9]).transpose(),
2134 DenseMatrix::from_2d_array(&[&[50, 122]]).unwrap()
2135 );
2136 assert_eq!(
2137 a.ax(true, &vec![7, 8]).transpose(),
2138 DenseMatrix::from_2d_array(&[&[39, 54, 69]]).unwrap()
2139 );
2140 }
2141
2142 #[test]
2143 fn diag() {
2144 let x = DenseMatrix::from_2d_array(&[&[0, 1, 2], &[3, 4, 5], &[6, 7, 8]]).unwrap();
2145 assert_eq!(x.diag(), vec![0, 4, 8]);
2146 }
2147
2148 #[test]
2149 fn test_cov() {
2150 let a = DenseMatrix::from_2d_array(&[
2151 &[64, 580, 29],
2152 &[66, 570, 33],
2153 &[68, 590, 37],
2154 &[69, 660, 46],
2155 &[73, 600, 55],
2156 ])
2157 .unwrap();
2158 let mut result = DenseMatrix::zeros(3, 3);
2159 let expected = DenseMatrix::from_2d_array(&[
2160 &[11.5, 50.0, 34.75],
2161 &[50.0, 1250.0, 205.0],
2162 &[34.75, 205.0, 110.0],
2163 ])
2164 .unwrap();
2165
2166 a.cov(&mut result);
2167
2168 assert_eq!(result, expected);
2169 }
2170
2171 #[test]
2172 fn test_from_iter() {
2173 let vec_a = Vec::from([64, 580, 29, 66, 570, 33]);
2174 let vec_a_len = vec_a.len();
2175 let mut a: Vec<i32> = Array1::<i32>::from_iterator(vec_a.into_iter(), vec_a_len);
2176
2177 let vec_b = vec![1, 1, 1, 1, 1, 1];
2178 a.sub_mut(&vec_b);
2179
2180 assert_eq!(a, [63, 579, 28, 65, 569, 32])
2181 }
2182
2183 #[test]
2184 fn test_from_vec_slice() {
2185 let vec_a = Vec::from([64, 580, 29, 66, 570, 33]);
2186 let a: Vec<i32> = Array1::<i32>::from_vec_slice(&vec_a[0..3]);
2187
2188 let vec_b = vec![1, 1, 1];
2189 let result = a.add(&vec_b);
2190
2191 assert_eq!(result, [65, 581, 30])
2192 }
2193}