smartcore/linalg/basic/
matrix.rs

1use std::fmt;
2use std::fmt::{Debug, Display};
3use std::ops::Range;
4use std::slice::Iter;
5
6use approx::{AbsDiffEq, RelativeEq};
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10use crate::linalg::basic::arrays::{
11    Array, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2,
12};
13use crate::linalg::traits::cholesky::CholeskyDecomposable;
14use crate::linalg::traits::evd::EVDDecomposable;
15use crate::linalg::traits::lu::LUDecomposable;
16use crate::linalg::traits::qr::QRDecomposable;
17use crate::linalg::traits::stats::{MatrixPreprocessing, MatrixStats};
18use crate::linalg::traits::svd::SVDDecomposable;
19use crate::numbers::basenum::Number;
20use crate::numbers::realnum::RealNumber;
21
22use crate::error::Failed;
23
24/// Dense matrix
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[derive(Debug, Clone)]
27pub struct DenseMatrix<T> {
28    ncols: usize,
29    nrows: usize,
30    values: Vec<T>,
31    column_major: bool,
32}
33
34/// View on dense matrix
35#[derive(Debug, Clone)]
36pub struct DenseMatrixView<'a, T: Debug + Display + Copy + Sized> {
37    values: &'a [T],
38    stride: usize,
39    nrows: usize,
40    ncols: usize,
41    column_major: bool,
42}
43
44/// Mutable view on dense matrix
45#[derive(Debug)]
46pub struct DenseMatrixMutView<'a, T: Debug + Display + Copy + Sized> {
47    values: &'a mut [T],
48    stride: usize,
49    nrows: usize,
50    ncols: usize,
51    column_major: bool,
52}
53
54impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
55    fn new(
56        m: &'a DenseMatrix<T>,
57        vrows: Range<usize>,
58        vcols: Range<usize>,
59    ) -> Result<Self, Failed> {
60        if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
61            Err(Failed::input(
62                "The specified view is outside of the matrix range",
63            ))
64        } else {
65            let (start, end, stride) =
66                m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
67
68            Ok(DenseMatrixView {
69                values: &m.values[start..end],
70                stride,
71                nrows: vrows.end - vrows.start,
72                ncols: vcols.end - vcols.start,
73                column_major: m.column_major,
74            })
75        }
76    }
77
78    fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
79        assert!(
80            axis == 1 || axis == 0,
81            "For two dimensional array `axis` should be either 0 or 1"
82        );
83        match axis {
84            0 => Box::new(
85                (0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
86            ),
87            _ => Box::new(
88                (0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
89            ),
90        }
91    }
92}
93
94impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        writeln!(
97            f,
98            "DenseMatrix: nrows: {:?}, ncols: {:?}",
99            self.nrows, self.ncols
100        )?;
101        writeln!(f, "column_major: {:?}", self.column_major)?;
102        self.display(f)
103    }
104}
105
106impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
107    fn new(
108        m: &'a mut DenseMatrix<T>,
109        vrows: Range<usize>,
110        vcols: Range<usize>,
111    ) -> Result<Self, Failed> {
112        if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
113            Err(Failed::input(
114                "The specified view is outside of the matrix range",
115            ))
116        } else {
117            let (start, end, stride) =
118                m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
119
120            Ok(DenseMatrixMutView {
121                values: &mut m.values[start..end],
122                stride,
123                nrows: vrows.end - vrows.start,
124                ncols: vcols.end - vcols.start,
125                column_major: m.column_major,
126            })
127        }
128    }
129
130    fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
131        assert!(
132            axis == 1 || axis == 0,
133            "For two dimensional array `axis` should be either 0 or 1"
134        );
135        match axis {
136            0 => Box::new(
137                (0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
138            ),
139            _ => Box::new(
140                (0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
141            ),
142        }
143    }
144
145    fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
146        let column_major = self.column_major;
147        let stride = self.stride;
148        let ptr = self.values.as_mut_ptr();
149        match axis {
150            0 => Box::new((0..self.nrows).flat_map(move |r| {
151                (0..self.ncols).map(move |c| unsafe {
152                    &mut *ptr.add(if column_major {
153                        r + c * stride
154                    } else {
155                        r * stride + c
156                    })
157                })
158            })),
159            _ => Box::new((0..self.ncols).flat_map(move |c| {
160                (0..self.nrows).map(move |r| unsafe {
161                    &mut *ptr.add(if column_major {
162                        r + c * stride
163                    } else {
164                        r * stride + c
165                    })
166                })
167            })),
168        }
169    }
170}
171
172impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        writeln!(
175            f,
176            "DenseMatrix: nrows: {:?}, ncols: {:?}",
177            self.nrows, self.ncols
178        )?;
179        writeln!(f, "column_major: {:?}", self.column_major)?;
180        self.display(f)
181    }
182}
183
184impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
185    /// Create new instance of `DenseMatrix` without copying data.
186    /// `values` should be in column-major order.
187    pub fn new(
188        nrows: usize,
189        ncols: usize,
190        values: Vec<T>,
191        column_major: bool,
192    ) -> Result<Self, Failed> {
193        let data_len = values.len();
194        if nrows * ncols != values.len() {
195            Err(Failed::input(&format!(
196                "The specified shape: (cols: {ncols}, rows: {nrows}) does not align with data len: {data_len}"
197            )))
198        } else {
199            Ok(DenseMatrix {
200                ncols,
201                nrows,
202                values,
203                column_major,
204            })
205        }
206    }
207
208    /// New instance of `DenseMatrix` from 2d array.
209    pub fn from_2d_array(values: &[&[T]]) -> Result<Self, Failed> {
210        DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect())
211    }
212
213    /// New instance of `DenseMatrix` from 2d vector.
214    #[allow(clippy::ptr_arg)]
215    pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Result<Self, Failed> {
216        if values.is_empty() || values[0].is_empty() {
217            Err(Failed::input(
218                "The 2d vec provided is empty; cannot instantiate the matrix",
219            ))
220        } else {
221            let nrows = values.len();
222            let ncols = values
223                .first()
224                .unwrap_or_else(|| {
225                    panic!("Invalid state: Cannot create 2d matrix from an empty vector")
226                })
227                .len();
228            let mut m_values = Vec::with_capacity(nrows * ncols);
229
230            for c in 0..ncols {
231                for r in values.iter().take(nrows) {
232                    m_values.push(r[c])
233                }
234            }
235
236            DenseMatrix::new(nrows, ncols, m_values, true)
237        }
238    }
239
240    /// Iterate over values of matrix
241    pub fn iter(&self) -> Iter<'_, T> {
242        self.values.iter()
243    }
244
245    ///  Check if the size of the requested view is bounded to matrix rows/cols count
246    fn is_valid_view(
247        &self,
248        n_rows: usize,
249        n_cols: usize,
250        vrows: &Range<usize>,
251        vcols: &Range<usize>,
252    ) -> bool {
253        !(vrows.end <= n_rows
254            && vcols.end <= n_cols
255            && vrows.start <= n_rows
256            && vcols.start <= n_cols)
257    }
258
259    ///  Compute the range of the requested view: start, end, size of the slice
260    fn stride_range(
261        &self,
262        n_rows: usize,
263        n_cols: usize,
264        vrows: &Range<usize>,
265        vcols: &Range<usize>,
266        column_major: bool,
267    ) -> (usize, usize, usize) {
268        let (start, end, stride) = if column_major {
269            (
270                vrows.start + vcols.start * n_rows,
271                vrows.end + (vcols.end - 1) * n_rows,
272                n_rows,
273            )
274        } else {
275            (
276                vrows.start * n_cols + vcols.start,
277                (vrows.end - 1) * n_cols + vcols.end,
278                n_cols,
279            )
280        };
281        (start, end, stride)
282    }
283}
284
285impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        writeln!(
288            f,
289            "DenseMatrix: nrows: {:?}, ncols: {:?}",
290            self.nrows, self.ncols
291        )?;
292        writeln!(f, "column_major: {:?}", self.column_major)?;
293        self.display(f)
294    }
295}
296
297impl<T: Debug + Display + Copy + Sized + PartialEq> PartialEq for DenseMatrix<T> {
298    fn eq(&self, other: &Self) -> bool {
299        if self.ncols != other.ncols || self.nrows != other.nrows {
300            return false;
301        }
302
303        let len = self.values.len();
304        let other_len = other.values.len();
305
306        if len != other_len {
307            return false;
308        }
309
310        match self.column_major == other.column_major {
311            true => self
312                .values
313                .iter()
314                .zip(other.values.iter())
315                .all(|(&v1, v2)| v1.eq(v2)),
316            false => self
317                .iterator(0)
318                .zip(other.iterator(0))
319                .all(|(&v1, v2)| v1.eq(v2)),
320        }
321    }
322}
323
324impl<T: Number + RealNumber + AbsDiffEq> AbsDiffEq for DenseMatrix<T>
325where
326    T::Epsilon: Copy,
327{
328    type Epsilon = T::Epsilon;
329
330    fn default_epsilon() -> T::Epsilon {
331        T::default_epsilon()
332    }
333
334    // equality in differences in absolute values, according to an epsilon
335    fn abs_diff_eq(&self, other: &Self, epsilon: T::Epsilon) -> bool {
336        if self.ncols != other.ncols || self.nrows != other.nrows {
337            false
338        } else {
339            self.values
340                .iter()
341                .zip(other.values.iter())
342                .all(|(v1, v2)| T::abs_diff_eq(v1, v2, epsilon))
343        }
344    }
345}
346
347impl<T: Number + RealNumber + RelativeEq> RelativeEq for DenseMatrix<T>
348where
349    T::Epsilon: Copy,
350{
351    fn default_max_relative() -> T::Epsilon {
352        T::default_max_relative()
353    }
354
355    fn relative_eq(&self, other: &Self, epsilon: T::Epsilon, max_relative: T::Epsilon) -> bool {
356        if self.ncols != other.ncols || self.nrows != other.nrows {
357            false
358        } else {
359            self.iterator(0)
360                .zip(other.iterator(0))
361                .all(|(v1, v2)| T::relative_eq(v1, v2, epsilon, max_relative))
362        }
363    }
364}
365
366impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix<T> {
367    fn get(&self, pos: (usize, usize)) -> &T {
368        let (row, col) = pos;
369
370        if row >= self.nrows || col >= self.ncols {
371            panic!(
372                "Invalid index ({},{}) for {}x{} matrix",
373                row, col, self.nrows, self.ncols
374            );
375        }
376        if self.column_major {
377            &self.values[col * self.nrows + row]
378        } else {
379            &self.values[col + self.ncols * row]
380        }
381    }
382
383    fn shape(&self) -> (usize, usize) {
384        (self.nrows, self.ncols)
385    }
386
387    fn is_empty(&self) -> bool {
388        self.ncols > 0 && self.nrows > 0
389    }
390
391    fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
392        assert!(
393            axis == 1 || axis == 0,
394            "For two dimensional array `axis` should be either 0 or 1"
395        );
396        match axis {
397            0 => Box::new(
398                (0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
399            ),
400            _ => Box::new(
401                (0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
402            ),
403        }
404    }
405}
406
407impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrix<T> {
408    fn set(&mut self, pos: (usize, usize), x: T) {
409        if self.column_major {
410            self.values[pos.1 * self.nrows + pos.0] = x;
411        } else {
412            self.values[pos.1 + pos.0 * self.ncols] = x;
413        }
414    }
415
416    fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
417        let ptr = self.values.as_mut_ptr();
418        let column_major = self.column_major;
419        let (nrows, ncols) = self.shape();
420        match axis {
421            0 => Box::new((0..self.nrows).flat_map(move |r| {
422                (0..self.ncols).map(move |c| unsafe {
423                    &mut *ptr.add(if column_major {
424                        r + c * nrows
425                    } else {
426                        r * ncols + c
427                    })
428                })
429            })),
430            _ => Box::new((0..self.ncols).flat_map(move |c| {
431                (0..self.nrows).map(move |r| unsafe {
432                    &mut *ptr.add(if column_major {
433                        r + c * nrows
434                    } else {
435                        r * ncols + c
436                    })
437                })
438            })),
439        }
440    }
441}
442
443impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrix<T> {}
444
445impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrix<T> {}
446
447impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
448    fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
449        Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols).unwrap())
450    }
451
452    fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
453        Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1).unwrap())
454    }
455
456    fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
457        Box::new(DenseMatrixView::new(self, rows, cols).unwrap())
458    }
459
460    fn slice_mut<'a>(
461        &'a mut self,
462        rows: Range<usize>,
463        cols: Range<usize>,
464    ) -> Box<dyn MutArrayView2<T> + 'a>
465    where
466        Self: Sized,
467    {
468        Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap())
469    }
470
471    // private function so for now assume infalible
472    fn fill(nrows: usize, ncols: usize, value: T) -> Self {
473        DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap()
474    }
475
476    // private function so for now assume infalible
477    fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
478        DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap()
479    }
480
481    fn transpose(&self) -> Self {
482        let mut m = self.clone();
483        m.ncols = self.nrows;
484        m.nrows = self.ncols;
485        m.column_major = !self.column_major;
486        m
487    }
488}
489
490impl<T: Number + RealNumber> QRDecomposable<T> for DenseMatrix<T> {}
491impl<T: Number + RealNumber> CholeskyDecomposable<T> for DenseMatrix<T> {}
492impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
493impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
494impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
495
496impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
497    fn get(&self, pos: (usize, usize)) -> &T {
498        if self.column_major {
499            &self.values[pos.0 + pos.1 * self.stride]
500        } else {
501            &self.values[pos.0 * self.stride + pos.1]
502        }
503    }
504
505    fn shape(&self) -> (usize, usize) {
506        (self.nrows, self.ncols)
507    }
508
509    fn is_empty(&self) -> bool {
510        self.nrows * self.ncols > 0
511    }
512
513    fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
514        self.iter(axis)
515    }
516}
517
518impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
519    fn get(&self, i: usize) -> &T {
520        if self.nrows == 1 {
521            if self.column_major {
522                &self.values[i * self.stride]
523            } else {
524                &self.values[i]
525            }
526        } else if self.ncols == 1 || (!self.column_major && self.nrows == 1) {
527            if self.column_major {
528                &self.values[i]
529            } else {
530                &self.values[i * self.stride]
531            }
532        } else {
533            panic!("This is neither a column nor a row");
534        }
535    }
536
537    fn shape(&self) -> usize {
538        if self.nrows == 1 {
539            self.ncols
540        } else if self.ncols == 1 {
541            self.nrows
542        } else {
543            panic!("This is neither a column nor a row");
544        }
545    }
546
547    fn is_empty(&self) -> bool {
548        self.nrows * self.ncols > 0
549    }
550
551    fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
552        self.iter(axis)
553    }
554}
555
556impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}
557
558impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}
559
560impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
561    fn get(&self, pos: (usize, usize)) -> &T {
562        if self.column_major {
563            &self.values[pos.0 + pos.1 * self.stride]
564        } else {
565            &self.values[pos.0 * self.stride + pos.1]
566        }
567    }
568
569    fn shape(&self) -> (usize, usize) {
570        (self.nrows, self.ncols)
571    }
572
573    fn is_empty(&self) -> bool {
574        self.nrows * self.ncols > 0
575    }
576
577    fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
578        self.iter(axis)
579    }
580}
581
582impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
583    fn set(&mut self, pos: (usize, usize), x: T) {
584        if self.column_major {
585            self.values[pos.0 + pos.1 * self.stride] = x;
586        } else {
587            self.values[pos.0 * self.stride + pos.1] = x;
588        }
589    }
590
591    fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
592        self.iter_mut(axis)
593    }
594}
595
596impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}
597
598impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}
599
600impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
601
602impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
603
604#[cfg(test)]
605#[warn(clippy::reversed_empty_ranges)]
606mod tests {
607    use super::*;
608    use approx::relative_eq;
609
610    #[test]
611    fn test_instantiate_from_2d() {
612        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
613        assert!(x.is_ok());
614    }
615    #[test]
616    fn test_instantiate_from_2d_empty() {
617        let input: &[&[f64]] = &[&[]];
618        let x = DenseMatrix::from_2d_array(input);
619        assert!(x.is_err());
620    }
621    #[test]
622    fn test_instantiate_from_2d_empty2() {
623        let input: &[&[f64]] = &[&[], &[]];
624        let x = DenseMatrix::from_2d_array(input);
625        assert!(x.is_err());
626    }
627    #[test]
628    fn test_instantiate_ok_view1() {
629        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
630        let v = DenseMatrixView::new(&x, 0..2, 0..2);
631        assert!(v.is_ok());
632    }
633    #[test]
634    fn test_instantiate_ok_view2() {
635        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
636        let v = DenseMatrixView::new(&x, 0..3, 0..3);
637        assert!(v.is_ok());
638    }
639    #[test]
640    fn test_instantiate_ok_view3() {
641        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
642        let v = DenseMatrixView::new(&x, 2..3, 0..3);
643        assert!(v.is_ok());
644    }
645    #[test]
646    fn test_instantiate_ok_view4() {
647        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
648        let v = DenseMatrixView::new(&x, 3..3, 0..3);
649        assert!(v.is_ok());
650    }
651    #[test]
652    fn test_instantiate_err_view1() {
653        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
654        let v = DenseMatrixView::new(&x, 3..4, 0..3);
655        assert!(v.is_err());
656    }
657    #[test]
658    fn test_instantiate_err_view2() {
659        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
660        let v = DenseMatrixView::new(&x, 0..3, 3..4);
661        assert!(v.is_err());
662    }
663    #[test]
664    fn test_instantiate_err_view3() {
665        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
666        #[allow(clippy::reversed_empty_ranges)]
667        let v = DenseMatrixView::new(&x, 0..3, 4..3);
668        assert!(v.is_err());
669    }
670    #[test]
671    fn test_display() {
672        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
673
674        println!("{}", &x);
675    }
676
677    #[test]
678    fn test_get_row_col() {
679        let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
680
681        assert_eq!(15.0, x.get_col(1).sum());
682        assert_eq!(15.0, x.get_row(1).sum());
683        assert_eq!(81.0, x.get_col(1).dot(&(*x.get_row(1))));
684    }
685
686    #[test]
687    fn test_row_major() {
688        let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false).unwrap();
689
690        assert_eq!(5, *x.get_col(1).get(1));
691        assert_eq!(7, x.get_col(1).sum());
692        assert_eq!(5, *x.get_row(1).get(1));
693        assert_eq!(15, x.get_row(1).sum());
694        x.slice_mut(0..2, 1..2)
695            .iterator_mut(0)
696            .for_each(|v| *v += 2);
697        assert_eq!(vec![1, 4, 3, 4, 7, 6], *x.values);
698    }
699
700    #[test]
701    fn test_get_slice() {
702        let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
703            .unwrap();
704
705        assert_eq!(
706            vec![4, 5, 6],
707            DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values
708        );
709        let second_row: Vec<i32> = x.slice(1..2, 0..3).iterator(0).copied().collect();
710        assert_eq!(vec![4, 5, 6], second_row);
711        let second_col: Vec<i32> = x.slice(0..3, 1..2).iterator(0).copied().collect();
712        assert_eq!(vec![2, 5, 8], second_col);
713    }
714
715    #[test]
716    fn test_iter_mut() {
717        let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
718
719        assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values);
720        // add +2 to some elements
721        x.slice_mut(1..2, 0..3)
722            .iterator_mut(0)
723            .for_each(|v| *v += 2);
724        assert_eq!(vec![1, 6, 7, 2, 7, 8, 3, 8, 9], x.values);
725        // add +1 to some others
726        x.slice_mut(0..3, 1..2)
727            .iterator_mut(0)
728            .for_each(|v| *v += 1);
729        assert_eq!(vec![1, 6, 7, 3, 8, 9, 3, 8, 9], x.values);
730
731        // rewrite matrix as indices of values per axis 1 (row-wise)
732        x.iterator_mut(1).enumerate().for_each(|(a, b)| *b = a);
733        assert_eq!(vec![0, 1, 2, 3, 4, 5, 6, 7, 8], x.values);
734        // rewrite matrix as indices of values per axis 0 (column-wise)
735        x.iterator_mut(0).enumerate().for_each(|(a, b)| *b = a);
736        assert_eq!(vec![0, 3, 6, 1, 4, 7, 2, 5, 8], x.values);
737        // rewrite some by slice
738        x.slice_mut(0..3, 0..2)
739            .iterator_mut(0)
740            .enumerate()
741            .for_each(|(a, b)| *b = a);
742        assert_eq!(vec![0, 2, 4, 1, 3, 5, 2, 5, 8], x.values);
743        x.slice_mut(0..2, 0..3)
744            .iterator_mut(1)
745            .enumerate()
746            .for_each(|(a, b)| *b = a);
747        assert_eq!(vec![0, 1, 4, 2, 3, 5, 4, 5, 8], x.values);
748    }
749
750    #[test]
751    fn test_str_array() {
752        let mut x =
753            DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]])
754                .unwrap();
755
756        assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
757        x.iterator_mut(0).for_each(|v| *v = "str");
758        assert_eq!(
759            vec!["str", "str", "str", "str", "str", "str", "str", "str", "str"],
760            x.values
761        );
762    }
763
764    #[test]
765    fn test_transpose() {
766        let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]).unwrap();
767
768        assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
769        assert!(x.column_major);
770
771        // transpose
772        let x = x.transpose();
773        assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
774        assert!(!x.column_major); // should change column_major
775    }
776
777    #[test]
778    fn test_from_iterator() {
779        let data = [1, 2, 3, 4, 5, 6];
780
781        let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0);
782
783        // make a vector into a 2x3 matrix.
784        assert_eq!(
785            vec![1, 2, 3, 4, 5, 6],
786            m.values.iter().map(|e| **e).collect::<Vec<i32>>()
787        );
788        assert!(!m.column_major);
789    }
790
791    #[test]
792    fn test_take() {
793        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
794        let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
795
796        println!("{a}");
797        // take column 0 and 2
798        assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values);
799        println!("{b}");
800        // take rows 0 and 2
801        assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values);
802    }
803
804    #[test]
805    fn test_mut() {
806        let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]).unwrap();
807
808        let a = a.abs();
809        assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values);
810
811        let a = a.neg();
812        assert_eq!(vec![-1.3, -4.0, -2.1, -5.3, -3.4, -6.1], a.values);
813    }
814
815    #[test]
816    fn test_reshape() {
817        let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
818            .unwrap();
819
820        let a = a.reshape(2, 6, 0);
821        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
822        assert!(a.ncols == 6 && a.nrows == 2 && !a.column_major);
823
824        let a = a.reshape(3, 4, 1);
825        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
826        assert!(a.ncols == 4 && a.nrows == 3 && a.column_major);
827    }
828
829    #[test]
830    fn test_eq() {
831        let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap();
832        let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
833        let c = DenseMatrix::from_2d_array(&[
834            &[1. + f32::EPSILON, 2., 3.],
835            &[4., 5., 6. + f32::EPSILON],
836        ])
837        .unwrap();
838        let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]])
839            .unwrap();
840
841        assert!(!relative_eq!(a, b));
842        assert!(!relative_eq!(a, d));
843        assert!(relative_eq!(a, c));
844    }
845}