1use std::fmt;
2use std::fmt::{Debug, Display};
3use std::ops::Range;
4use std::slice::Iter;
5
6use approx::{AbsDiffEq, RelativeEq};
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10use crate::linalg::basic::arrays::{
11 Array, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2,
12};
13use crate::linalg::traits::cholesky::CholeskyDecomposable;
14use crate::linalg::traits::evd::EVDDecomposable;
15use crate::linalg::traits::lu::LUDecomposable;
16use crate::linalg::traits::qr::QRDecomposable;
17use crate::linalg::traits::stats::{MatrixPreprocessing, MatrixStats};
18use crate::linalg::traits::svd::SVDDecomposable;
19use crate::numbers::basenum::Number;
20use crate::numbers::realnum::RealNumber;
21
22use crate::error::Failed;
23
24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[derive(Debug, Clone)]
27pub struct DenseMatrix<T> {
28 ncols: usize,
29 nrows: usize,
30 values: Vec<T>,
31 column_major: bool,
32}
33
34#[derive(Debug, Clone)]
36pub struct DenseMatrixView<'a, T: Debug + Display + Copy + Sized> {
37 values: &'a [T],
38 stride: usize,
39 nrows: usize,
40 ncols: usize,
41 column_major: bool,
42}
43
44#[derive(Debug)]
46pub struct DenseMatrixMutView<'a, T: Debug + Display + Copy + Sized> {
47 values: &'a mut [T],
48 stride: usize,
49 nrows: usize,
50 ncols: usize,
51 column_major: bool,
52}
53
54impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
55 fn new(
56 m: &'a DenseMatrix<T>,
57 vrows: Range<usize>,
58 vcols: Range<usize>,
59 ) -> Result<Self, Failed> {
60 if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
61 Err(Failed::input(
62 "The specified view is outside of the matrix range",
63 ))
64 } else {
65 let (start, end, stride) =
66 m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
67
68 Ok(DenseMatrixView {
69 values: &m.values[start..end],
70 stride,
71 nrows: vrows.end - vrows.start,
72 ncols: vcols.end - vcols.start,
73 column_major: m.column_major,
74 })
75 }
76 }
77
78 fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
79 assert!(
80 axis == 1 || axis == 0,
81 "For two dimensional array `axis` should be either 0 or 1"
82 );
83 match axis {
84 0 => Box::new(
85 (0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
86 ),
87 _ => Box::new(
88 (0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
89 ),
90 }
91 }
92}
93
94impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 writeln!(
97 f,
98 "DenseMatrix: nrows: {:?}, ncols: {:?}",
99 self.nrows, self.ncols
100 )?;
101 writeln!(f, "column_major: {:?}", self.column_major)?;
102 self.display(f)
103 }
104}
105
106impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
107 fn new(
108 m: &'a mut DenseMatrix<T>,
109 vrows: Range<usize>,
110 vcols: Range<usize>,
111 ) -> Result<Self, Failed> {
112 if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
113 Err(Failed::input(
114 "The specified view is outside of the matrix range",
115 ))
116 } else {
117 let (start, end, stride) =
118 m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
119
120 Ok(DenseMatrixMutView {
121 values: &mut m.values[start..end],
122 stride,
123 nrows: vrows.end - vrows.start,
124 ncols: vcols.end - vcols.start,
125 column_major: m.column_major,
126 })
127 }
128 }
129
130 fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
131 assert!(
132 axis == 1 || axis == 0,
133 "For two dimensional array `axis` should be either 0 or 1"
134 );
135 match axis {
136 0 => Box::new(
137 (0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
138 ),
139 _ => Box::new(
140 (0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
141 ),
142 }
143 }
144
145 fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
146 let column_major = self.column_major;
147 let stride = self.stride;
148 let ptr = self.values.as_mut_ptr();
149 match axis {
150 0 => Box::new((0..self.nrows).flat_map(move |r| {
151 (0..self.ncols).map(move |c| unsafe {
152 &mut *ptr.add(if column_major {
153 r + c * stride
154 } else {
155 r * stride + c
156 })
157 })
158 })),
159 _ => Box::new((0..self.ncols).flat_map(move |c| {
160 (0..self.nrows).map(move |r| unsafe {
161 &mut *ptr.add(if column_major {
162 r + c * stride
163 } else {
164 r * stride + c
165 })
166 })
167 })),
168 }
169 }
170}
171
172impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 writeln!(
175 f,
176 "DenseMatrix: nrows: {:?}, ncols: {:?}",
177 self.nrows, self.ncols
178 )?;
179 writeln!(f, "column_major: {:?}", self.column_major)?;
180 self.display(f)
181 }
182}
183
184impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
185 pub fn new(
188 nrows: usize,
189 ncols: usize,
190 values: Vec<T>,
191 column_major: bool,
192 ) -> Result<Self, Failed> {
193 let data_len = values.len();
194 if nrows * ncols != values.len() {
195 Err(Failed::input(&format!(
196 "The specified shape: (cols: {ncols}, rows: {nrows}) does not align with data len: {data_len}"
197 )))
198 } else {
199 Ok(DenseMatrix {
200 ncols,
201 nrows,
202 values,
203 column_major,
204 })
205 }
206 }
207
208 pub fn from_2d_array(values: &[&[T]]) -> Result<Self, Failed> {
210 DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect())
211 }
212
213 #[allow(clippy::ptr_arg)]
215 pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Result<Self, Failed> {
216 if values.is_empty() || values[0].is_empty() {
217 Err(Failed::input(
218 "The 2d vec provided is empty; cannot instantiate the matrix",
219 ))
220 } else {
221 let nrows = values.len();
222 let ncols = values
223 .first()
224 .unwrap_or_else(|| {
225 panic!("Invalid state: Cannot create 2d matrix from an empty vector")
226 })
227 .len();
228 let mut m_values = Vec::with_capacity(nrows * ncols);
229
230 for c in 0..ncols {
231 for r in values.iter().take(nrows) {
232 m_values.push(r[c])
233 }
234 }
235
236 DenseMatrix::new(nrows, ncols, m_values, true)
237 }
238 }
239
240 pub fn iter(&self) -> Iter<'_, T> {
242 self.values.iter()
243 }
244
245 fn is_valid_view(
247 &self,
248 n_rows: usize,
249 n_cols: usize,
250 vrows: &Range<usize>,
251 vcols: &Range<usize>,
252 ) -> bool {
253 !(vrows.end <= n_rows
254 && vcols.end <= n_cols
255 && vrows.start <= n_rows
256 && vcols.start <= n_cols)
257 }
258
259 fn stride_range(
261 &self,
262 n_rows: usize,
263 n_cols: usize,
264 vrows: &Range<usize>,
265 vcols: &Range<usize>,
266 column_major: bool,
267 ) -> (usize, usize, usize) {
268 let (start, end, stride) = if column_major {
269 (
270 vrows.start + vcols.start * n_rows,
271 vrows.end + (vcols.end - 1) * n_rows,
272 n_rows,
273 )
274 } else {
275 (
276 vrows.start * n_cols + vcols.start,
277 (vrows.end - 1) * n_cols + vcols.end,
278 n_cols,
279 )
280 };
281 (start, end, stride)
282 }
283}
284
285impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 writeln!(
288 f,
289 "DenseMatrix: nrows: {:?}, ncols: {:?}",
290 self.nrows, self.ncols
291 )?;
292 writeln!(f, "column_major: {:?}", self.column_major)?;
293 self.display(f)
294 }
295}
296
297impl<T: Debug + Display + Copy + Sized + PartialEq> PartialEq for DenseMatrix<T> {
298 fn eq(&self, other: &Self) -> bool {
299 if self.ncols != other.ncols || self.nrows != other.nrows {
300 return false;
301 }
302
303 let len = self.values.len();
304 let other_len = other.values.len();
305
306 if len != other_len {
307 return false;
308 }
309
310 match self.column_major == other.column_major {
311 true => self
312 .values
313 .iter()
314 .zip(other.values.iter())
315 .all(|(&v1, v2)| v1.eq(v2)),
316 false => self
317 .iterator(0)
318 .zip(other.iterator(0))
319 .all(|(&v1, v2)| v1.eq(v2)),
320 }
321 }
322}
323
324impl<T: Number + RealNumber + AbsDiffEq> AbsDiffEq for DenseMatrix<T>
325where
326 T::Epsilon: Copy,
327{
328 type Epsilon = T::Epsilon;
329
330 fn default_epsilon() -> T::Epsilon {
331 T::default_epsilon()
332 }
333
334 fn abs_diff_eq(&self, other: &Self, epsilon: T::Epsilon) -> bool {
336 if self.ncols != other.ncols || self.nrows != other.nrows {
337 false
338 } else {
339 self.values
340 .iter()
341 .zip(other.values.iter())
342 .all(|(v1, v2)| T::abs_diff_eq(v1, v2, epsilon))
343 }
344 }
345}
346
347impl<T: Number + RealNumber + RelativeEq> RelativeEq for DenseMatrix<T>
348where
349 T::Epsilon: Copy,
350{
351 fn default_max_relative() -> T::Epsilon {
352 T::default_max_relative()
353 }
354
355 fn relative_eq(&self, other: &Self, epsilon: T::Epsilon, max_relative: T::Epsilon) -> bool {
356 if self.ncols != other.ncols || self.nrows != other.nrows {
357 false
358 } else {
359 self.iterator(0)
360 .zip(other.iterator(0))
361 .all(|(v1, v2)| T::relative_eq(v1, v2, epsilon, max_relative))
362 }
363 }
364}
365
366impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix<T> {
367 fn get(&self, pos: (usize, usize)) -> &T {
368 let (row, col) = pos;
369
370 if row >= self.nrows || col >= self.ncols {
371 panic!(
372 "Invalid index ({},{}) for {}x{} matrix",
373 row, col, self.nrows, self.ncols
374 );
375 }
376 if self.column_major {
377 &self.values[col * self.nrows + row]
378 } else {
379 &self.values[col + self.ncols * row]
380 }
381 }
382
383 fn shape(&self) -> (usize, usize) {
384 (self.nrows, self.ncols)
385 }
386
387 fn is_empty(&self) -> bool {
388 self.ncols > 0 && self.nrows > 0
389 }
390
391 fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
392 assert!(
393 axis == 1 || axis == 0,
394 "For two dimensional array `axis` should be either 0 or 1"
395 );
396 match axis {
397 0 => Box::new(
398 (0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
399 ),
400 _ => Box::new(
401 (0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
402 ),
403 }
404 }
405}
406
407impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrix<T> {
408 fn set(&mut self, pos: (usize, usize), x: T) {
409 if self.column_major {
410 self.values[pos.1 * self.nrows + pos.0] = x;
411 } else {
412 self.values[pos.1 + pos.0 * self.ncols] = x;
413 }
414 }
415
416 fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
417 let ptr = self.values.as_mut_ptr();
418 let column_major = self.column_major;
419 let (nrows, ncols) = self.shape();
420 match axis {
421 0 => Box::new((0..self.nrows).flat_map(move |r| {
422 (0..self.ncols).map(move |c| unsafe {
423 &mut *ptr.add(if column_major {
424 r + c * nrows
425 } else {
426 r * ncols + c
427 })
428 })
429 })),
430 _ => Box::new((0..self.ncols).flat_map(move |c| {
431 (0..self.nrows).map(move |r| unsafe {
432 &mut *ptr.add(if column_major {
433 r + c * nrows
434 } else {
435 r * ncols + c
436 })
437 })
438 })),
439 }
440 }
441}
442
443impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrix<T> {}
444
445impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrix<T> {}
446
447impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
448 fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
449 Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols).unwrap())
450 }
451
452 fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
453 Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1).unwrap())
454 }
455
456 fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
457 Box::new(DenseMatrixView::new(self, rows, cols).unwrap())
458 }
459
460 fn slice_mut<'a>(
461 &'a mut self,
462 rows: Range<usize>,
463 cols: Range<usize>,
464 ) -> Box<dyn MutArrayView2<T> + 'a>
465 where
466 Self: Sized,
467 {
468 Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap())
469 }
470
471 fn fill(nrows: usize, ncols: usize, value: T) -> Self {
473 DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap()
474 }
475
476 fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
478 DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap()
479 }
480
481 fn transpose(&self) -> Self {
482 let mut m = self.clone();
483 m.ncols = self.nrows;
484 m.nrows = self.ncols;
485 m.column_major = !self.column_major;
486 m
487 }
488}
489
490impl<T: Number + RealNumber> QRDecomposable<T> for DenseMatrix<T> {}
491impl<T: Number + RealNumber> CholeskyDecomposable<T> for DenseMatrix<T> {}
492impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
493impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
494impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
495
496impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
497 fn get(&self, pos: (usize, usize)) -> &T {
498 if self.column_major {
499 &self.values[pos.0 + pos.1 * self.stride]
500 } else {
501 &self.values[pos.0 * self.stride + pos.1]
502 }
503 }
504
505 fn shape(&self) -> (usize, usize) {
506 (self.nrows, self.ncols)
507 }
508
509 fn is_empty(&self) -> bool {
510 self.nrows * self.ncols > 0
511 }
512
513 fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
514 self.iter(axis)
515 }
516}
517
518impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
519 fn get(&self, i: usize) -> &T {
520 if self.nrows == 1 {
521 if self.column_major {
522 &self.values[i * self.stride]
523 } else {
524 &self.values[i]
525 }
526 } else if self.ncols == 1 || (!self.column_major && self.nrows == 1) {
527 if self.column_major {
528 &self.values[i]
529 } else {
530 &self.values[i * self.stride]
531 }
532 } else {
533 panic!("This is neither a column nor a row");
534 }
535 }
536
537 fn shape(&self) -> usize {
538 if self.nrows == 1 {
539 self.ncols
540 } else if self.ncols == 1 {
541 self.nrows
542 } else {
543 panic!("This is neither a column nor a row");
544 }
545 }
546
547 fn is_empty(&self) -> bool {
548 self.nrows * self.ncols > 0
549 }
550
551 fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
552 self.iter(axis)
553 }
554}
555
556impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}
557
558impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}
559
560impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
561 fn get(&self, pos: (usize, usize)) -> &T {
562 if self.column_major {
563 &self.values[pos.0 + pos.1 * self.stride]
564 } else {
565 &self.values[pos.0 * self.stride + pos.1]
566 }
567 }
568
569 fn shape(&self) -> (usize, usize) {
570 (self.nrows, self.ncols)
571 }
572
573 fn is_empty(&self) -> bool {
574 self.nrows * self.ncols > 0
575 }
576
577 fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
578 self.iter(axis)
579 }
580}
581
582impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
583 fn set(&mut self, pos: (usize, usize), x: T) {
584 if self.column_major {
585 self.values[pos.0 + pos.1 * self.stride] = x;
586 } else {
587 self.values[pos.0 * self.stride + pos.1] = x;
588 }
589 }
590
591 fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
592 self.iter_mut(axis)
593 }
594}
595
596impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}
597
598impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}
599
600impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
601
602impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
603
604#[cfg(test)]
605#[warn(clippy::reversed_empty_ranges)]
606mod tests {
607 use super::*;
608 use approx::relative_eq;
609
610 #[test]
611 fn test_instantiate_from_2d() {
612 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
613 assert!(x.is_ok());
614 }
615 #[test]
616 fn test_instantiate_from_2d_empty() {
617 let input: &[&[f64]] = &[&[]];
618 let x = DenseMatrix::from_2d_array(input);
619 assert!(x.is_err());
620 }
621 #[test]
622 fn test_instantiate_from_2d_empty2() {
623 let input: &[&[f64]] = &[&[], &[]];
624 let x = DenseMatrix::from_2d_array(input);
625 assert!(x.is_err());
626 }
627 #[test]
628 fn test_instantiate_ok_view1() {
629 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
630 let v = DenseMatrixView::new(&x, 0..2, 0..2);
631 assert!(v.is_ok());
632 }
633 #[test]
634 fn test_instantiate_ok_view2() {
635 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
636 let v = DenseMatrixView::new(&x, 0..3, 0..3);
637 assert!(v.is_ok());
638 }
639 #[test]
640 fn test_instantiate_ok_view3() {
641 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
642 let v = DenseMatrixView::new(&x, 2..3, 0..3);
643 assert!(v.is_ok());
644 }
645 #[test]
646 fn test_instantiate_ok_view4() {
647 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
648 let v = DenseMatrixView::new(&x, 3..3, 0..3);
649 assert!(v.is_ok());
650 }
651 #[test]
652 fn test_instantiate_err_view1() {
653 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
654 let v = DenseMatrixView::new(&x, 3..4, 0..3);
655 assert!(v.is_err());
656 }
657 #[test]
658 fn test_instantiate_err_view2() {
659 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
660 let v = DenseMatrixView::new(&x, 0..3, 3..4);
661 assert!(v.is_err());
662 }
663 #[test]
664 fn test_instantiate_err_view3() {
665 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
666 #[allow(clippy::reversed_empty_ranges)]
667 let v = DenseMatrixView::new(&x, 0..3, 4..3);
668 assert!(v.is_err());
669 }
670 #[test]
671 fn test_display() {
672 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
673
674 println!("{}", &x);
675 }
676
677 #[test]
678 fn test_get_row_col() {
679 let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
680
681 assert_eq!(15.0, x.get_col(1).sum());
682 assert_eq!(15.0, x.get_row(1).sum());
683 assert_eq!(81.0, x.get_col(1).dot(&(*x.get_row(1))));
684 }
685
686 #[test]
687 fn test_row_major() {
688 let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false).unwrap();
689
690 assert_eq!(5, *x.get_col(1).get(1));
691 assert_eq!(7, x.get_col(1).sum());
692 assert_eq!(5, *x.get_row(1).get(1));
693 assert_eq!(15, x.get_row(1).sum());
694 x.slice_mut(0..2, 1..2)
695 .iterator_mut(0)
696 .for_each(|v| *v += 2);
697 assert_eq!(vec![1, 4, 3, 4, 7, 6], *x.values);
698 }
699
700 #[test]
701 fn test_get_slice() {
702 let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
703 .unwrap();
704
705 assert_eq!(
706 vec![4, 5, 6],
707 DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values
708 );
709 let second_row: Vec<i32> = x.slice(1..2, 0..3).iterator(0).copied().collect();
710 assert_eq!(vec![4, 5, 6], second_row);
711 let second_col: Vec<i32> = x.slice(0..3, 1..2).iterator(0).copied().collect();
712 assert_eq!(vec![2, 5, 8], second_col);
713 }
714
715 #[test]
716 fn test_iter_mut() {
717 let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
718
719 assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values);
720 x.slice_mut(1..2, 0..3)
722 .iterator_mut(0)
723 .for_each(|v| *v += 2);
724 assert_eq!(vec![1, 6, 7, 2, 7, 8, 3, 8, 9], x.values);
725 x.slice_mut(0..3, 1..2)
727 .iterator_mut(0)
728 .for_each(|v| *v += 1);
729 assert_eq!(vec![1, 6, 7, 3, 8, 9, 3, 8, 9], x.values);
730
731 x.iterator_mut(1).enumerate().for_each(|(a, b)| *b = a);
733 assert_eq!(vec![0, 1, 2, 3, 4, 5, 6, 7, 8], x.values);
734 x.iterator_mut(0).enumerate().for_each(|(a, b)| *b = a);
736 assert_eq!(vec![0, 3, 6, 1, 4, 7, 2, 5, 8], x.values);
737 x.slice_mut(0..3, 0..2)
739 .iterator_mut(0)
740 .enumerate()
741 .for_each(|(a, b)| *b = a);
742 assert_eq!(vec![0, 2, 4, 1, 3, 5, 2, 5, 8], x.values);
743 x.slice_mut(0..2, 0..3)
744 .iterator_mut(1)
745 .enumerate()
746 .for_each(|(a, b)| *b = a);
747 assert_eq!(vec![0, 1, 4, 2, 3, 5, 4, 5, 8], x.values);
748 }
749
750 #[test]
751 fn test_str_array() {
752 let mut x =
753 DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]])
754 .unwrap();
755
756 assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
757 x.iterator_mut(0).for_each(|v| *v = "str");
758 assert_eq!(
759 vec!["str", "str", "str", "str", "str", "str", "str", "str", "str"],
760 x.values
761 );
762 }
763
764 #[test]
765 fn test_transpose() {
766 let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]).unwrap();
767
768 assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
769 assert!(x.column_major);
770
771 let x = x.transpose();
773 assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
774 assert!(!x.column_major); }
776
777 #[test]
778 fn test_from_iterator() {
779 let data = [1, 2, 3, 4, 5, 6];
780
781 let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0);
782
783 assert_eq!(
785 vec![1, 2, 3, 4, 5, 6],
786 m.values.iter().map(|e| **e).collect::<Vec<i32>>()
787 );
788 assert!(!m.column_major);
789 }
790
791 #[test]
792 fn test_take() {
793 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
794 let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
795
796 println!("{a}");
797 assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values);
799 println!("{b}");
800 assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values);
802 }
803
804 #[test]
805 fn test_mut() {
806 let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]).unwrap();
807
808 let a = a.abs();
809 assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values);
810
811 let a = a.neg();
812 assert_eq!(vec![-1.3, -4.0, -2.1, -5.3, -3.4, -6.1], a.values);
813 }
814
815 #[test]
816 fn test_reshape() {
817 let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
818 .unwrap();
819
820 let a = a.reshape(2, 6, 0);
821 assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
822 assert!(a.ncols == 6 && a.nrows == 2 && !a.column_major);
823
824 let a = a.reshape(3, 4, 1);
825 assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
826 assert!(a.ncols == 4 && a.nrows == 3 && a.column_major);
827 }
828
829 #[test]
830 fn test_eq() {
831 let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap();
832 let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
833 let c = DenseMatrix::from_2d_array(&[
834 &[1. + f32::EPSILON, 2., 3.],
835 &[4., 5., 6. + f32::EPSILON],
836 ])
837 .unwrap();
838 let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]])
839 .unwrap();
840
841 assert!(!relative_eq!(a, b));
842 assert!(!relative_eq!(a, d));
843 assert!(relative_eq!(a, c));
844 }
845}