smartcore/linalg/basic/
arrays.rs

1use std::fmt;
2use std::fmt::{Debug, Display};
3use std::ops::Neg;
4use std::ops::Range;
5
6use crate::numbers::basenum::Number;
7use crate::numbers::realnum::RealNumber;
8
9use num::ToPrimitive;
10use num_traits::Signed;
11
12/// Abstract methods for Array
13pub trait Array<T: Debug + Display + Copy + Sized, S>: Debug {
14    /// retrieve a reference to a value at position
15    fn get(&self, pos: S) -> &T;
16    /// return shape of the array
17    fn shape(&self) -> S;
18    /// return true if array is empty
19    fn is_empty(&self) -> bool;
20    /// iterate over array's values
21    fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b>;
22}
23
24/// Abstract methods for mutable Array
25pub trait MutArray<T: Debug + Display + Copy + Sized, S>: Array<T, S> {
26    /// assign value to a position
27    fn set(&mut self, pos: S, x: T);
28    /// iterate over mutable values
29    fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b>;
30    /// swap values between positions
31    fn swap(&mut self, a: S, b: S)
32    where
33        S: Copy,
34    {
35        let t = *self.get(a);
36        self.set(a, *self.get(b));
37        self.set(b, t);
38    }
39    /// divide element by a given value
40    fn div_element_mut(&mut self, pos: S, x: T)
41    where
42        T: Number,
43        S: Copy,
44    {
45        self.set(pos, *self.get(pos) / x);
46    }
47    /// multiply element for a given value
48    fn mul_element_mut(&mut self, pos: S, x: T)
49    where
50        T: Number,
51        S: Copy,
52    {
53        self.set(pos, *self.get(pos) * x);
54    }
55    /// add a given value to an element
56    fn add_element_mut(&mut self, pos: S, x: T)
57    where
58        T: Number,
59        S: Copy,
60    {
61        self.set(pos, *self.get(pos) + x);
62    }
63    /// subtract a given value to an element
64    fn sub_element_mut(&mut self, pos: S, x: T)
65    where
66        T: Number,
67        S: Copy,
68    {
69        self.set(pos, *self.get(pos) - x);
70    }
71    /// subtract a given value to all the elements
72    fn sub_scalar_mut(&mut self, x: T)
73    where
74        T: Number,
75    {
76        self.iterator_mut(0).for_each(|v| *v -= x);
77    }
78    /// add  a given value to all the elements
79    fn add_scalar_mut(&mut self, x: T)
80    where
81        T: Number,
82    {
83        self.iterator_mut(0).for_each(|v| *v += x);
84    }
85    /// multiply a given value to all the elements
86    fn mul_scalar_mut(&mut self, x: T)
87    where
88        T: Number,
89    {
90        self.iterator_mut(0).for_each(|v| *v *= x);
91    }
92    /// divide a given value to all the elements
93    fn div_scalar_mut(&mut self, x: T)
94    where
95        T: Number,
96    {
97        self.iterator_mut(0).for_each(|v| *v /= x);
98    }
99    /// add values from another array to the values of initial array
100    fn add_mut(&mut self, other: &dyn Array<T, S>)
101    where
102        T: Number,
103        S: Eq,
104    {
105        assert!(
106            self.shape() == other.shape(),
107            "A and B should have the same shape"
108        );
109        self.iterator_mut(0)
110            .zip(other.iterator(0))
111            .for_each(|(a, &b)| *a += b);
112    }
113    /// subtract values from another array to the values of initial array
114    fn sub_mut(&mut self, other: &dyn Array<T, S>)
115    where
116        T: Number,
117        S: Eq,
118    {
119        assert!(
120            self.shape() == other.shape(),
121            "A and B should have the same shape"
122        );
123        self.iterator_mut(0)
124            .zip(other.iterator(0))
125            .for_each(|(a, &b)| *a -= b);
126    }
127    /// multiply values from another array to the values of initial array
128    fn mul_mut(&mut self, other: &dyn Array<T, S>)
129    where
130        T: Number,
131        S: Eq,
132    {
133        assert!(
134            self.shape() == other.shape(),
135            "A and B should have the same shape"
136        );
137        self.iterator_mut(0)
138            .zip(other.iterator(0))
139            .for_each(|(a, &b)| *a *= b);
140    }
141    /// divide values from another array to the values of initial array
142    fn div_mut(&mut self, other: &dyn Array<T, S>)
143    where
144        T: Number,
145        S: Eq,
146    {
147        assert!(
148            self.shape() == other.shape(),
149            "A and B should have the same shape"
150        );
151        self.iterator_mut(0)
152            .zip(other.iterator(0))
153            .for_each(|(a, &b)| *a /= b);
154    }
155}
156
157/// Trait for 1D-arrays
158pub trait ArrayView1<T: Debug + Display + Copy + Sized>: Array<T, usize> {
159    /// return dot product with another array
160    fn dot(&self, other: &dyn ArrayView1<T>) -> T
161    where
162        T: Number,
163    {
164        assert!(
165            self.shape() == other.shape(),
166            "Can't take dot product. Arrays have different shapes"
167        );
168        self.iterator(0)
169            .zip(other.iterator(0))
170            .map(|(s, o)| *s * *o)
171            .sum()
172    }
173    /// return sum of all value of the view
174    fn sum(&self) -> T
175    where
176        T: Number,
177    {
178        self.iterator(0).copied().sum()
179    }
180    /// return max value from the view
181    fn max(&self) -> T
182    where
183        T: Number + PartialOrd,
184    {
185        let max_f = |max: T, v: &T| -> T {
186            match T::gt(v, &max) {
187                true => *v,
188                _ => max,
189            }
190        };
191        self.iterator(0).fold(T::min_value(), max_f)
192    }
193    /// return min value from  the view
194    fn min(&self) -> T
195    where
196        T: Number + PartialOrd,
197    {
198        let min_f = |min: T, v: &T| -> T {
199            match T::lt(v, &min) {
200                true => *v,
201                _ => min,
202            }
203        };
204        self.iterator(0).fold(T::max_value(), min_f)
205    }
206    /// return the position of the max value of the view
207    fn argmax(&self) -> usize
208    where
209        T: Number + PartialOrd,
210    {
211        // TODO: add check on shape, axis shall be axis < 1
212        let mut max = T::min_value();
213        let mut max_pos = 0usize;
214        for (i, v) in self.iterator(0).enumerate() {
215            if T::gt(v, &max) {
216                max = *v;
217                max_pos = i;
218            }
219        }
220        max_pos
221    }
222    /// sort the elements and remove duplicates
223    fn unique(&self) -> Vec<T>
224    where
225        T: Number + Ord,
226    {
227        let mut result: Vec<T> = self.iterator(0).copied().collect();
228        result.sort();
229        result.dedup();
230        result
231    }
232    /// return sorted unique elements
233    fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>)
234    where
235        T: Number + Ord,
236    {
237        let mut unique: Vec<T> = self.iterator(0).copied().collect();
238        unique.sort();
239        unique.dedup();
240
241        let mut unique_index = Vec::with_capacity(self.shape());
242        for idx in 0..self.shape() {
243            unique_index.push(unique.iter().position(|v| self.get(idx) == v).unwrap());
244        }
245
246        (unique, unique_index)
247    }
248    /// return norm2
249    fn norm2(&self) -> f64
250    where
251        T: Number,
252    {
253        self.iterator(0)
254            .fold(0f64, |norm, xi| {
255                let xi = xi.to_f64().unwrap();
256                norm + xi * xi
257            })
258            .sqrt()
259    }
260    /// return norm
261    fn norm(&self, p: f64) -> f64
262    where
263        T: Number,
264    {
265        if p.is_infinite() && p.is_sign_positive() {
266            self.iterator(0)
267                .map(|x| x.to_f64().unwrap().abs())
268                .fold(f64::NEG_INFINITY, |a, b| a.max(b))
269        } else if p.is_infinite() && p.is_sign_negative() {
270            self.iterator(0)
271                .map(|x| x.to_f64().unwrap().abs())
272                .fold(f64::INFINITY, |a, b| a.min(b))
273        } else {
274            let mut norm = 0f64;
275
276            for xi in self.iterator(0) {
277                norm += xi.to_f64().unwrap().abs().powf(p);
278            }
279
280            norm.powf(1f64 / p)
281        }
282    }
283    /// return max differences in array
284    fn max_diff(&self, other: &dyn ArrayView1<T>) -> T
285    where
286        T: Number + Signed + PartialOrd,
287    {
288        assert!(
289            self.shape() == other.shape(),
290            "Both arrays should have the same shape ({})",
291            self.shape()
292        );
293        let max_f = |max: T, v: T| -> T {
294            match T::gt(&v, &max) {
295                true => v,
296                _ => max,
297            }
298        };
299        self.iterator(0)
300            .zip(other.iterator(0))
301            .map(|(&a, &b)| (a - b).abs())
302            .fold(T::min_value(), max_f)
303    }
304    /// return array variance
305    fn variance(&self) -> f64
306    where
307        T: Number,
308    {
309        let n = self.shape();
310
311        let mut mu = 0f64;
312        let mut sum = 0f64;
313        let div = n as f64;
314        for i in 0..n {
315            let xi = T::to_f64(self.get(i)).unwrap();
316            mu += xi;
317            sum += xi * xi;
318        }
319        mu /= div;
320        sum / div - mu.powi(2)
321    }
322    /// return variance
323    fn std_dev(&self) -> f64
324    where
325        T: Number,
326    {
327        self.variance().sqrt()
328    }
329    /// return mean of the array
330    fn mean_by(&self) -> f64
331    where
332        T: Number,
333    {
334        self.sum().to_f64().unwrap() / self.shape() as f64
335    }
336}
337
338/// Trait for 2D-array
339pub trait ArrayView2<T: Debug + Display + Copy + Sized>: Array<T, (usize, usize)> {
340    /// return max value in array
341    fn max(&self, axis: u8) -> Vec<T>
342    where
343        T: Number + PartialOrd,
344    {
345        let (nrows, ncols) = self.shape();
346        let max_f = |max: T, r: usize, c: usize| -> T {
347            let v = self.get((r, c));
348            match T::gt(v, &max) {
349                true => *v,
350                _ => max,
351            }
352        };
353        match axis {
354            0 => (0..ncols)
355                .map(move |c| (0..nrows).fold(T::min_value(), |max, r| max_f(max, r, c)))
356                .collect(),
357            _ => (0..nrows)
358                .map(move |r| (0..ncols).fold(T::min_value(), |max, c| max_f(max, r, c)))
359                .collect(),
360        }
361    }
362    /// return sum of element of array
363    fn sum(&self, axis: u8) -> Vec<T>
364    where
365        T: Number,
366    {
367        let (nrows, ncols) = self.shape();
368        match axis {
369            0 => (0..ncols)
370                .map(move |c| (0..nrows).map(|r| *self.get((r, c))).sum())
371                .collect(),
372            _ => (0..nrows)
373                .map(move |r| (0..ncols).map(|c| *self.get((r, c))).sum())
374                .collect(),
375        }
376    }
377    /// return min value of array
378    fn min(&self, axis: u8) -> Vec<T>
379    where
380        T: Number + PartialOrd,
381    {
382        let (nrows, ncols) = self.shape();
383        let min_f = |min: T, r: usize, c: usize| -> T {
384            let v = self.get((r, c));
385            match T::lt(v, &min) {
386                true => *v,
387                _ => min,
388            }
389        };
390        match axis {
391            0 => (0..ncols)
392                .map(move |c| (0..nrows).fold(T::max_value(), |min, r| min_f(min, r, c)))
393                .collect(),
394            _ => (0..nrows)
395                .map(move |r| (0..ncols).fold(T::max_value(), |min, c| min_f(min, r, c)))
396                .collect(),
397        }
398    }
399    /// return positions of max values in both rows
400    fn argmax(&self, axis: u8) -> Vec<usize>
401    where
402        T: Number + PartialOrd,
403    {
404        // TODO: add check on shape, axis value shall be < 2
405        let max_f = |max: (T, usize), v: (T, usize)| -> (T, usize) {
406            match T::gt(&v.0, &max.0) {
407                true => v,
408                _ => max,
409            }
410        };
411        let (nrows, ncols) = self.shape();
412        match axis {
413            0 => (0..ncols)
414                .map(move |c| {
415                    (0..nrows).fold((T::min_value(), 0), |max, r| {
416                        max_f(max, (*self.get((r, c)), r))
417                    })
418                })
419                .map(|(_, i)| i)
420                .collect(),
421            _ => (0..nrows)
422                .map(move |r| {
423                    (0..ncols).fold((T::min_value(), 0), |max, c| {
424                        max_f(max, (*self.get((r, c)), c))
425                    })
426                })
427                .map(|(_, i)| i)
428                .collect(),
429        }
430    }
431    /// return mean value
432    /// TODO: this can be made more readable and efficient using the
433    /// methods in `linalg::traits::stats`
434    fn mean_by(&self, axis: u8) -> Vec<f64>
435    where
436        T: Number,
437    {
438        let (n, m) = match axis {
439            0 => {
440                let (n, m) = self.shape();
441                (m, n)
442            }
443            _ => self.shape(),
444        };
445
446        let mut x: Vec<f64> = vec![0f64; n];
447
448        let div = m as f64;
449
450        for (i, x_i) in x.iter_mut().enumerate().take(n) {
451            for j in 0..m {
452                *x_i += match axis {
453                    0 => T::to_f64(self.get((j, i))).unwrap(),
454                    _ => T::to_f64(self.get((i, j))).unwrap(),
455                };
456            }
457            *x_i /= div;
458        }
459
460        x
461    }
462    /// return variance
463    fn variance(&self, axis: u8) -> Vec<f64>
464    where
465        T: Number + RealNumber,
466    {
467        let (n, m) = match axis {
468            0 => {
469                let (n, m) = self.shape();
470                (m, n)
471            }
472            _ => self.shape(),
473        };
474
475        let mut x: Vec<f64> = vec![0f64; n];
476
477        let div = m as f64;
478
479        for (i, x_i) in x.iter_mut().enumerate().take(n) {
480            let mut mu = 0f64;
481            let mut sum = 0f64;
482            for j in 0..m {
483                let a = match axis {
484                    0 => T::to_f64(self.get((j, i))).unwrap(),
485                    _ => T::to_f64(self.get((i, j))).unwrap(),
486                };
487                mu += a;
488                sum += a * a;
489            }
490            mu /= div;
491            *x_i = sum / div - mu.powi(2);
492        }
493
494        x
495    }
496    /// return standard deviation
497    fn std_dev(&self, axis: u8) -> Vec<f64>
498    where
499        T: Number + RealNumber,
500    {
501        let mut x = self.variance(axis);
502
503        let n = match axis {
504            0 => self.shape().1,
505            _ => self.shape().0,
506        };
507
508        for x_i in x.iter_mut().take(n) {
509            *x_i = x_i.sqrt();
510        }
511
512        x
513    }
514    /// return covariance
515    fn cov(&self, cov: &mut dyn MutArrayView2<f64>)
516    where
517        T: Number,
518    {
519        let (m, n) = self.shape();
520
521        let mu = self.mean_by(0);
522
523        for k in 0..m {
524            for i in 0..n {
525                for j in 0..=i {
526                    cov.add_element_mut(
527                        (i, j),
528                        (self.get((k, i)).to_f64().unwrap() - mu[i])
529                            * (self.get((k, j)).to_f64().unwrap() - mu[j]),
530                    );
531                }
532            }
533        }
534
535        let m = (m - 1) as f64;
536
537        for i in 0..n {
538            for j in 0..=i {
539                cov.div_element_mut((i, j), m);
540                cov.set((j, i), *cov.get((i, j)));
541            }
542        }
543    }
544    /// print out array
545    fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
546        let (nrows, ncols) = self.shape();
547        for r in 0..nrows {
548            let row: Vec<T> = (0..ncols).map(|c| *self.get((r, c))).collect();
549            writeln!(f, "{row:?}")?
550        }
551        Ok(())
552    }
553    /// return norm
554    fn norm(&self, p: f64) -> f64
555    where
556        T: Number,
557    {
558        if p.is_infinite() && p.is_sign_positive() {
559            self.iterator(0)
560                .map(|x| x.to_f64().unwrap().abs())
561                .fold(f64::NEG_INFINITY, |a, b| a.max(b))
562        } else if p.is_infinite() && p.is_sign_negative() {
563            self.iterator(0)
564                .map(|x| x.to_f64().unwrap().abs())
565                .fold(f64::INFINITY, |a, b| a.min(b))
566        } else {
567            let mut norm = 0f64;
568
569            for xi in self.iterator(0) {
570                norm += xi.to_f64().unwrap().abs().powf(p);
571            }
572
573            norm.powf(1f64 / p)
574        }
575    }
576    /// return array diagonal
577    fn diag(&self) -> Vec<T> {
578        let (nrows, ncols) = self.shape();
579        let n = nrows.min(ncols);
580
581        (0..n).map(|i| *self.get((i, i))).collect()
582    }
583}
584
585/// Trait for mutable 1D-array
586pub trait MutArrayView1<T: Debug + Display + Copy + Sized>:
587    MutArray<T, usize> + ArrayView1<T>
588{
589    /// copy a mutable view from array
590    fn copy_from(&mut self, other: &dyn Array<T, usize>) {
591        self.iterator_mut(0)
592            .zip(other.iterator(0))
593            .for_each(|(s, o)| *s = *o);
594    }
595    /// return a mutable view of absolute values
596    fn abs_mut(&mut self)
597    where
598        T: Number + Signed,
599    {
600        self.iterator_mut(0).for_each(|v| *v = v.abs());
601    }
602    /// return a mutable view of values with opposite sign
603    fn neg_mut(&mut self)
604    where
605        T: Number + Neg<Output = T>,
606    {
607        self.iterator_mut(0).for_each(|v| *v = -*v);
608    }
609    /// return a mutable view of values at power `p`
610    fn pow_mut(&mut self, p: T)
611    where
612        T: RealNumber,
613    {
614        self.iterator_mut(0).for_each(|v| *v = v.powf(p));
615    }
616    /// return vector of indices for sorted elements
617    fn argsort_mut(&mut self) -> Vec<usize>
618    where
619        T: Number + PartialOrd,
620    {
621        let stack_size = 64;
622        let mut jstack = -1;
623        let mut l = 0;
624        let mut istack = vec![0; stack_size];
625        let mut ir = self.shape() - 1;
626        let mut index: Vec<usize> = (0..self.shape()).collect();
627
628        loop {
629            if ir - l < 7 {
630                for j in l + 1..=ir {
631                    let a = *self.get(j);
632                    let b = index[j];
633                    let mut i: i32 = (j - 1) as i32;
634                    while i >= l as i32 {
635                        if *self.get(i as usize) <= a {
636                            break;
637                        }
638                        self.set((i + 1) as usize, *self.get(i as usize));
639                        index[(i + 1) as usize] = index[i as usize];
640                        i -= 1;
641                    }
642                    self.set((i + 1) as usize, a);
643                    index[(i + 1) as usize] = b;
644                }
645                if jstack < 0 {
646                    break;
647                }
648                ir = istack[jstack as usize];
649                jstack -= 1;
650                l = istack[jstack as usize];
651                jstack -= 1;
652            } else {
653                let k = (l + ir) >> 1;
654                self.swap(k, l + 1);
655                index.swap(k, l + 1);
656                if self.get(l) > self.get(ir) {
657                    self.swap(l, ir);
658                    index.swap(l, ir);
659                }
660                if self.get(l + 1) > self.get(ir) {
661                    self.swap(l + 1, ir);
662                    index.swap(l + 1, ir);
663                }
664                if self.get(l) > self.get(l + 1) {
665                    self.swap(l, l + 1);
666                    index.swap(l, l + 1);
667                }
668                let mut i = l + 1;
669                let mut j = ir;
670                let a = *self.get(l + 1);
671                let b = index[l + 1];
672                loop {
673                    loop {
674                        i += 1;
675                        if *self.get(i) >= a {
676                            break;
677                        }
678                    }
679                    loop {
680                        j -= 1;
681                        if *self.get(j) <= a {
682                            break;
683                        }
684                    }
685                    if j < i {
686                        break;
687                    }
688                    self.swap(i, j);
689                    index.swap(i, j);
690                }
691                self.set(l + 1, *self.get(j));
692                self.set(j, a);
693                index[l + 1] = index[j];
694                index[j] = b;
695                jstack += 2;
696
697                if jstack >= 64 {
698                    panic!("stack size is too small.");
699                }
700
701                if ir - i + 1 >= j - l {
702                    istack[jstack as usize] = ir;
703                    istack[jstack as usize - 1] = i;
704                    ir = j - 1;
705                } else {
706                    istack[jstack as usize] = j - 1;
707                    istack[jstack as usize - 1] = l;
708                    l = i;
709                }
710            }
711        }
712
713        index
714    }
715    /// return softmax values
716    fn softmax_mut(&mut self)
717    where
718        T: RealNumber,
719    {
720        let max = self.max();
721        let mut z = T::zero();
722        self.iterator_mut(0).for_each(|v| {
723            *v = (*v - max).exp();
724            z += *v;
725        });
726        self.iterator_mut(0).for_each(|v| *v /= z);
727    }
728}
729
730/// Trait for mutable 2D-array views
731pub trait MutArrayView2<T: Debug + Display + Copy + Sized>:
732    MutArray<T, (usize, usize)> + ArrayView2<T>
733{
734    /// copy values from another array
735    fn copy_from(&mut self, other: &dyn Array<T, (usize, usize)>) {
736        self.iterator_mut(0)
737            .zip(other.iterator(0))
738            .for_each(|(s, o)| *s = *o);
739    }
740    /// update view with absolute values
741    fn abs_mut(&mut self)
742    where
743        T: Number + Signed,
744    {
745        self.iterator_mut(0).for_each(|v| *v = v.abs());
746    }
747    /// update view values with opposite sign
748    fn neg_mut(&mut self)
749    where
750        T: Number + Neg<Output = T>,
751    {
752        self.iterator_mut(0).for_each(|v| *v = -*v);
753    }
754    /// update view values at power `p`
755    fn pow_mut(&mut self, p: T)
756    where
757        T: RealNumber,
758    {
759        self.iterator_mut(0).for_each(|v| *v = v.powf(p));
760    }
761    /// scale view values
762    fn scale_mut(&mut self, mean: &[T], std: &[T], axis: u8)
763    where
764        T: Number,
765    {
766        let (n, m) = match axis {
767            0 => {
768                let (n, m) = self.shape();
769                (m, n)
770            }
771            _ => self.shape(),
772        };
773
774        for i in 0..n {
775            for j in 0..m {
776                match axis {
777                    0 => self.set((j, i), (*self.get((j, i)) - mean[i]) / std[i]),
778                    _ => self.set((i, j), (*self.get((i, j)) - mean[i]) / std[i]),
779                }
780            }
781        }
782    }
783}
784
785/// Trait for mutable 1D-array view
786pub trait Array1<T: Debug + Display + Copy + Sized>: MutArrayView1<T> + Sized + Clone {
787    /// return a view of the array
788    fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a>;
789    /// return a mutable view of the array
790    fn slice_mut<'a>(&'a mut self, range: Range<usize>) -> Box<dyn MutArrayView1<T> + 'a>;
791    /// fill array with a given value
792    fn fill(len: usize, value: T) -> Self
793    where
794        Self: Sized;
795    /// create array from iterator
796    fn from_iterator<I: Iterator<Item = T>>(iter: I, len: usize) -> Self
797    where
798        Self: Sized;
799    /// create array from vector
800    fn from_vec_slice(slice: &[T]) -> Self
801    where
802        Self: Sized;
803    /// create array from slice
804    fn from_slice(slice: &'_ dyn ArrayView1<T>) -> Self
805    where
806        Self: Sized;
807    /// create a zero array
808    fn zeros(len: usize) -> Self
809    where
810        T: Number,
811        Self: Sized,
812    {
813        Self::fill(len, T::zero())
814    }
815    /// create an array of ones
816    fn ones(len: usize) -> Self
817    where
818        T: Number,
819        Self: Sized,
820    {
821        Self::fill(len, T::one())
822    }
823    /// create an array of random values
824    fn rand(len: usize) -> Self
825    where
826        T: RealNumber,
827        Self: Sized,
828    {
829        Self::from_iterator((0..len).map(|_| T::rand()), len)
830    }
831    /// add a scalar to the array
832    fn add_scalar(&self, x: T) -> Self
833    where
834        T: Number,
835        Self: Sized,
836    {
837        let mut result = self.clone();
838        result.add_scalar_mut(x);
839        result
840    }
841    /// subtract a scalar from the array
842    fn sub_scalar(&self, x: T) -> Self
843    where
844        T: Number,
845        Self: Sized,
846    {
847        let mut result = self.clone();
848        result.sub_scalar_mut(x);
849        result
850    }
851    /// divide a scalar from the array
852    fn div_scalar(&self, x: T) -> Self
853    where
854        T: Number,
855        Self: Sized,
856    {
857        let mut result = self.clone();
858        result.div_scalar_mut(x);
859        result
860    }
861    /// multiply a scalar to the array
862    fn mul_scalar(&self, x: T) -> Self
863    where
864        T: Number,
865        Self: Sized,
866    {
867        let mut result = self.clone();
868        result.mul_scalar_mut(x);
869        result
870    }
871    /// sum of two arrays
872    fn add(&self, other: &dyn Array<T, usize>) -> Self
873    where
874        T: Number,
875        Self: Sized,
876    {
877        let mut result = self.clone();
878        result.add_mut(other);
879        result
880    }
881    /// subtract two arrays
882    fn sub(&self, other: &impl Array1<T>) -> Self
883    where
884        T: Number,
885        Self: Sized,
886    {
887        let mut result = self.clone();
888        result.sub_mut(other);
889        result
890    }
891    /// multiply two arrays
892    fn mul(&self, other: &dyn Array<T, usize>) -> Self
893    where
894        T: Number,
895        Self: Sized,
896    {
897        let mut result = self.clone();
898        result.mul_mut(other);
899        result
900    }
901    /// divide two arrays
902    fn div(&self, other: &dyn Array<T, usize>) -> Self
903    where
904        T: Number,
905        Self: Sized,
906    {
907        let mut result = self.clone();
908        result.div_mut(other);
909        result
910    }
911    /// replace values with another array
912    fn take(&self, index: &[usize]) -> Self
913    where
914        Self: Sized,
915    {
916        let len = self.shape();
917        assert!(
918            index.iter().all(|&i| i < len),
919            "All indices in `take` should be < {len}"
920        );
921        Self::from_iterator(index.iter().map(move |&i| *self.get(i)), index.len())
922    }
923    /// create a view of the array with absolute values
924    fn abs(&self) -> Self
925    where
926        T: Number + Signed,
927        Self: Sized,
928    {
929        let mut result = self.clone();
930        result.abs_mut();
931        result
932    }
933    /// create a view of the array with opposite sign
934    fn neg(&self) -> Self
935    where
936        T: Number + Neg<Output = T>,
937        Self: Sized,
938    {
939        let mut result = self.clone();
940        result.neg_mut();
941        result
942    }
943    /// create a view of the array with values at power `p`
944    fn pow(&self, p: T) -> Self
945    where
946        T: RealNumber,
947        Self: Sized,
948    {
949        let mut result = self.clone();
950        result.pow_mut(p);
951        result
952    }
953    /// apply argsort to the array
954    fn argsort(&self) -> Vec<usize>
955    where
956        T: Number + PartialOrd,
957    {
958        let mut v = self.clone();
959        v.argsort_mut()
960    }
961    /// map values of the array
962    fn map<O: Debug + Display + Copy + Sized, A: Array1<O>, F: FnMut(&T) -> O>(self, f: F) -> A {
963        let len = self.shape();
964        A::from_iterator(self.iterator(0).map(f), len)
965    }
966    /// apply softmax to the array
967    fn softmax(&self) -> Self
968    where
969        T: RealNumber,
970        Self: Sized,
971    {
972        let mut result = self.clone();
973        result.softmax_mut();
974        result
975    }
976    /// multiply array by matrix
977    fn xa(&self, a_transpose: bool, a: &dyn ArrayView2<T>) -> Self
978    where
979        T: Number,
980        Self: Sized,
981    {
982        let (nrows, ncols) = a.shape();
983        let len = self.shape();
984        let (d1, d2) = match a_transpose {
985            true => (ncols, nrows),
986            _ => (nrows, ncols),
987        };
988        assert!(
989            d1 == len,
990            "Can not multiply {nrows}x{ncols} matrix by {len} vector"
991        );
992        let mut result = Self::zeros(d2);
993        for i in 0..d2 {
994            let mut s = T::zero();
995            for j in 0..d1 {
996                match a_transpose {
997                    true => s += *a.get((i, j)) * *self.get(j),
998                    _ => s += *a.get((j, i)) * *self.get(j),
999                }
1000            }
1001            result.set(i, s);
1002        }
1003        result
1004    }
1005
1006    /// check if two arrays are approximately equal
1007    fn approximate_eq(&self, other: &Self, error: T) -> bool
1008    where
1009        T: Number + RealNumber,
1010        Self: Sized,
1011    {
1012        (self.sub(other)).iterator(0).all(|v| v.abs() <= error)
1013    }
1014}
1015
1016/// Trait for mutable 2D-array view
1017pub trait Array2<T: Debug + Display + Copy + Sized>: MutArrayView2<T> + Sized + Clone {
1018    /// fill 2d array with a given value
1019    fn fill(nrows: usize, ncols: usize, value: T) -> Self;
1020    /// get a view of the 2d array
1021    fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a>
1022    where
1023        Self: Sized;
1024    /// get a mutable view of the 2d array
1025    fn slice_mut<'a>(
1026        &'a mut self,
1027        rows: Range<usize>,
1028        cols: Range<usize>,
1029    ) -> Box<dyn MutArrayView2<T> + 'a>
1030    where
1031        Self: Sized;
1032    /// create 2d array from iterator
1033    fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self;
1034    /// get row from 2d array
1035    fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a>
1036    where
1037        Self: Sized;
1038    /// get column from 2d array
1039    fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a>
1040    where
1041        Self: Sized;
1042    /// create a zero 2d array
1043    fn zeros(nrows: usize, ncols: usize) -> Self
1044    where
1045        T: Number,
1046    {
1047        Self::fill(nrows, ncols, T::zero())
1048    }
1049    /// create a 2d array of ones
1050    fn ones(nrows: usize, ncols: usize) -> Self
1051    where
1052        T: Number,
1053    {
1054        Self::fill(nrows, ncols, T::one())
1055    }
1056    /// create an identity matrix
1057    fn eye(size: usize) -> Self
1058    where
1059        T: Number,
1060    {
1061        let mut matrix = Self::zeros(size, size);
1062
1063        for i in 0..size {
1064            matrix.set((i, i), T::one());
1065        }
1066
1067        matrix
1068    }
1069    /// create a 2d array of random values
1070    fn rand(nrows: usize, ncols: usize) -> Self
1071    where
1072        T: RealNumber,
1073    {
1074        Self::from_iterator((0..nrows * ncols).map(|_| T::rand()), nrows, ncols, 0)
1075    }
1076    /// crate from 2d slice
1077    fn from_slice(slice: &dyn ArrayView2<T>) -> Self {
1078        let (nrows, ncols) = slice.shape();
1079        Self::from_iterator(slice.iterator(0).cloned(), nrows, ncols, 0)
1080    }
1081    /// create from row
1082    fn from_row(slice: &dyn ArrayView1<T>) -> Self {
1083        let ncols = slice.shape();
1084        Self::from_iterator(slice.iterator(0).cloned(), 1, ncols, 0)
1085    }
1086    /// create from column
1087    fn from_column(slice: &dyn ArrayView1<T>) -> Self {
1088        let nrows = slice.shape();
1089        Self::from_iterator(slice.iterator(0).cloned(), nrows, 1, 0)
1090    }
1091    /// transpose 2d array
1092    fn transpose(&self) -> Self {
1093        let (nrows, ncols) = self.shape();
1094        let mut m = Self::fill(ncols, nrows, *self.get((0, 0)));
1095        for c in 0..ncols {
1096            for r in 0..nrows {
1097                m.set((c, r), *self.get((r, c)));
1098            }
1099        }
1100        m
1101    }
1102    /// change shape of 2d array
1103    fn reshape(&self, nrows: usize, ncols: usize, axis: u8) -> Self {
1104        let (onrows, oncols) = self.shape();
1105
1106        assert!(
1107            nrows * ncols == onrows * oncols,
1108            "Can't reshape {onrows}x{oncols} array into a {nrows}x{ncols} array"
1109        );
1110
1111        Self::from_iterator(self.iterator(0).cloned(), nrows, ncols, axis)
1112    }
1113    /// multiply two 2d arrays
1114    fn matmul(&self, other: &dyn ArrayView2<T>) -> Self
1115    where
1116        T: Number,
1117    {
1118        let (nrows, ncols) = self.shape();
1119        let (o_nrows, o_ncols) = other.shape();
1120        assert!(
1121            ncols == o_nrows,
1122            "Can't multiply {nrows}x{ncols} and {o_nrows}x{o_ncols} matrices"
1123        );
1124        let inner_d = ncols;
1125        let mut result = Self::zeros(nrows, o_ncols);
1126
1127        for r in 0..nrows {
1128            for c in 0..o_ncols {
1129                let mut s = T::zero();
1130                for i in 0..inner_d {
1131                    s += *self.get((r, i)) * *other.get((i, c));
1132                }
1133                result.set((r, c), s);
1134            }
1135        }
1136
1137        result
1138    }
1139    /// matrix multiplication
1140    fn ab(&self, a_transpose: bool, b: &dyn ArrayView2<T>, b_transpose: bool) -> Self
1141    where
1142        T: Number,
1143    {
1144        if !a_transpose && !b_transpose {
1145            self.matmul(b)
1146        } else {
1147            let (nrows, ncols) = self.shape();
1148            let (o_nrows, o_ncols) = b.shape();
1149            let (d1, d2, d3, d4) = match (a_transpose, b_transpose) {
1150                (true, false) => (nrows, ncols, o_ncols, o_nrows),
1151                (false, true) => (ncols, nrows, o_nrows, o_ncols),
1152                _ => (nrows, ncols, o_nrows, o_ncols),
1153            };
1154            if d1 != d4 {
1155                panic!("Can not multiply {d2}x{d1} by {d4}x{d3} matrices");
1156            }
1157            let mut result = Self::zeros(d2, d3);
1158            for r in 0..d2 {
1159                for c in 0..d3 {
1160                    let mut s = T::zero();
1161                    for i in 0..d1 {
1162                        match (a_transpose, b_transpose) {
1163                            (true, false) => s += *self.get((i, r)) * *b.get((i, c)),
1164                            (false, true) => s += *self.get((r, i)) * *b.get((c, i)),
1165                            _ => s += *self.get((i, r)) * *b.get((c, i)),
1166                        }
1167                    }
1168                    result.set((r, c), s);
1169                }
1170            }
1171            result
1172        }
1173    }
1174    /// matrix vector multiplication
1175    fn ax(&self, a_transpose: bool, x: &dyn ArrayView1<T>) -> Self
1176    where
1177        T: Number,
1178    {
1179        let (nrows, ncols) = self.shape();
1180        let len = x.shape();
1181        let (d1, d2) = match a_transpose {
1182            true => (ncols, nrows),
1183            _ => (nrows, ncols),
1184        };
1185        assert!(
1186            d2 == len,
1187            "Can not multiply {nrows}x{ncols} matrix by {len} vector"
1188        );
1189        let mut result = Self::zeros(d1, 1);
1190        for i in 0..d1 {
1191            let mut s = T::zero();
1192            for j in 0..d2 {
1193                match a_transpose {
1194                    true => s += *self.get((j, i)) * *x.get(j),
1195                    _ => s += *self.get((i, j)) * *x.get(j),
1196                }
1197            }
1198            result.set((i, 0), s);
1199        }
1200        result
1201    }
1202    /// concatenate 1d array
1203    fn concatenate_1d<'a>(arrays: &'a [&'a dyn ArrayView1<T>], axis: u8) -> Self {
1204        assert!(
1205            axis == 1 || axis == 0,
1206            "For two dimensional array `axis` should be either 0 or 1"
1207        );
1208        assert!(!arrays.is_empty(), "Can't concatenate an empty array");
1209        assert!(
1210            arrays.windows(2).all(|w| w[0].shape() == w[1].shape()),
1211            "Can't concatenate arrays of different sizes"
1212        );
1213
1214        let first = &arrays[0];
1215        let tail = &arrays[1..];
1216
1217        match axis {
1218            0 => Self::from_iterator(
1219                tail.iter()
1220                    .fold(first.iterator(0), |acc, i| {
1221                        Box::new(acc.chain(i.iterator(0)))
1222                    })
1223                    .cloned(),
1224                arrays.len(),
1225                arrays[0].shape(),
1226                axis,
1227            ),
1228            _ => Self::from_iterator(
1229                tail.iter()
1230                    .fold(first.iterator(0), |acc, i| {
1231                        Box::new(acc.chain(i.iterator(0)))
1232                    })
1233                    .cloned(),
1234                arrays[0].shape(),
1235                arrays.len(),
1236                axis,
1237            ),
1238        }
1239    }
1240    /// concatenate 2d array
1241    fn concatenate_2d<'a>(arrays: &'a [&'a dyn ArrayView2<T>], axis: u8) -> Self {
1242        assert!(
1243            axis == 1 || axis == 0,
1244            "For two dimensional array `axis` should be either 0 or 1"
1245        );
1246        assert!(!arrays.is_empty(), "Can't concatenate an empty array");
1247        if axis == 0 {
1248            assert!(
1249                arrays.windows(2).all(|w| w[0].shape().1 == w[1].shape().1),
1250                "Number of columns in all arrays should match"
1251            );
1252        } else {
1253            assert!(
1254                arrays.windows(2).all(|w| w[0].shape().0 == w[1].shape().0),
1255                "Number of rows in all arrays should match"
1256            );
1257        }
1258
1259        let first = &arrays[0];
1260        let tail = &arrays[1..];
1261
1262        match axis {
1263            0 => {
1264                let (nrows, ncols) = (
1265                    arrays.iter().map(|a| a.shape().0).sum(),
1266                    arrays[0].shape().1,
1267                );
1268                Self::from_iterator(
1269                    tail.iter()
1270                        .fold(first.iterator(0), |acc, i| {
1271                            Box::new(acc.chain(i.iterator(0)))
1272                        })
1273                        .cloned(),
1274                    nrows,
1275                    ncols,
1276                    axis,
1277                )
1278            }
1279            _ => {
1280                let (nrows, ncols) = (
1281                    arrays[0].shape().0,
1282                    (arrays.iter().map(|a| a.shape().1).sum()),
1283                );
1284                Self::from_iterator(
1285                    tail.iter()
1286                        .fold(first.iterator(1), |acc, i| {
1287                            Box::new(acc.chain(i.iterator(1)))
1288                        })
1289                        .cloned(),
1290                    nrows,
1291                    ncols,
1292                    axis,
1293                )
1294            }
1295        }
1296    }
1297    /// merge 1d arrays
1298    fn merge_1d<'a>(&'a self, arrays: &'a [&'a dyn ArrayView1<T>], axis: u8, append: bool) -> Self {
1299        assert!(
1300            axis == 1 || axis == 0,
1301            "For two dimensional array `axis` should be either 0 or 1"
1302        );
1303        assert!(!arrays.is_empty(), "Can't merge with an empty array");
1304
1305        let first = &arrays[0];
1306        let tail = &arrays[1..];
1307
1308        match (append, axis) {
1309            (true, 0) => {
1310                let (nrows, ncols) = (self.shape().0 + arrays.len(), self.shape().1);
1311                Self::from_iterator(
1312                    self.iterator(0)
1313                        .chain(tail.iter().fold(first.iterator(0), |acc, i| {
1314                            Box::new(acc.chain(i.iterator(0)))
1315                        }))
1316                        .cloned(),
1317                    nrows,
1318                    ncols,
1319                    axis,
1320                )
1321            }
1322            (true, 1) => {
1323                let (nrows, ncols) = (self.shape().0, self.shape().1 + arrays.len());
1324                Self::from_iterator(
1325                    self.iterator(1)
1326                        .chain(tail.iter().fold(first.iterator(0), |acc, i| {
1327                            Box::new(acc.chain(i.iterator(0)))
1328                        }))
1329                        .cloned(),
1330                    nrows,
1331                    ncols,
1332                    axis,
1333                )
1334            }
1335            (false, 0) => {
1336                let (nrows, ncols) = (self.shape().0 + arrays.len(), self.shape().1);
1337                Self::from_iterator(
1338                    tail.iter()
1339                        .fold(first.iterator(0), |acc, i| {
1340                            Box::new(acc.chain(i.iterator(0)))
1341                        })
1342                        .chain(self.iterator(0))
1343                        .cloned(),
1344                    nrows,
1345                    ncols,
1346                    axis,
1347                )
1348            }
1349            _ => {
1350                let (nrows, ncols) = (self.shape().0, self.shape().1 + arrays.len());
1351                Self::from_iterator(
1352                    tail.iter()
1353                        .fold(first.iterator(0), |acc, i| {
1354                            Box::new(acc.chain(i.iterator(0)))
1355                        })
1356                        .chain(self.iterator(1))
1357                        .cloned(),
1358                    nrows,
1359                    ncols,
1360                    axis,
1361                )
1362            }
1363        }
1364    }
1365    /// Stack arrays in sequence vertically
1366    fn v_stack(&self, other: &dyn ArrayView2<T>) -> Self {
1367        let (nrows, ncols) = self.shape();
1368        let (other_nrows, other_ncols) = other.shape();
1369
1370        assert!(
1371            ncols == other_ncols,
1372            "For vertical stack number of rows in both arrays should match"
1373        );
1374        Self::from_iterator(
1375            self.iterator(0).chain(other.iterator(0)).cloned(),
1376            nrows + other_nrows,
1377            ncols,
1378            0,
1379        )
1380    }
1381    /// Stack arrays in sequence horizontally
1382    fn h_stack(&self, other: &dyn ArrayView2<T>) -> Self {
1383        let (nrows, ncols) = self.shape();
1384        let (other_nrows, other_ncols) = other.shape();
1385
1386        assert!(
1387            nrows == other_nrows,
1388            "For horizontal stack number of rows in both arrays should match"
1389        );
1390        Self::from_iterator(
1391            self.iterator(1).chain(other.iterator(1)).cloned(),
1392            nrows,
1393            other_ncols + ncols,
1394            1,
1395        )
1396    }
1397    /// map  array values
1398    fn map<O: Debug + Display + Copy + Sized, A: Array2<O>, F: FnMut(&T) -> O>(self, f: F) -> A {
1399        let (nrows, ncols) = self.shape();
1400        A::from_iterator(self.iterator(0).map(f), nrows, ncols, 0)
1401    }
1402    /// iter rows
1403    fn row_iter<'a>(&'a self) -> Box<dyn Iterator<Item = Box<dyn ArrayView1<T> + 'a>> + 'a> {
1404        Box::new((0..self.shape().0).map(move |r| self.get_row(r)))
1405    }
1406    /// iter cols
1407    fn col_iter<'a>(&'a self) -> Box<dyn Iterator<Item = Box<dyn ArrayView1<T> + 'a>> + 'a> {
1408        Box::new((0..self.shape().1).map(move |r| self.get_col(r)))
1409    }
1410    /// take elements from 2d array
1411    fn take(&self, index: &[usize], axis: u8) -> Self {
1412        let (nrows, ncols) = self.shape();
1413
1414        match axis {
1415            0 => {
1416                assert!(
1417                    index.iter().all(|&i| i < nrows),
1418                    "All indices in `take` should be < {nrows}"
1419                );
1420                Self::from_iterator(
1421                    index
1422                        .iter()
1423                        .flat_map(move |&r| (0..ncols).map(move |c| self.get((r, c))))
1424                        .cloned(),
1425                    index.len(),
1426                    ncols,
1427                    0,
1428                )
1429            }
1430            _ => {
1431                assert!(
1432                    index.iter().all(|&i| i < ncols),
1433                    "All indices in `take` should be < {ncols}"
1434                );
1435                Self::from_iterator(
1436                    (0..nrows)
1437                        .flat_map(move |r| index.iter().map(move |&c| self.get((r, c))))
1438                        .cloned(),
1439                    nrows,
1440                    index.len(),
1441                    0,
1442                )
1443            }
1444        }
1445    }
1446    /// Take an individual column from the matrix.
1447    fn take_column(&self, column_index: usize) -> Self {
1448        self.take(&[column_index], 1)
1449    }
1450    /// add a scalar to the array
1451    fn add_scalar(&self, x: T) -> Self
1452    where
1453        T: Number,
1454    {
1455        let mut result = self.clone();
1456        result.add_scalar_mut(x);
1457        result
1458    }
1459    /// subtract a scalar from the array
1460    fn sub_scalar(&self, x: T) -> Self
1461    where
1462        T: Number,
1463    {
1464        let mut result = self.clone();
1465        result.sub_scalar_mut(x);
1466        result
1467    }
1468    /// divide a scalar from the array
1469    fn div_scalar(&self, x: T) -> Self
1470    where
1471        T: Number,
1472    {
1473        let mut result = self.clone();
1474        result.div_scalar_mut(x);
1475        result
1476    }
1477    /// multiply a scalar to the array
1478    fn mul_scalar(&self, x: T) -> Self
1479    where
1480        T: Number,
1481    {
1482        let mut result = self.clone();
1483        result.mul_scalar_mut(x);
1484        result
1485    }
1486    /// sum of two arrays
1487    fn add(&self, other: &dyn Array<T, (usize, usize)>) -> Self
1488    where
1489        T: Number,
1490    {
1491        let mut result = self.clone();
1492        result.add_mut(other);
1493        result
1494    }
1495    /// subtract two arrays
1496    fn sub(&self, other: &dyn Array<T, (usize, usize)>) -> Self
1497    where
1498        T: Number,
1499    {
1500        let mut result = self.clone();
1501        result.sub_mut(other);
1502        result
1503    }
1504    /// multiply two arrays
1505    fn mul(&self, other: &dyn Array<T, (usize, usize)>) -> Self
1506    where
1507        T: Number,
1508    {
1509        let mut result = self.clone();
1510        result.mul_mut(other);
1511        result
1512    }
1513    /// divide two arrays
1514    fn div(&self, other: &dyn Array<T, (usize, usize)>) -> Self
1515    where
1516        T: Number,
1517    {
1518        let mut result = self.clone();
1519        result.div_mut(other);
1520        result
1521    }
1522    /// absolute values of the array
1523    fn abs(&self) -> Self
1524    where
1525        T: Number + Signed,
1526    {
1527        let mut result = self.clone();
1528        result.abs_mut();
1529        result
1530    }
1531    /// negation of the array
1532    fn neg(&self) -> Self
1533    where
1534        T: Number + Neg<Output = T>,
1535    {
1536        let mut result = self.clone();
1537        result.neg_mut();
1538        result
1539    }
1540    /// values at power `p`
1541    fn pow(&self, p: T) -> Self
1542    where
1543        T: RealNumber,
1544    {
1545        let mut result = self.clone();
1546        result.pow_mut(p);
1547        result
1548    }
1549
1550    /// compute mean for each column
1551    fn column_mean(&self) -> Vec<f64>
1552    where
1553        T: Number + ToPrimitive,
1554    {
1555        let mut mean = vec![0f64; self.shape().1];
1556
1557        for r in 0..self.shape().0 {
1558            for (c, mean_c) in mean.iter_mut().enumerate().take(self.shape().1) {
1559                let value: f64 = self.get((r, c)).to_f64().unwrap();
1560                *mean_c += value;
1561            }
1562        }
1563
1564        for mean_i in mean.iter_mut() {
1565            *mean_i /= self.shape().0 as f64;
1566        }
1567
1568        mean
1569    }
1570
1571    /// copy column as a vector
1572    fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>) {
1573        for (r, result_r) in result.iter_mut().enumerate().take(self.shape().0) {
1574            *result_r = *self.get((r, col));
1575        }
1576    }
1577
1578    /// approximate equality of the elements of a matrix according to a given error
1579    fn approximate_eq(&self, other: &Self, error: T) -> bool
1580    where
1581        T: Number + RealNumber,
1582    {
1583        (self.sub(other)).iterator(0).all(|v| v.abs() <= error)
1584            && (self.sub(other)).iterator(1).all(|v| v.abs() <= error)
1585    }
1586}
1587
1588#[cfg(test)]
1589mod tests {
1590    use super::*;
1591    use crate::linalg::basic::arrays::{Array, Array2, ArrayView2, MutArrayView2};
1592    use crate::linalg::basic::matrix::DenseMatrix;
1593    use approx::relative_eq;
1594
1595    #[test]
1596    fn test_dot() {
1597        let a = vec![1, 2, 3];
1598        let b = vec![1.0, 2.0, 3.0];
1599        let c = vec![4.0, 5.0, 6.0];
1600
1601        assert_eq!(b.slice(0..2).dot(c.slice(0..2).as_ref()), 14.);
1602        assert_eq!(b.slice(0..3).dot(&c), 32.);
1603        assert_eq!(b.dot(&c), 32.);
1604        assert_eq!(a.dot(&a), 14);
1605    }
1606
1607    #[test]
1608    #[should_panic]
1609    fn test_failed_dot() {
1610        let a = vec![1, 2, 3];
1611
1612        a.slice(0..2).dot(a.slice(0..3).as_ref());
1613    }
1614
1615    #[test]
1616    fn test_vec_chaining() {
1617        let mut x: Vec<i32> = Vec::zeros(6);
1618
1619        x.add_scalar(5);
1620        assert_eq!(vec!(5, 5, 5, 5, 5, 5), x.add_scalar(5));
1621        {
1622            let mut x_s = x.slice_mut(0..3);
1623            x_s.add_scalar_mut(1);
1624        }
1625
1626        assert_eq!(vec!(1, 1, 1, 0, 0, 0), x);
1627    }
1628
1629    #[test]
1630    fn test_vec_norm() {
1631        let v = vec![3., -2., 6.];
1632        assert_eq!(v.norm(1.), 11.);
1633        assert_eq!(v.norm(2.), 7.);
1634        assert_eq!(v.norm(f64::INFINITY), 6.);
1635        assert_eq!(v.norm(f64::NEG_INFINITY), 2.);
1636    }
1637
1638    #[test]
1639    fn test_vec_unique() {
1640        let n = vec![1, 2, 2, 3, 4, 5, 3, 2];
1641        assert_eq!(
1642            n.unique_with_indices(),
1643            (vec!(1, 2, 3, 4, 5), vec!(0, 1, 1, 2, 3, 4, 2, 1))
1644        );
1645        assert_eq!(n.unique(), vec!(1, 2, 3, 4, 5));
1646        assert_eq!(Vec::<i32>::zeros(100).unique(), vec![0]);
1647        assert_eq!(Vec::<i32>::zeros(100).slice(0..10).unique(), vec![0]);
1648    }
1649
1650    #[test]
1651    fn test_vec_var_std() {
1652        assert_eq!(vec![1., 2., 3., 4., 5.].variance(), 2.);
1653        assert_eq!(vec![1., 2.].std_dev(), 0.5);
1654        assert_eq!(vec![1.].variance(), 0.0);
1655        assert_eq!(vec![1.].std_dev(), 0.0);
1656    }
1657
1658    #[test]
1659    fn test_vec_abs() {
1660        let mut x = vec![-1, 2, -3];
1661        x.abs_mut();
1662        assert_eq!(x, vec![1, 2, 3]);
1663    }
1664
1665    #[test]
1666    fn test_vec_neg() {
1667        let mut x = vec![-1, 2, -3];
1668        x.neg_mut();
1669        assert_eq!(x, vec![1, -2, 3]);
1670    }
1671
1672    #[test]
1673    fn test_vec_copy_from() {
1674        let x = vec![1, 2, 3];
1675        let mut y = Vec::<i32>::zeros(3);
1676        y.copy_from(&x);
1677        assert_eq!(y, vec![1, 2, 3]);
1678    }
1679
1680    #[test]
1681    fn test_vec_element_ops() {
1682        let mut x = vec![1, 2, 3, 4];
1683        x.slice_mut(0..1).mul_element_mut(0, 4);
1684        x.slice_mut(1..2).add_element_mut(0, 1);
1685        x.slice_mut(2..3).sub_element_mut(0, 1);
1686        x.slice_mut(3..4).div_element_mut(0, 4);
1687        assert_eq!(x, vec![4, 3, 2, 1]);
1688    }
1689
1690    #[test]
1691    fn test_vec_ops() {
1692        assert_eq!(vec![1, 2, 3, 4].mul_scalar(2), vec![2, 4, 6, 8]);
1693        assert_eq!(vec![1, 2, 3, 4].add_scalar(2), vec![3, 4, 5, 6]);
1694        assert_eq!(vec![1, 2, 3, 4].sub_scalar(1), vec![0, 1, 2, 3]);
1695        assert_eq!(vec![1, 2, 3, 4].div_scalar(2), vec![0, 1, 1, 2]);
1696    }
1697
1698    #[test]
1699    fn test_vec_init() {
1700        assert_eq!(Vec::<i32>::ones(3), vec![1, 1, 1]);
1701        assert_eq!(Vec::<i32>::zeros(3), vec![0, 0, 0]);
1702    }
1703
1704    #[test]
1705    fn test_vec_min_max() {
1706        assert_eq!(ArrayView1::min(&vec![1, 2, 3, 4, 5, 6]), 1);
1707        assert_eq!(ArrayView1::max(&vec![1, 2, 3, 4, 5, 6]), 6);
1708    }
1709
1710    #[test]
1711    fn test_vec_take() {
1712        assert_eq!(vec![1, 2, 3, 4, 5, 6].take(&[0, 4, 5]), vec![1, 5, 6]);
1713    }
1714
1715    #[test]
1716    fn test_vec_rand() {
1717        let r = Vec::<f32>::rand(4);
1718        assert!(r.iterator(0).all(|&e| e <= 1f32));
1719        assert!(r.iterator(0).all(|&e| e >= 0f32));
1720        assert!(r.iterator(0).copied().sum::<f32>() > 0f32);
1721    }
1722
1723    #[test]
1724    #[should_panic]
1725    fn test_failed_vec_take() {
1726        assert_eq!(vec![1, 2, 3, 4, 5, 6].take(&[10, 4, 5]), vec![1, 5, 6]);
1727    }
1728
1729    #[test]
1730    fn test_vec_quicksort() {
1731        let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
1732        assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.argsort());
1733
1734        let arr2 = vec![
1735            0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6,
1736            1.0, 1.3, 1.4,
1737        ];
1738        assert_eq!(
1739            vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16],
1740            arr2.argsort()
1741        );
1742    }
1743
1744    #[test]
1745    fn test_vec_map() {
1746        let a = vec![1.0, 2.0, 3.0, 4.0];
1747        let expected = vec![2, 4, 6, 8];
1748        let result: Vec<i32> = a.map(|&v| v as i32 * 2);
1749        assert_eq!(result, expected);
1750    }
1751
1752    #[test]
1753    fn test_vec_mean() {
1754        let m = vec![1, 2, 3];
1755
1756        assert_eq!(m.mean_by(), 2.0);
1757    }
1758
1759    #[test]
1760    fn test_vec_max_diff() {
1761        let a = vec![1, 2, 3, 4, -5, 6];
1762        let b = vec![2, 3, 4, 1, 0, -12];
1763        assert_eq!(a.max_diff(&b), 18);
1764        assert_eq!(b.max_diff(&b), 0);
1765    }
1766
1767    #[test]
1768    fn test_vec_softmax() {
1769        let mut prob = vec![1., 2., 3.];
1770        prob.softmax_mut();
1771        assert!((prob[0] - 0.09).abs() < 0.01);
1772        assert!((prob[1] - 0.24).abs() < 0.01);
1773        assert!((prob[2] - 0.66).abs() < 0.01);
1774    }
1775
1776    #[test]
1777    fn test_xa() {
1778        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1779        assert_eq!(vec![7, 8].xa(false, &a), vec![39, 54, 69]);
1780        assert_eq!(vec![7, 8, 9].xa(true, &a), vec![50, 122]);
1781    }
1782
1783    #[test]
1784    fn test_min_max() {
1785        assert_eq!(
1786            DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]])
1787                .unwrap()
1788                .max(0),
1789            vec!(4, 5, 6)
1790        );
1791        assert_eq!(
1792            DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]])
1793                .unwrap()
1794                .max(1),
1795            vec!(3, 6)
1796        );
1797        assert_eq!(
1798            DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]])
1799                .unwrap()
1800                .min(0),
1801            vec!(1., 2., 3.)
1802        );
1803        assert_eq!(
1804            DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]])
1805                .unwrap()
1806                .min(1),
1807            vec!(1., 4.)
1808        );
1809    }
1810
1811    #[test]
1812    fn test_argmax() {
1813        assert_eq!(
1814            DenseMatrix::from_2d_array(&[&[1, 5, 3], &[4, 2, 6]])
1815                .unwrap()
1816                .argmax(0),
1817            vec!(1, 0, 1)
1818        );
1819        assert_eq!(
1820            DenseMatrix::from_2d_array(&[&[4, 2, 3], &[1, 5, 6]])
1821                .unwrap()
1822                .argmax(1),
1823            vec!(0, 2)
1824        );
1825    }
1826
1827    #[test]
1828    fn test_sum() {
1829        assert_eq!(
1830            DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]])
1831                .unwrap()
1832                .sum(0),
1833            vec!(5, 7, 9)
1834        );
1835        assert_eq!(
1836            DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]])
1837                .unwrap()
1838                .sum(1),
1839            vec!(6., 15.)
1840        );
1841    }
1842
1843    #[test]
1844    fn test_abs() {
1845        let mut x = DenseMatrix::from_2d_array(&[&[-1, 2, -3], &[4, -5, 6]]).unwrap();
1846        x.abs_mut();
1847        assert_eq!(
1848            x,
1849            DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap()
1850        );
1851    }
1852
1853    #[test]
1854    fn test_neg() {
1855        let mut x = DenseMatrix::from_2d_array(&[&[-1, 2, -3], &[4, -5, 6]]).unwrap();
1856        x.neg_mut();
1857        assert_eq!(
1858            x,
1859            DenseMatrix::from_2d_array(&[&[1, -2, 3], &[-4, 5, -6]]).unwrap()
1860        );
1861    }
1862
1863    #[test]
1864    fn test_copy_from() {
1865        let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1866        let mut y = DenseMatrix::<i32>::zeros(2, 3);
1867        y.copy_from(&x);
1868        assert_eq!(
1869            y,
1870            DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap()
1871        );
1872    }
1873
1874    #[test]
1875    fn test_init() {
1876        let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1877        assert_eq!(
1878            DenseMatrix::<i32>::zeros(2, 2),
1879            DenseMatrix::from_2d_array(&[&[0, 0], &[0, 0]]).unwrap()
1880        );
1881        assert_eq!(
1882            DenseMatrix::<i32>::ones(2, 2),
1883            DenseMatrix::from_2d_array(&[&[1, 1], &[1, 1]]).unwrap()
1884        );
1885        assert_eq!(
1886            DenseMatrix::<i32>::eye(3),
1887            DenseMatrix::from_2d_array(&[&[1, 0, 0], &[0, 1, 0], &[0, 0, 1]]).unwrap()
1888        );
1889        assert_eq!(
1890            DenseMatrix::from_slice(x.slice(0..2, 0..2).as_ref()), // internal only?
1891            DenseMatrix::from_2d_array(&[&[1, 2], &[4, 5]]).unwrap()
1892        );
1893        assert_eq!(
1894            DenseMatrix::from_row(x.get_row(0).as_ref()), // internal only?
1895            DenseMatrix::from_2d_array(&[&[1, 2, 3]]).unwrap()
1896        );
1897        assert_eq!(
1898            DenseMatrix::from_column(x.get_col(0).as_ref()), // internal only?
1899            DenseMatrix::from_2d_array(&[&[1], &[4]]).unwrap()
1900        );
1901    }
1902
1903    #[test]
1904    fn test_transpose() {
1905        let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1906        assert_eq!(
1907            x.transpose(),
1908            DenseMatrix::from_2d_array(&[&[1, 4], &[2, 5], &[3, 6]]).unwrap()
1909        );
1910    }
1911
1912    #[test]
1913    fn test_reshape() {
1914        let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1915        assert_eq!(
1916            x.reshape(3, 2, 0),
1917            DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap()
1918        );
1919        assert_eq!(
1920            x.reshape(3, 2, 1),
1921            DenseMatrix::from_2d_array(&[&[1, 4], &[2, 5], &[3, 6]]).unwrap()
1922        );
1923    }
1924
1925    #[test]
1926    #[should_panic]
1927    fn test_failed_reshape() {
1928        let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1929        assert_eq!(
1930            x.reshape(4, 2, 0),
1931            DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap()
1932        );
1933    }
1934
1935    #[test]
1936    fn test_matmul() {
1937        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1938        let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
1939        assert_eq!(
1940            a.matmul(&(*b.slice(0..3, 0..2))),
1941            DenseMatrix::from_2d_array(&[&[22, 28], &[49, 64]]).unwrap()
1942        );
1943        assert_eq!(
1944            a.matmul(&b),
1945            DenseMatrix::from_2d_array(&[&[22, 28], &[49, 64]]).unwrap()
1946        );
1947    }
1948
1949    #[test]
1950    fn test_concat() {
1951        let a = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]]).unwrap();
1952        let b = DenseMatrix::from_2d_array(&[&[5, 6], &[7, 8]]).unwrap();
1953
1954        assert_eq!(
1955            DenseMatrix::concatenate_1d(&[&vec!(1, 2, 3), &vec!(4, 5, 6)], 0),
1956            DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap()
1957        );
1958        assert_eq!(
1959            DenseMatrix::concatenate_1d(&[&vec!(1, 2), &vec!(3, 4)], 1),
1960            DenseMatrix::from_2d_array(&[&[1, 3], &[2, 4]]).unwrap()
1961        );
1962        assert_eq!(
1963            DenseMatrix::concatenate_2d(&[&a, &b], 0),
1964            DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6], &[7, 8]]).unwrap()
1965        );
1966        assert_eq!(
1967            DenseMatrix::concatenate_2d(&[&a, &b], 1),
1968            DenseMatrix::from_2d_array(&[&[1, 2, 5, 6], &[3, 4, 7, 8]]).unwrap()
1969        );
1970    }
1971
1972    #[test]
1973    fn test_take() {
1974        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
1975        let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
1976
1977        assert_eq!(
1978            a.take(&[0, 2], 1),
1979            DenseMatrix::from_2d_array(&[&[1, 3], &[4, 6]]).unwrap()
1980        );
1981        assert_eq!(
1982            b.take(&[0, 2], 0),
1983            DenseMatrix::from_2d_array(&[&[1, 2], &[5, 6]]).unwrap()
1984        );
1985    }
1986
1987    #[test]
1988    fn test_merge() {
1989        let a = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]]).unwrap();
1990
1991        assert_eq!(
1992            DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6], &[7, 8]]).unwrap(),
1993            a.merge_1d(&[&vec!(5, 6), &vec!(7, 8)], 0, true)
1994        );
1995        assert_eq!(
1996            DenseMatrix::from_2d_array(&[&[5, 6], &[7, 8], &[1, 2], &[3, 4]]).unwrap(),
1997            a.merge_1d(&[&vec!(5, 6), &vec!(7, 8)], 0, false)
1998        );
1999        assert_eq!(
2000            DenseMatrix::from_2d_array(&[&[1, 2, 5, 7], &[3, 4, 6, 8]]).unwrap(),
2001            a.merge_1d(&[&vec!(5, 6), &vec!(7, 8)], 1, true)
2002        );
2003        assert_eq!(
2004            DenseMatrix::from_2d_array(&[&[5, 7, 1, 2], &[6, 8, 3, 4]]).unwrap(),
2005            a.merge_1d(&[&vec!(5, 6), &vec!(7, 8)], 1, false)
2006        );
2007    }
2008
2009    #[test]
2010    fn test_ops() {
2011        assert_eq!(
2012            DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]])
2013                .unwrap()
2014                .mul_scalar(2),
2015            DenseMatrix::from_2d_array(&[&[2, 4], &[6, 8]]).unwrap()
2016        );
2017        assert_eq!(
2018            DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]])
2019                .unwrap()
2020                .add_scalar(2),
2021            DenseMatrix::from_2d_array(&[&[3, 4], &[5, 6]]).unwrap()
2022        );
2023        assert_eq!(
2024            DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]])
2025                .unwrap()
2026                .sub_scalar(1),
2027            DenseMatrix::from_2d_array(&[&[0, 1], &[2, 3]]).unwrap()
2028        );
2029        assert_eq!(
2030            DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]])
2031                .unwrap()
2032                .div_scalar(2),
2033            DenseMatrix::from_2d_array(&[&[0, 1], &[1, 2]]).unwrap()
2034        );
2035    }
2036
2037    #[test]
2038    fn test_rand() {
2039        let r = DenseMatrix::<f32>::rand(2, 2);
2040        assert!(r.iterator(0).all(|&e| e <= 1f32));
2041        assert!(r.iterator(0).all(|&e| e >= 0f32));
2042        assert!(r.iterator(0).copied().sum::<f32>() > 0f32);
2043    }
2044
2045    #[test]
2046    fn test_vstack() {
2047        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
2048        let b = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
2049        let expected = DenseMatrix::from_2d_array(&[
2050            &[1, 2, 3],
2051            &[4, 5, 6],
2052            &[7, 8, 9],
2053            &[1, 2, 3],
2054            &[4, 5, 6],
2055        ])
2056        .unwrap();
2057        let result = a.v_stack(&b);
2058        assert_eq!(result, expected);
2059    }
2060
2061    #[test]
2062    fn test_hstack() {
2063        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
2064        let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
2065        let expected =
2066            DenseMatrix::from_2d_array(&[&[1, 2, 3, 1, 2], &[4, 5, 6, 3, 4], &[7, 8, 9, 5, 6]])
2067                .unwrap();
2068        let result = a.h_stack(&b);
2069        assert_eq!(result, expected);
2070    }
2071
2072    #[test]
2073    fn test_map() {
2074        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
2075        let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
2076        let result: DenseMatrix<f64> = a.map(|&v| v as f64);
2077        assert_eq!(result, expected);
2078    }
2079
2080    #[test]
2081    fn scale() {
2082        let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap();
2083        let expected_0 = DenseMatrix::from_2d_array(&[&[-1., -1., -1.], &[1., 1., 1.]]).unwrap();
2084        let expected_1 =
2085            DenseMatrix::from_2d_array(&[&[-1.22, 0.0, 1.22], &[-1.22, 0.0, 1.22]]).unwrap();
2086
2087        {
2088            let mut m = m.clone();
2089            m.scale_mut(&m.mean_by(0), &m.std_dev(0), 0);
2090            assert!(relative_eq!(m, expected_0));
2091        }
2092
2093        m.scale_mut(&m.mean_by(1), &m.std_dev(1), 1);
2094        assert!(relative_eq!(m, expected_1, epsilon = 1e-2));
2095    }
2096
2097    #[test]
2098    fn test_pow_mut() {
2099        let mut a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
2100        a.pow_mut(2.0);
2101        assert_eq!(
2102            a,
2103            DenseMatrix::from_2d_array(&[&[1.0, 4.0, 9.0], &[16.0, 25.0, 36.0]]).unwrap()
2104        );
2105    }
2106
2107    #[test]
2108    fn test_ab() {
2109        let a = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4]]).unwrap();
2110        let b = DenseMatrix::from_2d_array(&[&[5, 6], &[7, 8]]).unwrap();
2111        assert_eq!(
2112            a.ab(false, &b, false),
2113            DenseMatrix::from_2d_array(&[&[19, 22], &[43, 50]]).unwrap()
2114        );
2115        assert_eq!(
2116            a.ab(true, &b, false),
2117            DenseMatrix::from_2d_array(&[&[26, 30], &[38, 44]]).unwrap()
2118        );
2119        assert_eq!(
2120            a.ab(false, &b, true),
2121            DenseMatrix::from_2d_array(&[&[17, 23], &[39, 53]]).unwrap()
2122        );
2123        assert_eq!(
2124            a.ab(true, &b, true),
2125            DenseMatrix::from_2d_array(&[&[23, 31], &[34, 46]]).unwrap()
2126        );
2127    }
2128
2129    #[test]
2130    fn test_ax() {
2131        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
2132        assert_eq!(
2133            a.ax(false, &vec![7, 8, 9]).transpose(),
2134            DenseMatrix::from_2d_array(&[&[50, 122]]).unwrap()
2135        );
2136        assert_eq!(
2137            a.ax(true, &vec![7, 8]).transpose(),
2138            DenseMatrix::from_2d_array(&[&[39, 54, 69]]).unwrap()
2139        );
2140    }
2141
2142    #[test]
2143    fn diag() {
2144        let x = DenseMatrix::from_2d_array(&[&[0, 1, 2], &[3, 4, 5], &[6, 7, 8]]).unwrap();
2145        assert_eq!(x.diag(), vec![0, 4, 8]);
2146    }
2147
2148    #[test]
2149    fn test_cov() {
2150        let a = DenseMatrix::from_2d_array(&[
2151            &[64, 580, 29],
2152            &[66, 570, 33],
2153            &[68, 590, 37],
2154            &[69, 660, 46],
2155            &[73, 600, 55],
2156        ])
2157        .unwrap();
2158        let mut result = DenseMatrix::zeros(3, 3);
2159        let expected = DenseMatrix::from_2d_array(&[
2160            &[11.5, 50.0, 34.75],
2161            &[50.0, 1250.0, 205.0],
2162            &[34.75, 205.0, 110.0],
2163        ])
2164        .unwrap();
2165
2166        a.cov(&mut result);
2167
2168        assert_eq!(result, expected);
2169    }
2170
2171    #[test]
2172    fn test_from_iter() {
2173        let vec_a = Vec::from([64, 580, 29, 66, 570, 33]);
2174        let vec_a_len = vec_a.len();
2175        let mut a: Vec<i32> = Array1::<i32>::from_iterator(vec_a.into_iter(), vec_a_len);
2176
2177        let vec_b = vec![1, 1, 1, 1, 1, 1];
2178        a.sub_mut(&vec_b);
2179
2180        assert_eq!(a, [63, 579, 28, 65, 569, 32])
2181    }
2182
2183    #[test]
2184    fn test_from_vec_slice() {
2185        let vec_a = Vec::from([64, 580, 29, 66, 570, 33]);
2186        let a: Vec<i32> = Array1::<i32>::from_vec_slice(&vec_a[0..3]);
2187
2188        let vec_b = vec![1, 1, 1];
2189        let result = a.add(&vec_b);
2190
2191        assert_eq!(result, [65, 581, 30])
2192    }
2193}