scirs2_core/ndarray_ext/matrix.rs
1//! Matrix operations for ndarray
2//!
3//! This module provides matrix creation and manipulation operations similar to
4//! those found in `NumPy`/SciPy, such as identity, diagonal, block, and other
5//! specialized matrix operations.
6
7use ndarray::{Array, ArrayView, Ix1, Ix2};
8use num_traits::{One, Zero};
9
10/// Create an identity matrix
11///
12/// # Arguments
13///
14/// * `n` - Number of rows and columns in the square identity matrix
15///
16/// # Returns
17///
18/// An nxn identity matrix
19///
20/// # Examples
21///
22/// ```
23/// use scirs2_core::ndarray_ext::matrix::eye;
24///
25/// let id3 = eye::<f64>(3);
26/// assert_eq!(id3.shape(), &[3, 3]);
27/// assert_eq!(id3[[0, 0]], 1.0);
28/// assert_eq!(id3[[1, 1]], 1.0);
29/// assert_eq!(id3[[2, 2]], 1.0);
30/// assert_eq!(id3[[0, 1]], 0.0);
31/// ```
32#[allow(dead_code)]
33pub fn eye<T>(n: usize) -> Array<T, Ix2>
34where
35 T: Clone + Zero + One,
36{
37 let mut result = Array::<T, Ix2>::zeros((n, n));
38
39 for i in 0..n {
40 result[[i, i]] = T::one();
41 }
42
43 result
44}
45
46/// Create a matrix with ones on the given diagonal and zeros elsewhere
47///
48/// # Arguments
49///
50/// * `n` - Number of rows
51/// * `m` - Number of columns
52/// * `k` - Diagonal offset (0 for main diagonal, positive for above, negative for below)
53///
54/// # Returns
55///
56/// An n x m matrix with ones on the specified diagonal
57///
58/// # Examples
59///
60/// ```
61/// use scirs2_core::ndarray_ext::matrix::eye_offset;
62///
63/// // 3x3 matrix with ones on the main diagonal (k=0)
64/// let id3 = eye_offset::<f64>(3, 3, 0);
65/// assert_eq!(id3.shape(), &[3, 3]);
66/// assert_eq!(id3[[0, 0]], 1.0);
67/// assert_eq!(id3[[1, 1]], 1.0);
68/// assert_eq!(id3[[2, 2]], 1.0);
69///
70/// // 3x4 matrix with ones on the first superdiagonal (k=1)
71/// let super_diag = eye_offset::<f64>(3, 4, 1);
72/// assert_eq!(super_diag.shape(), &[3, 4]);
73/// assert_eq!(super_diag[[0, 1]], 1.0);
74/// assert_eq!(super_diag[[1, 2]], 1.0);
75/// assert_eq!(super_diag[[2, 3]], 1.0);
76/// ```
77#[allow(dead_code)]
78pub fn eye_offset<T>(n: usize, m: usize, k: isize) -> Array<T, Ix2>
79where
80 T: Clone + Zero + One,
81{
82 let mut result = Array::<T, Ix2>::zeros((n, m));
83
84 // Determine the start and end points for the diagonal
85 let start_i = if k > 0 { 0 } else { (-k) as usize };
86 let start_j = if k < 0 { 0 } else { k as usize };
87
88 let diag_len = std::cmp::min(n - start_i, m - start_j);
89
90 for d in 0..diag_len {
91 result[[start_i + d, start_j + d]] = T::one();
92 }
93
94 result
95}
96
97/// Construct a diagonal matrix from a 1D array
98///
99/// # Arguments
100///
101/// * `_diagvalues` - The values to place on the diagonal
102///
103/// # Returns
104///
105/// A square matrix with the input values on the main diagonal and zeros elsewhere
106///
107/// # Examples
108///
109/// ```
110/// use ndarray::array;
111/// use scirs2_core::ndarray_ext::matrix::_diag;
112///
113/// let values = array![1, 2, 3];
114/// let diagmatrix = _diag(values.view());
115/// assert_eq!(diagmatrix.shape(), &[3, 3]);
116/// assert_eq!(diagmatrix[[0, 0]], 1);
117/// assert_eq!(diagmatrix[[1, 1]], 2);
118/// assert_eq!(diagmatrix[[2, 2]], 3);
119/// assert_eq!(diagmatrix[[0, 1]], 0);
120/// ```
121#[allow(dead_code)]
122pub fn _diag<T>(_diagvalues: ArrayView<T, Ix1>) -> Array<T, Ix2>
123where
124 T: Clone + Zero,
125{
126 let n = _diagvalues.len();
127 let mut result = Array::<T, Ix2>::zeros((n, n));
128
129 for i in 0..n {
130 result[[i, i]] = _diagvalues[i].clone();
131 }
132
133 result
134}
135
136/// Extract a diagonal from a 2D array
137///
138/// # Arguments
139///
140/// * `array` - The input 2D array
141/// * `k` - Diagonal offset (0 for main diagonal, positive for above, negative for below)
142///
143/// # Returns
144///
145/// A 1D array containing the specified diagonal
146///
147/// # Examples
148///
149/// ```
150/// use ndarray::array;
151/// use scirs2_core::ndarray_ext::matrix::diagonal;
152///
153/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
154///
155/// // Extract main diagonal
156/// let main_diag = diagonal(a.view(), 0).unwrap();
157/// assert_eq!(main_diag, array![1, 5, 9]);
158///
159/// // Extract superdiagonal
160/// let super_diag = diagonal(a.view(), 1).unwrap();
161/// assert_eq!(super_diag, array![2, 6]);
162///
163/// // Extract subdiagonal
164/// let sub_diag = diagonal(a.view(), -1).unwrap();
165/// assert_eq!(sub_diag, array![4, 8]);
166/// ```
167#[allow(dead_code)]
168pub fn diagonal<T>(array: ArrayView<T, Ix2>, k: isize) -> Result<Array<T, Ix1>, &'static str>
169where
170 T: Clone + Zero,
171{
172 let (rows, cols) = (array.shape()[0], array.shape()[1]);
173
174 // Calculate the length of the diagonal
175 let diag_len = if k >= 0 {
176 std::cmp::min(rows, cols.saturating_sub(k as usize))
177 } else {
178 std::cmp::min(cols, rows.saturating_sub((-k) as usize))
179 };
180
181 if diag_len == 0 {
182 return Err("No diagonal elements for the given offset");
183 }
184
185 // Create the result array directly
186 let mut result = Array::<T, Ix1>::zeros(diag_len);
187
188 // Extract the diagonal elements
189 for i in 0..diag_len {
190 let row = if k < 0 { i + (-k) as usize } else { i };
191
192 let col = if k > 0 { i + k as usize } else { i };
193
194 result[i] = array[[row, col]].clone();
195 }
196
197 Ok(result)
198}
199
200/// Create a matrix filled with a given value
201///
202/// # Arguments
203///
204/// * `rows` - Number of rows
205/// * `cols` - Number of columns
206/// * `value` - Value to fill the matrix with
207///
208/// # Returns
209///
210/// A matrix filled with the specified value
211///
212/// # Examples
213///
214/// ```
215/// use scirs2_core::ndarray_ext::matrix::full;
216///
217/// let filled = full(2, 3, 7);
218/// assert_eq!(filled.shape(), &[2, 3]);
219/// assert_eq!(filled[[0, 0]], 7);
220/// assert_eq!(filled[[1, 2]], 7);
221/// ```
222#[allow(dead_code)]
223pub fn full<T>(rows: usize, cols: usize, value: T) -> Array<T, Ix2>
224where
225 T: Clone,
226{
227 Array::<T, Ix2>::from_elem((rows, cols), value)
228}
229
230/// Create a matrix filled with ones
231///
232/// # Arguments
233///
234/// * `rows` - Number of rows
235/// * `cols` - Number of columns
236///
237/// # Returns
238///
239/// A matrix filled with ones
240///
241/// # Examples
242///
243/// ```
244/// use scirs2_core::ndarray_ext::matrix::ones;
245///
246/// let ones_mat = ones::<f64>(2, 3);
247/// assert_eq!(ones_mat.shape(), &[2, 3]);
248/// assert_eq!(ones_mat[[0, 0]], 1.0);
249/// assert_eq!(ones_mat[[1, 2]], 1.0);
250/// ```
251#[allow(dead_code)]
252pub fn ones<T>(rows: usize, cols: usize) -> Array<T, Ix2>
253where
254 T: Clone + One,
255{
256 Array::<T, Ix2>::from_elem((rows, cols), T::one())
257}
258
259/// Create a matrix filled with zeros
260///
261/// # Arguments
262///
263/// * `rows` - Number of rows
264/// * `cols` - Number of columns
265///
266/// # Returns
267///
268/// A matrix filled with zeros
269///
270/// # Examples
271///
272/// ```ignore
273/// use scirs2_core::ndarray_ext::matrix::zeros;
274///
275/// let zeros_mat = zeros::<f64>(2, 3);
276/// assert_eq!(zeros_mat.shape(), &[2, 3]);
277/// assert_eq!(zeros_mat[[0, 0]], 0.0);
278/// assert_eq!(zeros_mat[[1, 2]], 0.0);
279/// ```
280#[allow(dead_code)]
281pub fn zeros<T>(rows: usize, cols: usize) -> Array<T, Ix2>
282where
283 T: Clone + Zero,
284{
285 Array::<T, Ix2>::zeros((rows, cols))
286}
287
288/// Compute the Kronecker product of two 2D arrays
289///
290/// # Arguments
291///
292/// * `a` - First input array
293/// * `b` - Second input array
294///
295/// # Returns
296///
297/// The Kronecker product of the input arrays
298///
299/// # Examples
300///
301/// ```
302/// use ndarray::array;
303/// use scirs2_core::ndarray_ext::matrix::kron;
304///
305/// let a = array![[1, 2], [3, 4]];
306/// let b = array![[0, 5], [6, 7]];
307///
308/// let result = kron(a.view(), b.view());
309/// assert_eq!(result.shape(), &[4, 4]);
310/// assert_eq!(result, array![
311/// [0, 5, 0, 10],
312/// [6, 7, 12, 14],
313/// [0, 15, 0, 20],
314/// [18, 21, 24, 28]
315/// ]);
316/// ```
317#[allow(dead_code)]
318pub fn kron<T>(a: ArrayView<T, Ix2>, b: ArrayView<T, Ix2>) -> Array<T, Ix2>
319where
320 T: Clone + Zero + std::ops::Mul<Output = T>,
321{
322 let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
323 let (b_rows, b_cols) = (b.shape()[0], b.shape()[1]);
324
325 let result_rows = a_rows * b_rows;
326 let result_cols = a_cols * b_cols;
327
328 let mut result = Array::<T, Ix2>::zeros((result_rows, result_cols));
329
330 for i in 0..a_rows {
331 for j in 0..a_cols {
332 for k in 0..b_rows {
333 for l in 0..b_cols {
334 result[[i * b_rows + k, j * b_cols + l]] =
335 a[[i, j]].clone() * b[[k, l]].clone();
336 }
337 }
338 }
339 }
340
341 result
342}
343
344/// Create a Toeplitz matrix from first row and first column
345///
346/// # Arguments
347///
348/// * `first_row` - First row of the Toeplitz matrix
349/// * `first_col` - First column of the Toeplitz matrix (first element must match first row's first element)
350///
351/// # Returns
352///
353/// A Toeplitz matrix with the specified first row and column
354///
355/// # Examples
356///
357/// ```
358/// use ndarray::array;
359/// use scirs2_core::ndarray_ext::matrix::toeplitz;
360///
361/// let first_row = array![1, 2, 3];
362/// let first_col = array![1, 4, 5];
363/// let result = toeplitz(first_row.view(), first_col.view()).unwrap();
364/// assert_eq!(result.shape(), &[3, 3]);
365/// assert_eq!(result, array![
366/// [1, 2, 3],
367/// [4, 1, 2],
368/// [5, 4, 1]
369/// ]);
370/// ```
371#[allow(dead_code)]
372pub fn toeplitz<T>(
373 first_row: ArrayView<T, Ix1>,
374 first_col: ArrayView<T, Ix1>,
375) -> Result<Array<T, Ix2>, &'static str>
376where
377 T: Clone + PartialEq + Zero,
378{
379 // First elements of first_row and first_col must match
380 if first_row.is_empty() || first_col.is_empty() {
381 return Err("Input arrays must not be empty");
382 }
383
384 if first_row[0] != first_col[0] {
385 return Err("First element of _row and column must match");
386 }
387
388 let n = first_col.len(); // Number of rows
389 let m = first_row.len(); // Number of columns
390
391 let mut result = Array::<T, Ix2>::zeros((n, m));
392
393 for i in 0..n {
394 for j in 0..m {
395 if i <= j {
396 // Upper triangle and main diagonal from first_row
397 result[[i, j]] = first_row[j - i].clone();
398 } else {
399 // Lower triangle from first_col
400 result[[i, j]] = first_col[i - j].clone();
401 }
402 }
403 }
404
405 Ok(result)
406}
407
408/// Create a block diagonal matrix from a sequence of 2D arrays
409///
410/// # Arguments
411///
412/// * `arrays` - A slice of 2D arrays to form the blocks on the diagonal
413///
414/// # Returns
415///
416/// A block diagonal matrix with the input arrays on the diagonal
417///
418/// # Examples
419///
420/// ```
421/// use ndarray::array;
422/// use scirs2_core::ndarray_ext::matrix::block_diag;
423///
424/// let a = array![[1, 2], [3, 4]];
425/// let b = array![[5, 6], [7, 8]];
426///
427/// let result = block_diag(&[a.view(), b.view()]);
428/// assert_eq!(result.shape(), &[4, 4]);
429/// assert_eq!(result, array![
430/// [1, 2, 0, 0],
431/// [3, 4, 0, 0],
432/// [0, 0, 5, 6],
433/// [0, 0, 7, 8]
434/// ]);
435/// ```
436#[allow(dead_code)]
437pub fn block_diag<T>(arrays: &[ArrayView<T, Ix2>]) -> Array<T, Ix2>
438where
439 T: Clone + Zero,
440{
441 if arrays.is_empty() {
442 return Array::<T, Ix2>::zeros((0, 0));
443 }
444
445 // Calculate total dimensions
446 let mut total_rows = 0;
447 let mut total_cols = 0;
448
449 for array in arrays {
450 total_rows += array.shape()[0];
451 total_cols += array.shape()[1];
452 }
453
454 let mut result = Array::<T, Ix2>::zeros((total_rows, total_cols));
455
456 let mut row_offset = 0;
457 let mut col_offset = 0;
458
459 // Place each array on the diagonal
460 for array in arrays {
461 let (rows, cols) = (array.shape()[0], array.shape()[1]);
462
463 for i in 0..rows {
464 for j in 0..cols {
465 result[[row_offset + i, col_offset + j]] = array[[i, j]].clone();
466 }
467 }
468
469 row_offset += rows;
470 col_offset += cols;
471 }
472
473 result
474}
475
476/// Create a tri-diagonal matrix from the three diagonals
477///
478/// # Arguments
479///
480/// * `diag` - Main diagonal
481/// * `lower_diag` - Lower diagonal
482/// * `upper_diag` - Upper diagonal
483///
484/// # Returns
485///
486/// A tri-diagonal matrix with the specified diagonals
487///
488/// # Examples
489///
490/// ```
491/// use ndarray::array;
492/// use scirs2_core::ndarray_ext::matrix::tridiagonal;
493///
494/// let diag = array![1, 2, 3];
495/// let lower = array![4, 5];
496/// let upper = array![6, 7];
497///
498/// let result = tridiagonal(diag.view(), lower.view(), upper.view()).unwrap();
499/// assert_eq!(result.shape(), &[3, 3]);
500/// assert_eq!(result, array![
501/// [1, 6, 0],
502/// [4, 2, 7],
503/// [0, 5, 3]
504/// ]);
505/// ```
506#[allow(dead_code)]
507pub fn tridiagonal<T>(
508 diag: ArrayView<T, Ix1>,
509 lower_diag: ArrayView<T, Ix1>,
510 upper_diag: ArrayView<T, Ix1>,
511) -> Result<Array<T, Ix2>, &'static str>
512where
513 T: Clone + Zero,
514{
515 let n = diag.len();
516
517 // Check that the diagonals have correct sizes
518 if lower_diag.len() != n - 1 || upper_diag.len() != n - 1 {
519 return Err("Lower and upper diagonals must have length n-1 where n is the length of the main diagonal");
520 }
521
522 let mut result = Array::<T, Ix2>::zeros((n, n));
523
524 // Set main diagonal
525 for i in 0..n {
526 result[[i, i]] = diag[i].clone();
527 }
528
529 // Set lower diagonal
530 for i in 1..n {
531 result[[i, i - 1]] = lower_diag[i - 1].clone();
532 }
533
534 // Set upper diagonal
535 for i in 0..n - 1 {
536 result[[i, i + 1]] = upper_diag[i].clone();
537 }
538
539 Ok(result)
540}
541
542/// Create a Hankel matrix from its first column and last row
543///
544/// # Arguments
545///
546/// * `first_col` - First column of the Hankel matrix
547/// * `last_row` - Last row of the Hankel matrix (first element must match last element of first_col)
548///
549/// # Returns
550///
551/// A Hankel matrix with the specified first column and last row
552///
553/// # Examples
554///
555/// ```
556/// use ndarray::array;
557/// use scirs2_core::ndarray_ext::matrix::hankel;
558///
559/// let first_col = array![1, 2, 3];
560/// let last_row = array![3, 4, 5];
561///
562/// let result = hankel(first_col.view(), last_row.view()).unwrap();
563/// assert_eq!(result.shape(), &[3, 3]);
564/// assert_eq!(result, array![
565/// [1, 2, 3],
566/// [2, 3, 4],
567/// [3, 4, 5]
568/// ]);
569/// ```
570#[allow(dead_code)]
571pub fn hankel<T>(
572 first_col: ArrayView<T, Ix1>,
573 last_row: ArrayView<T, Ix1>,
574) -> Result<Array<T, Ix2>, &'static str>
575where
576 T: Clone + PartialEq + Zero,
577{
578 if first_col.is_empty() || last_row.is_empty() {
579 return Err("Input arrays must not be empty");
580 }
581
582 // Last element of first_col must match first element of last_row
583 if first_col[first_col.len() - 1] != last_row[0] {
584 return Err("Last element of first column must match first element of last _row");
585 }
586
587 let n = first_col.len(); // Number of rows
588 let m = last_row.len(); // Number of columns
589
590 let mut result = Array::<T, Ix2>::zeros((n, m));
591
592 // Combine first_col and last_row (minus first element of last_row) to form the "data" array
593 let data_len = n + m - 1;
594 let mut data = Vec::with_capacity(data_len);
595
596 // Fill data with first_col elements
597 for i in 0..n {
598 data.push(first_col[i].clone());
599 }
600
601 // Append last_row elements (skipping first element which should match last element of first_col)
602 for i in 1..m {
603 data.push(last_row[i].clone());
604 }
605
606 // Fill the Hankel matrix
607 for i in 0..n {
608 for j in 0..m {
609 result[[i, j]] = data[i + j].clone();
610 }
611 }
612
613 Ok(result)
614}
615
616/// Calculate the trace of a square matrix (sum of diagonal elements)
617///
618/// # Arguments
619///
620/// * `array` - Input square matrix
621///
622/// # Returns
623///
624/// The trace of the matrix
625///
626/// # Examples
627///
628/// ```
629/// use ndarray::array;
630/// use scirs2_core::ndarray_ext::matrix::trace;
631///
632/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
633/// let tr = trace(a.view()).unwrap();
634/// assert_eq!(tr, 15); // 1 + 5 + 9 = 15
635/// ```
636#[allow(dead_code)]
637pub fn trace<T>(array: ArrayView<T, Ix2>) -> Result<T, &'static str>
638where
639 T: Clone + Zero + std::ops::Add<Output = T>,
640{
641 let (rows, cols) = (array.shape()[0], array.shape()[1]);
642
643 if rows != cols {
644 return Err("Trace is only defined for square matrices");
645 }
646
647 let mut result = T::zero();
648
649 for i in 0..rows {
650 result = result + array[[i, i]].clone();
651 }
652
653 Ok(result)
654}
655
656/// Create a matrix vander from a 1D array
657///
658/// # Arguments
659///
660/// * `x` - Input 1D array
661/// * `n` - Optional number of columns in the output (defaults to x.len())
662/// * `increasing` - Optional boolean to determine order (defaults to false)
663///
664/// # Returns
665///
666/// A Vandermonde matrix where each column is a power of the input array
667///
668/// # Examples
669///
670/// ```ignore
671/// use ndarray::array;
672/// use scirs2_core::ndarray_ext::matrix::vander;
673///
674/// let x = array![1.0, 2.0, 3.0];
675///
676/// // Default behavior: decreasing powers from n-1 to 0
677/// let v1 = vander(x.view(), None, None).unwrap();
678/// assert_eq!(v1.shape(), &[3, 3]);
679/// // Powers: x^2, x^1, x^0
680/// assert_eq!(v1, array![
681/// [1.0, 1.0, 1.0],
682/// [4.0, 2.0, 1.0],
683/// [9.0, 3.0, 1.0]
684/// ]);
685///
686/// // Increasing powers: 0 to n-1
687/// let v2 = vander(x.view(), None, Some(true)).unwrap();
688/// assert_eq!(v2.shape(), &[3, 3]);
689/// // Powers: x^0, x^1, x^2
690/// assert_eq!(v2, array![
691/// [1.0, 1.0, 1.0],
692/// [1.0, 2.0, 4.0],
693/// [1.0, 3.0, 9.0]
694/// ]);
695///
696/// // Specify 4 columns
697/// let v3 = vander(x.view(), Some(4), None).unwrap();
698/// assert_eq!(v3.shape(), &[3, 4]);
699/// // Powers: x^3, x^2, x^1, x^0
700/// assert_eq!(v3, array![
701/// [1.0, 1.0, 1.0, 1.0],
702/// [8.0, 4.0, 2.0, 1.0],
703/// [27.0, 9.0, 3.0, 1.0]
704/// ]);
705/// ```
706#[allow(dead_code)]
707pub fn vander<T>(
708 x: ArrayView<T, Ix1>,
709 n: Option<usize>,
710 increasing: Option<bool>,
711) -> Result<Array<T, Ix2>, &'static str>
712where
713 T: Clone + Zero + One + std::ops::Mul<Output = T> + std::ops::Div<Output = T>,
714{
715 let x_len = x.len();
716
717 if x_len == 0 {
718 return Err("Input array must not be empty");
719 }
720
721 let n = n.unwrap_or(x_len);
722 let increasing = increasing.unwrap_or(false);
723
724 let mut result = Array::<T, Ix2>::zeros((x_len, n));
725
726 // Fill in the Vandermonde matrix with powers of x elements
727 for i in 0..x_len {
728 // Initialize accumulator with 1 (x^0)
729 let mut power = T::one();
730
731 if increasing {
732 // First column is x^0 (all ones)
733 for j in 0..n {
734 result[[i, j]] = power.clone();
735
736 if j < n - 1 {
737 power = power.clone() * x[i].clone();
738 }
739 }
740 } else {
741 // Decreasing powers (last column is x^0)
742 // Calculate highest power first: x^(n-1)
743 for _p in 0..n - 1 {
744 power = power.clone() * x[i].clone();
745 }
746
747 for j in 0..n {
748 result[[i, j]] = power.clone();
749
750 if j < n - 1 {
751 // For non-increasing powers, we need Div trait for T to handle division
752 // power = power.clone() / x[i].clone(); // This requires Div trait
753
754 // Use multiplication by reciprocal instead as a safer approach
755 power = power.clone() * (T::one() / x[i].clone());
756 }
757 }
758 }
759 }
760
761 Ok(result)
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767 use approx::assert_abs_diff_eq;
768 use ndarray::array;
769
770 #[test]
771 fn test_eye() {
772 let id3 = eye::<f64>(3);
773 assert_eq!(id3.shape(), &[3, 3]);
774 assert_eq!(id3[[0, 0]], 1.0);
775 assert_eq!(id3[[1, 1]], 1.0);
776 assert_eq!(id3[[2, 2]], 1.0);
777 assert_eq!(id3[[0, 1]], 0.0);
778 assert_eq!(id3[[1, 0]], 0.0);
779 }
780
781 #[test]
782 fn test_eye_offset() {
783 // 3x3 matrix with ones on the main diagonal (k=0)
784 let id3 = eye_offset::<f64>(3, 3, 0);
785 assert_eq!(id3.shape(), &[3, 3]);
786 assert_eq!(id3[[0, 0]], 1.0);
787 assert_eq!(id3[[1, 1]], 1.0);
788 assert_eq!(id3[[2, 2]], 1.0);
789
790 // 3x4 matrix with ones on the first superdiagonal (k=1)
791 let super_diag = eye_offset::<f64>(3, 4, 1);
792 assert_eq!(super_diag.shape(), &[3, 4]);
793 assert_eq!(super_diag[[0, 1]], 1.0);
794 assert_eq!(super_diag[[1, 2]], 1.0);
795 assert_eq!(super_diag[[2, 3]], 1.0);
796
797 // 4x3 matrix with ones on the first subdiagonal (k=-1)
798 let sub_diag = eye_offset::<f64>(4, 3, -1);
799 assert_eq!(sub_diag.shape(), &[4, 3]);
800 assert_eq!(sub_diag[[1, 0]], 1.0);
801 assert_eq!(sub_diag[[2, 1]], 1.0);
802 assert_eq!(sub_diag[[3, 2]], 1.0);
803 }
804
805 #[test]
806 fn test_diag() {
807 let values = array![1, 2, 3];
808 let diagmatrix = _diag(values.view());
809 assert_eq!(diagmatrix.shape(), &[3, 3]);
810 assert_eq!(diagmatrix[[0, 0]], 1);
811 assert_eq!(diagmatrix[[1, 1]], 2);
812 assert_eq!(diagmatrix[[2, 2]], 3);
813 assert_eq!(diagmatrix[[0, 1]], 0);
814 }
815
816 #[test]
817 fn test_diagonal() {
818 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
819
820 // Extract main diagonal
821 let main_diag = diagonal(a.view(), 0).unwrap();
822 assert_eq!(main_diag, array![1, 5, 9]);
823
824 // Extract superdiagonal
825 let super_diag = diagonal(a.view(), 1).unwrap();
826 assert_eq!(super_diag, array![2, 6]);
827
828 // Extract subdiagonal
829 let sub_diag = diagonal(a.view(), -1).unwrap();
830 assert_eq!(sub_diag, array![4, 8]);
831
832 // Test out of bounds (should return error)
833 assert!(diagonal(a.view(), 3).is_err());
834 assert!(diagonal(a.view(), -3).is_err());
835 }
836
837 #[test]
838 fn test_full_ones_zeros() {
839 // Test full
840 let filled = full(2, 3, 7);
841 assert_eq!(filled.shape(), &[2, 3]);
842 assert_eq!(filled[[0, 0]], 7);
843 assert_eq!(filled[[1, 2]], 7);
844
845 // Test ones
846 let ones_mat = ones::<f64>(2, 3);
847 assert_eq!(ones_mat.shape(), &[2, 3]);
848 assert_eq!(ones_mat[[0, 0]], 1.0);
849 assert_eq!(ones_mat[[1, 2]], 1.0);
850
851 // Test zeros
852 let zeros_mat = zeros::<f64>(2, 3);
853 assert_eq!(zeros_mat.shape(), &[2, 3]);
854 assert_eq!(zeros_mat[[0, 0]], 0.0);
855 assert_eq!(zeros_mat[[1, 2]], 0.0);
856 }
857
858 #[test]
859 fn test_kron() {
860 let a = array![[1, 2], [3, 4]];
861 let b = array![[0, 5], [6, 7]];
862
863 let result = kron(a.view(), b.view());
864 assert_eq!(result.shape(), &[4, 4]);
865
866 // Check results
867 assert_eq!(
868 result,
869 array![
870 [0, 5, 0, 10],
871 [6, 7, 12, 14],
872 [0, 15, 0, 20],
873 [18, 21, 24, 28]
874 ]
875 );
876 }
877
878 #[test]
879 fn test_toeplitz() {
880 let first_row = array![1, 2, 3];
881 let first_col = array![1, 4, 5];
882
883 let result = toeplitz(first_row.view(), first_col.view()).unwrap();
884 assert_eq!(result.shape(), &[3, 3]);
885 assert_eq!(result, array![[1, 2, 3], [4, 1, 2], [5, 4, 1]]);
886
887 // Test with mismatched first elements (should return error)
888 let bad_row = array![9, 2, 3];
889 assert!(toeplitz(bad_row.view(), first_col.view()).is_err());
890 }
891
892 #[test]
893 fn test_block_diag() {
894 let a = array![[1, 2], [3, 4]];
895 let b = array![[5, 6], [7, 8]];
896
897 let result = block_diag(&[a.view(), b.view()]);
898 assert_eq!(result.shape(), &[4, 4]);
899 assert_eq!(
900 result,
901 array![[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 5, 6], [0, 0, 7, 8]]
902 );
903
904 // Test with different size blocks
905 let c = array![[9]];
906 let result2 = block_diag(&[a.view(), c.view()]);
907 assert_eq!(result2.shape(), &[3, 3]);
908 assert_eq!(result2, array![[1, 2, 0], [3, 4, 0], [0, 0, 9]]);
909
910 // Test with empty array list
911 let empty: [ArrayView<i32, Ix2>; 0] = [];
912 let result3 = block_diag(&empty);
913 assert_eq!(result3.shape(), &[0, 0]);
914 }
915
916 #[test]
917 fn test_tridiagonal() {
918 let diag = array![1, 2, 3];
919 let lower = array![4, 5];
920 let upper = array![6, 7];
921
922 let result = tridiagonal(diag.view(), lower.view(), upper.view()).unwrap();
923 assert_eq!(result.shape(), &[3, 3]);
924 assert_eq!(result, array![[1, 6, 0], [4, 2, 7], [0, 5, 3]]);
925
926 // Test with incorrect diagonals size (should return error)
927 let bad_lower = array![4];
928 assert!(tridiagonal(diag.view(), bad_lower.view(), upper.view()).is_err());
929 }
930
931 #[test]
932 fn test_hankel() {
933 let first_col = array![1, 2, 3];
934 let last_row = array![3, 4, 5];
935
936 let result = hankel(first_col.view(), last_row.view()).unwrap();
937 assert_eq!(result.shape(), &[3, 3]);
938 assert_eq!(result, array![[1, 2, 3], [2, 3, 4], [3, 4, 5]]);
939
940 // Test with mismatched elements (should return error)
941 let bad_row = array![9, 4, 5];
942 assert!(hankel(first_col.view(), bad_row.view()).is_err());
943 }
944
945 #[test]
946 fn test_trace() {
947 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
948 let tr = trace(a.view()).unwrap();
949 assert_eq!(tr, 15); // 1 + 5 + 9 = 15
950
951 // Test non-square matrix (should return error)
952 let b = array![[1, 2, 3], [4, 5, 6]];
953 assert!(trace(b.view()).is_err());
954 }
955
956 #[test]
957 fn test_vander() {
958 let x = array![1.0, 2.0, 3.0];
959
960 // Default behavior: decreasing powers from n-1 to 0
961 let v1 = vander(x.view(), None, None).unwrap();
962 assert_eq!(v1.shape(), &[3, 3]);
963 // Should be equivalent to [x^2, x^1, x^0]
964 for i in 0..3 {
965 assert_abs_diff_eq!(v1[[0, 0]], x[0] * x[0]);
966 assert_abs_diff_eq!(v1[[0, 1]], x[0]);
967 assert_abs_diff_eq!(v1[[0, 2]], 1.0);
968 }
969
970 // Increasing powers from 0 to n-1
971 let v2 = vander(x.view(), None, Some(true)).unwrap();
972 assert_eq!(v2.shape(), &[3, 3]);
973 // Should be equivalent to [x^0, x^1, x^2]
974 for i in 0..3 {
975 assert_abs_diff_eq!(v2[[0, 0]], 1.0);
976 assert_abs_diff_eq!(v2[[0, 1]], x[0]);
977 assert_abs_diff_eq!(v2[[0, 2]], x[0] * x[0]);
978 }
979
980 // Specify 4 columns (decreasing power)
981 let v3 = vander(x.view(), Some(4), None).unwrap();
982 assert_eq!(v3.shape(), &[3, 4]);
983 // Should be equivalent to [x^3, x^2, x^1, x^0]
984 for i in 0..3 {
985 assert_abs_diff_eq!(v3[[0, 0]], x[0] * x[0] * x[0]);
986 assert_abs_diff_eq!(v3[[0, 1]], x[0] * x[0]);
987 assert_abs_diff_eq!(v3[[0, 2]], x[0]);
988 assert_abs_diff_eq!(v3[[0, 3]], 1.0);
989 }
990 }
991}