scirs2_core/ndarray_ext/
matrix.rs

1//! Matrix operations for ndarray
2//!
3//! This module provides matrix creation and manipulation operations similar to
4//! those found in `NumPy`/SciPy, such as identity, diagonal, block, and other
5//! specialized matrix operations.
6
7use ndarray::{Array, ArrayView, Ix1, Ix2};
8use num_traits::{One, Zero};
9
10/// Create an identity matrix
11///
12/// # Arguments
13///
14/// * `n` - Number of rows and columns in the square identity matrix
15///
16/// # Returns
17///
18/// An nxn identity matrix
19///
20/// # Examples
21///
22/// ```
23/// use scirs2_core::ndarray_ext::matrix::eye;
24///
25/// let id3 = eye::<f64>(3);
26/// assert_eq!(id3.shape(), &[3, 3]);
27/// assert_eq!(id3[[0, 0]], 1.0);
28/// assert_eq!(id3[[1, 1]], 1.0);
29/// assert_eq!(id3[[2, 2]], 1.0);
30/// assert_eq!(id3[[0, 1]], 0.0);
31/// ```
32#[allow(dead_code)]
33pub fn eye<T>(n: usize) -> Array<T, Ix2>
34where
35    T: Clone + Zero + One,
36{
37    let mut result = Array::<T, Ix2>::zeros((n, n));
38
39    for i in 0..n {
40        result[[i, i]] = T::one();
41    }
42
43    result
44}
45
46/// Create a matrix with ones on the given diagonal and zeros elsewhere
47///
48/// # Arguments
49///
50/// * `n` - Number of rows
51/// * `m` - Number of columns
52/// * `k` - Diagonal offset (0 for main diagonal, positive for above, negative for below)
53///
54/// # Returns
55///
56/// An n x m matrix with ones on the specified diagonal
57///
58/// # Examples
59///
60/// ```
61/// use scirs2_core::ndarray_ext::matrix::eye_offset;
62///
63/// // 3x3 matrix with ones on the main diagonal (k=0)
64/// let id3 = eye_offset::<f64>(3, 3, 0);
65/// assert_eq!(id3.shape(), &[3, 3]);
66/// assert_eq!(id3[[0, 0]], 1.0);
67/// assert_eq!(id3[[1, 1]], 1.0);
68/// assert_eq!(id3[[2, 2]], 1.0);
69///
70/// // 3x4 matrix with ones on the first superdiagonal (k=1)
71/// let super_diag = eye_offset::<f64>(3, 4, 1);
72/// assert_eq!(super_diag.shape(), &[3, 4]);
73/// assert_eq!(super_diag[[0, 1]], 1.0);
74/// assert_eq!(super_diag[[1, 2]], 1.0);
75/// assert_eq!(super_diag[[2, 3]], 1.0);
76/// ```
77#[allow(dead_code)]
78pub fn eye_offset<T>(n: usize, m: usize, k: isize) -> Array<T, Ix2>
79where
80    T: Clone + Zero + One,
81{
82    let mut result = Array::<T, Ix2>::zeros((n, m));
83
84    // Determine the start and end points for the diagonal
85    let start_i = if k > 0 { 0 } else { (-k) as usize };
86    let start_j = if k < 0 { 0 } else { k as usize };
87
88    let diag_len = std::cmp::min(n - start_i, m - start_j);
89
90    for d in 0..diag_len {
91        result[[start_i + d, start_j + d]] = T::one();
92    }
93
94    result
95}
96
97/// Construct a diagonal matrix from a 1D array
98///
99/// # Arguments
100///
101/// * `_diagvalues` - The values to place on the diagonal
102///
103/// # Returns
104///
105/// A square matrix with the input values on the main diagonal and zeros elsewhere
106///
107/// # Examples
108///
109/// ```
110/// use ndarray::array;
111/// use scirs2_core::ndarray_ext::matrix::_diag;
112///
113/// let values = array![1, 2, 3];
114/// let diagmatrix = _diag(values.view());
115/// assert_eq!(diagmatrix.shape(), &[3, 3]);
116/// assert_eq!(diagmatrix[[0, 0]], 1);
117/// assert_eq!(diagmatrix[[1, 1]], 2);
118/// assert_eq!(diagmatrix[[2, 2]], 3);
119/// assert_eq!(diagmatrix[[0, 1]], 0);
120/// ```
121#[allow(dead_code)]
122pub fn _diag<T>(_diagvalues: ArrayView<T, Ix1>) -> Array<T, Ix2>
123where
124    T: Clone + Zero,
125{
126    let n = _diagvalues.len();
127    let mut result = Array::<T, Ix2>::zeros((n, n));
128
129    for i in 0..n {
130        result[[i, i]] = _diagvalues[i].clone();
131    }
132
133    result
134}
135
136/// Extract a diagonal from a 2D array
137///
138/// # Arguments
139///
140/// * `array` - The input 2D array
141/// * `k` - Diagonal offset (0 for main diagonal, positive for above, negative for below)
142///
143/// # Returns
144///
145/// A 1D array containing the specified diagonal
146///
147/// # Examples
148///
149/// ```
150/// use ndarray::array;
151/// use scirs2_core::ndarray_ext::matrix::diagonal;
152///
153/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
154///
155/// // Extract main diagonal
156/// let main_diag = diagonal(a.view(), 0).unwrap();
157/// assert_eq!(main_diag, array![1, 5, 9]);
158///
159/// // Extract superdiagonal
160/// let super_diag = diagonal(a.view(), 1).unwrap();
161/// assert_eq!(super_diag, array![2, 6]);
162///
163/// // Extract subdiagonal
164/// let sub_diag = diagonal(a.view(), -1).unwrap();
165/// assert_eq!(sub_diag, array![4, 8]);
166/// ```
167#[allow(dead_code)]
168pub fn diagonal<T>(array: ArrayView<T, Ix2>, k: isize) -> Result<Array<T, Ix1>, &'static str>
169where
170    T: Clone + Zero,
171{
172    let (rows, cols) = (array.shape()[0], array.shape()[1]);
173
174    // Calculate the length of the diagonal
175    let diag_len = if k >= 0 {
176        std::cmp::min(rows, cols.saturating_sub(k as usize))
177    } else {
178        std::cmp::min(cols, rows.saturating_sub((-k) as usize))
179    };
180
181    if diag_len == 0 {
182        return Err("No diagonal elements for the given offset");
183    }
184
185    // Create the result array directly
186    let mut result = Array::<T, Ix1>::zeros(diag_len);
187
188    // Extract the diagonal elements
189    for i in 0..diag_len {
190        let row = if k < 0 { i + (-k) as usize } else { i };
191
192        let col = if k > 0 { i + k as usize } else { i };
193
194        result[i] = array[[row, col]].clone();
195    }
196
197    Ok(result)
198}
199
200/// Create a matrix filled with a given value
201///
202/// # Arguments
203///
204/// * `rows` - Number of rows
205/// * `cols` - Number of columns
206/// * `value` - Value to fill the matrix with
207///
208/// # Returns
209///
210/// A matrix filled with the specified value
211///
212/// # Examples
213///
214/// ```
215/// use scirs2_core::ndarray_ext::matrix::full;
216///
217/// let filled = full(2, 3, 7);
218/// assert_eq!(filled.shape(), &[2, 3]);
219/// assert_eq!(filled[[0, 0]], 7);
220/// assert_eq!(filled[[1, 2]], 7);
221/// ```
222#[allow(dead_code)]
223pub fn full<T>(rows: usize, cols: usize, value: T) -> Array<T, Ix2>
224where
225    T: Clone,
226{
227    Array::<T, Ix2>::from_elem((rows, cols), value)
228}
229
230/// Create a matrix filled with ones
231///
232/// # Arguments
233///
234/// * `rows` - Number of rows
235/// * `cols` - Number of columns
236///
237/// # Returns
238///
239/// A matrix filled with ones
240///
241/// # Examples
242///
243/// ```
244/// use scirs2_core::ndarray_ext::matrix::ones;
245///
246/// let ones_mat = ones::<f64>(2, 3);
247/// assert_eq!(ones_mat.shape(), &[2, 3]);
248/// assert_eq!(ones_mat[[0, 0]], 1.0);
249/// assert_eq!(ones_mat[[1, 2]], 1.0);
250/// ```
251#[allow(dead_code)]
252pub fn ones<T>(rows: usize, cols: usize) -> Array<T, Ix2>
253where
254    T: Clone + One,
255{
256    Array::<T, Ix2>::from_elem((rows, cols), T::one())
257}
258
259/// Create a matrix filled with zeros
260///
261/// # Arguments
262///
263/// * `rows` - Number of rows
264/// * `cols` - Number of columns
265///
266/// # Returns
267///
268/// A matrix filled with zeros
269///
270/// # Examples
271///
272/// ```ignore
273/// use scirs2_core::ndarray_ext::matrix::zeros;
274///
275/// let zeros_mat = zeros::<f64>(2, 3);
276/// assert_eq!(zeros_mat.shape(), &[2, 3]);
277/// assert_eq!(zeros_mat[[0, 0]], 0.0);
278/// assert_eq!(zeros_mat[[1, 2]], 0.0);
279/// ```
280#[allow(dead_code)]
281pub fn zeros<T>(rows: usize, cols: usize) -> Array<T, Ix2>
282where
283    T: Clone + Zero,
284{
285    Array::<T, Ix2>::zeros((rows, cols))
286}
287
288/// Compute the Kronecker product of two 2D arrays
289///
290/// # Arguments
291///
292/// * `a` - First input array
293/// * `b` - Second input array
294///
295/// # Returns
296///
297/// The Kronecker product of the input arrays
298///
299/// # Examples
300///
301/// ```
302/// use ndarray::array;
303/// use scirs2_core::ndarray_ext::matrix::kron;
304///
305/// let a = array![[1, 2], [3, 4]];
306/// let b = array![[0, 5], [6, 7]];
307///
308/// let result = kron(a.view(), b.view());
309/// assert_eq!(result.shape(), &[4, 4]);
310/// assert_eq!(result, array![
311///     [0, 5, 0, 10],
312///     [6, 7, 12, 14],
313///     [0, 15, 0, 20],
314///     [18, 21, 24, 28]
315/// ]);
316/// ```
317#[allow(dead_code)]
318pub fn kron<T>(a: ArrayView<T, Ix2>, b: ArrayView<T, Ix2>) -> Array<T, Ix2>
319where
320    T: Clone + Zero + std::ops::Mul<Output = T>,
321{
322    let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
323    let (b_rows, b_cols) = (b.shape()[0], b.shape()[1]);
324
325    let result_rows = a_rows * b_rows;
326    let result_cols = a_cols * b_cols;
327
328    let mut result = Array::<T, Ix2>::zeros((result_rows, result_cols));
329
330    for i in 0..a_rows {
331        for j in 0..a_cols {
332            for k in 0..b_rows {
333                for l in 0..b_cols {
334                    result[[i * b_rows + k, j * b_cols + l]] =
335                        a[[i, j]].clone() * b[[k, l]].clone();
336                }
337            }
338        }
339    }
340
341    result
342}
343
344/// Create a Toeplitz matrix from first row and first column
345///
346/// # Arguments
347///
348/// * `first_row` - First row of the Toeplitz matrix
349/// * `first_col` - First column of the Toeplitz matrix (first element must match first row's first element)
350///
351/// # Returns
352///
353/// A Toeplitz matrix with the specified first row and column
354///
355/// # Examples
356///
357/// ```
358/// use ndarray::array;
359/// use scirs2_core::ndarray_ext::matrix::toeplitz;
360///
361/// let first_row = array![1, 2, 3];
362/// let first_col = array![1, 4, 5];
363/// let result = toeplitz(first_row.view(), first_col.view()).unwrap();
364/// assert_eq!(result.shape(), &[3, 3]);
365/// assert_eq!(result, array![
366///     [1, 2, 3],
367///     [4, 1, 2],
368///     [5, 4, 1]
369/// ]);
370/// ```
371#[allow(dead_code)]
372pub fn toeplitz<T>(
373    first_row: ArrayView<T, Ix1>,
374    first_col: ArrayView<T, Ix1>,
375) -> Result<Array<T, Ix2>, &'static str>
376where
377    T: Clone + PartialEq + Zero,
378{
379    // First elements of first_row and first_col must match
380    if first_row.is_empty() || first_col.is_empty() {
381        return Err("Input arrays must not be empty");
382    }
383
384    if first_row[0] != first_col[0] {
385        return Err("First element of _row and column must match");
386    }
387
388    let n = first_col.len(); // Number of rows
389    let m = first_row.len(); // Number of columns
390
391    let mut result = Array::<T, Ix2>::zeros((n, m));
392
393    for i in 0..n {
394        for j in 0..m {
395            if i <= j {
396                // Upper triangle and main diagonal from first_row
397                result[[i, j]] = first_row[j - i].clone();
398            } else {
399                // Lower triangle from first_col
400                result[[i, j]] = first_col[i - j].clone();
401            }
402        }
403    }
404
405    Ok(result)
406}
407
408/// Create a block diagonal matrix from a sequence of 2D arrays
409///
410/// # Arguments
411///
412/// * `arrays` - A slice of 2D arrays to form the blocks on the diagonal
413///
414/// # Returns
415///
416/// A block diagonal matrix with the input arrays on the diagonal
417///
418/// # Examples
419///
420/// ```
421/// use ndarray::array;
422/// use scirs2_core::ndarray_ext::matrix::block_diag;
423///
424/// let a = array![[1, 2], [3, 4]];
425/// let b = array![[5, 6], [7, 8]];
426///
427/// let result = block_diag(&[a.view(), b.view()]);
428/// assert_eq!(result.shape(), &[4, 4]);
429/// assert_eq!(result, array![
430///     [1, 2, 0, 0],
431///     [3, 4, 0, 0],
432///     [0, 0, 5, 6],
433///     [0, 0, 7, 8]
434/// ]);
435/// ```
436#[allow(dead_code)]
437pub fn block_diag<T>(arrays: &[ArrayView<T, Ix2>]) -> Array<T, Ix2>
438where
439    T: Clone + Zero,
440{
441    if arrays.is_empty() {
442        return Array::<T, Ix2>::zeros((0, 0));
443    }
444
445    // Calculate total dimensions
446    let mut total_rows = 0;
447    let mut total_cols = 0;
448
449    for array in arrays {
450        total_rows += array.shape()[0];
451        total_cols += array.shape()[1];
452    }
453
454    let mut result = Array::<T, Ix2>::zeros((total_rows, total_cols));
455
456    let mut row_offset = 0;
457    let mut col_offset = 0;
458
459    // Place each array on the diagonal
460    for array in arrays {
461        let (rows, cols) = (array.shape()[0], array.shape()[1]);
462
463        for i in 0..rows {
464            for j in 0..cols {
465                result[[row_offset + i, col_offset + j]] = array[[i, j]].clone();
466            }
467        }
468
469        row_offset += rows;
470        col_offset += cols;
471    }
472
473    result
474}
475
476/// Create a tri-diagonal matrix from the three diagonals
477///
478/// # Arguments
479///
480/// * `diag` - Main diagonal
481/// * `lower_diag` - Lower diagonal
482/// * `upper_diag` - Upper diagonal
483///
484/// # Returns
485///
486/// A tri-diagonal matrix with the specified diagonals
487///
488/// # Examples
489///
490/// ```
491/// use ndarray::array;
492/// use scirs2_core::ndarray_ext::matrix::tridiagonal;
493///
494/// let diag = array![1, 2, 3];
495/// let lower = array![4, 5];
496/// let upper = array![6, 7];
497///
498/// let result = tridiagonal(diag.view(), lower.view(), upper.view()).unwrap();
499/// assert_eq!(result.shape(), &[3, 3]);
500/// assert_eq!(result, array![
501///     [1, 6, 0],
502///     [4, 2, 7],
503///     [0, 5, 3]
504/// ]);
505/// ```
506#[allow(dead_code)]
507pub fn tridiagonal<T>(
508    diag: ArrayView<T, Ix1>,
509    lower_diag: ArrayView<T, Ix1>,
510    upper_diag: ArrayView<T, Ix1>,
511) -> Result<Array<T, Ix2>, &'static str>
512where
513    T: Clone + Zero,
514{
515    let n = diag.len();
516
517    // Check that the diagonals have correct sizes
518    if lower_diag.len() != n - 1 || upper_diag.len() != n - 1 {
519        return Err("Lower and upper diagonals must have length n-1 where n is the length of the main diagonal");
520    }
521
522    let mut result = Array::<T, Ix2>::zeros((n, n));
523
524    // Set main diagonal
525    for i in 0..n {
526        result[[i, i]] = diag[i].clone();
527    }
528
529    // Set lower diagonal
530    for i in 1..n {
531        result[[i, i - 1]] = lower_diag[i - 1].clone();
532    }
533
534    // Set upper diagonal
535    for i in 0..n - 1 {
536        result[[i, i + 1]] = upper_diag[i].clone();
537    }
538
539    Ok(result)
540}
541
542/// Create a Hankel matrix from its first column and last row
543///
544/// # Arguments
545///
546/// * `first_col` - First column of the Hankel matrix
547/// * `last_row` - Last row of the Hankel matrix (first element must match last element of first_col)
548///
549/// # Returns
550///
551/// A Hankel matrix with the specified first column and last row
552///
553/// # Examples
554///
555/// ```
556/// use ndarray::array;
557/// use scirs2_core::ndarray_ext::matrix::hankel;
558///
559/// let first_col = array![1, 2, 3];
560/// let last_row = array![3, 4, 5];
561///
562/// let result = hankel(first_col.view(), last_row.view()).unwrap();
563/// assert_eq!(result.shape(), &[3, 3]);
564/// assert_eq!(result, array![
565///     [1, 2, 3],
566///     [2, 3, 4],
567///     [3, 4, 5]
568/// ]);
569/// ```
570#[allow(dead_code)]
571pub fn hankel<T>(
572    first_col: ArrayView<T, Ix1>,
573    last_row: ArrayView<T, Ix1>,
574) -> Result<Array<T, Ix2>, &'static str>
575where
576    T: Clone + PartialEq + Zero,
577{
578    if first_col.is_empty() || last_row.is_empty() {
579        return Err("Input arrays must not be empty");
580    }
581
582    // Last element of first_col must match first element of last_row
583    if first_col[first_col.len() - 1] != last_row[0] {
584        return Err("Last element of first column must match first element of last _row");
585    }
586
587    let n = first_col.len(); // Number of rows
588    let m = last_row.len(); // Number of columns
589
590    let mut result = Array::<T, Ix2>::zeros((n, m));
591
592    // Combine first_col and last_row (minus first element of last_row) to form the "data" array
593    let data_len = n + m - 1;
594    let mut data = Vec::with_capacity(data_len);
595
596    // Fill data with first_col elements
597    for i in 0..n {
598        data.push(first_col[i].clone());
599    }
600
601    // Append last_row elements (skipping first element which should match last element of first_col)
602    for i in 1..m {
603        data.push(last_row[i].clone());
604    }
605
606    // Fill the Hankel matrix
607    for i in 0..n {
608        for j in 0..m {
609            result[[i, j]] = data[i + j].clone();
610        }
611    }
612
613    Ok(result)
614}
615
616/// Calculate the trace of a square matrix (sum of diagonal elements)
617///
618/// # Arguments
619///
620/// * `array` - Input square matrix
621///
622/// # Returns
623///
624/// The trace of the matrix
625///
626/// # Examples
627///
628/// ```
629/// use ndarray::array;
630/// use scirs2_core::ndarray_ext::matrix::trace;
631///
632/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
633/// let tr = trace(a.view()).unwrap();
634/// assert_eq!(tr, 15);  // 1 + 5 + 9 = 15
635/// ```
636#[allow(dead_code)]
637pub fn trace<T>(array: ArrayView<T, Ix2>) -> Result<T, &'static str>
638where
639    T: Clone + Zero + std::ops::Add<Output = T>,
640{
641    let (rows, cols) = (array.shape()[0], array.shape()[1]);
642
643    if rows != cols {
644        return Err("Trace is only defined for square matrices");
645    }
646
647    let mut result = T::zero();
648
649    for i in 0..rows {
650        result = result + array[[i, i]].clone();
651    }
652
653    Ok(result)
654}
655
656/// Create a matrix vander from a 1D array
657///
658/// # Arguments
659///
660/// * `x` - Input 1D array
661/// * `n` - Optional number of columns in the output (defaults to x.len())
662/// * `increasing` - Optional boolean to determine order (defaults to false)
663///
664/// # Returns
665///
666/// A Vandermonde matrix where each column is a power of the input array
667///
668/// # Examples
669///
670/// ```ignore
671/// use ndarray::array;
672/// use scirs2_core::ndarray_ext::matrix::vander;
673///
674/// let x = array![1.0, 2.0, 3.0];
675///
676/// // Default behavior: decreasing powers from n-1 to 0
677/// let v1 = vander(x.view(), None, None).unwrap();
678/// assert_eq!(v1.shape(), &[3, 3]);
679/// // Powers: x^2, x^1, x^0
680/// assert_eq!(v1, array![
681///     [1.0, 1.0, 1.0],
682///     [4.0, 2.0, 1.0],
683///     [9.0, 3.0, 1.0]
684/// ]);
685///
686/// // Increasing powers: 0 to n-1
687/// let v2 = vander(x.view(), None, Some(true)).unwrap();
688/// assert_eq!(v2.shape(), &[3, 3]);
689/// // Powers: x^0, x^1, x^2
690/// assert_eq!(v2, array![
691///     [1.0, 1.0, 1.0],
692///     [1.0, 2.0, 4.0],
693///     [1.0, 3.0, 9.0]
694/// ]);
695///
696/// // Specify 4 columns
697/// let v3 = vander(x.view(), Some(4), None).unwrap();
698/// assert_eq!(v3.shape(), &[3, 4]);
699/// // Powers: x^3, x^2, x^1, x^0
700/// assert_eq!(v3, array![
701///     [1.0, 1.0, 1.0, 1.0],
702///     [8.0, 4.0, 2.0, 1.0],
703///     [27.0, 9.0, 3.0, 1.0]
704/// ]);
705/// ```
706#[allow(dead_code)]
707pub fn vander<T>(
708    x: ArrayView<T, Ix1>,
709    n: Option<usize>,
710    increasing: Option<bool>,
711) -> Result<Array<T, Ix2>, &'static str>
712where
713    T: Clone + Zero + One + std::ops::Mul<Output = T> + std::ops::Div<Output = T>,
714{
715    let x_len = x.len();
716
717    if x_len == 0 {
718        return Err("Input array must not be empty");
719    }
720
721    let n = n.unwrap_or(x_len);
722    let increasing = increasing.unwrap_or(false);
723
724    let mut result = Array::<T, Ix2>::zeros((x_len, n));
725
726    // Fill in the Vandermonde matrix with powers of x elements
727    for i in 0..x_len {
728        // Initialize accumulator with 1 (x^0)
729        let mut power = T::one();
730
731        if increasing {
732            // First column is x^0 (all ones)
733            for j in 0..n {
734                result[[i, j]] = power.clone();
735
736                if j < n - 1 {
737                    power = power.clone() * x[i].clone();
738                }
739            }
740        } else {
741            // Decreasing powers (last column is x^0)
742            // Calculate highest power first: x^(n-1)
743            for _p in 0..n - 1 {
744                power = power.clone() * x[i].clone();
745            }
746
747            for j in 0..n {
748                result[[i, j]] = power.clone();
749
750                if j < n - 1 {
751                    // For non-increasing powers, we need Div trait for T to handle division
752                    // power = power.clone() / x[i].clone(); // This requires Div trait
753
754                    // Use multiplication by reciprocal instead as a safer approach
755                    power = power.clone() * (T::one() / x[i].clone());
756                }
757            }
758        }
759    }
760
761    Ok(result)
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767    use approx::assert_abs_diff_eq;
768    use ndarray::array;
769
770    #[test]
771    fn test_eye() {
772        let id3 = eye::<f64>(3);
773        assert_eq!(id3.shape(), &[3, 3]);
774        assert_eq!(id3[[0, 0]], 1.0);
775        assert_eq!(id3[[1, 1]], 1.0);
776        assert_eq!(id3[[2, 2]], 1.0);
777        assert_eq!(id3[[0, 1]], 0.0);
778        assert_eq!(id3[[1, 0]], 0.0);
779    }
780
781    #[test]
782    fn test_eye_offset() {
783        // 3x3 matrix with ones on the main diagonal (k=0)
784        let id3 = eye_offset::<f64>(3, 3, 0);
785        assert_eq!(id3.shape(), &[3, 3]);
786        assert_eq!(id3[[0, 0]], 1.0);
787        assert_eq!(id3[[1, 1]], 1.0);
788        assert_eq!(id3[[2, 2]], 1.0);
789
790        // 3x4 matrix with ones on the first superdiagonal (k=1)
791        let super_diag = eye_offset::<f64>(3, 4, 1);
792        assert_eq!(super_diag.shape(), &[3, 4]);
793        assert_eq!(super_diag[[0, 1]], 1.0);
794        assert_eq!(super_diag[[1, 2]], 1.0);
795        assert_eq!(super_diag[[2, 3]], 1.0);
796
797        // 4x3 matrix with ones on the first subdiagonal (k=-1)
798        let sub_diag = eye_offset::<f64>(4, 3, -1);
799        assert_eq!(sub_diag.shape(), &[4, 3]);
800        assert_eq!(sub_diag[[1, 0]], 1.0);
801        assert_eq!(sub_diag[[2, 1]], 1.0);
802        assert_eq!(sub_diag[[3, 2]], 1.0);
803    }
804
805    #[test]
806    fn test_diag() {
807        let values = array![1, 2, 3];
808        let diagmatrix = _diag(values.view());
809        assert_eq!(diagmatrix.shape(), &[3, 3]);
810        assert_eq!(diagmatrix[[0, 0]], 1);
811        assert_eq!(diagmatrix[[1, 1]], 2);
812        assert_eq!(diagmatrix[[2, 2]], 3);
813        assert_eq!(diagmatrix[[0, 1]], 0);
814    }
815
816    #[test]
817    fn test_diagonal() {
818        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
819
820        // Extract main diagonal
821        let main_diag = diagonal(a.view(), 0).unwrap();
822        assert_eq!(main_diag, array![1, 5, 9]);
823
824        // Extract superdiagonal
825        let super_diag = diagonal(a.view(), 1).unwrap();
826        assert_eq!(super_diag, array![2, 6]);
827
828        // Extract subdiagonal
829        let sub_diag = diagonal(a.view(), -1).unwrap();
830        assert_eq!(sub_diag, array![4, 8]);
831
832        // Test out of bounds (should return error)
833        assert!(diagonal(a.view(), 3).is_err());
834        assert!(diagonal(a.view(), -3).is_err());
835    }
836
837    #[test]
838    fn test_full_ones_zeros() {
839        // Test full
840        let filled = full(2, 3, 7);
841        assert_eq!(filled.shape(), &[2, 3]);
842        assert_eq!(filled[[0, 0]], 7);
843        assert_eq!(filled[[1, 2]], 7);
844
845        // Test ones
846        let ones_mat = ones::<f64>(2, 3);
847        assert_eq!(ones_mat.shape(), &[2, 3]);
848        assert_eq!(ones_mat[[0, 0]], 1.0);
849        assert_eq!(ones_mat[[1, 2]], 1.0);
850
851        // Test zeros
852        let zeros_mat = zeros::<f64>(2, 3);
853        assert_eq!(zeros_mat.shape(), &[2, 3]);
854        assert_eq!(zeros_mat[[0, 0]], 0.0);
855        assert_eq!(zeros_mat[[1, 2]], 0.0);
856    }
857
858    #[test]
859    fn test_kron() {
860        let a = array![[1, 2], [3, 4]];
861        let b = array![[0, 5], [6, 7]];
862
863        let result = kron(a.view(), b.view());
864        assert_eq!(result.shape(), &[4, 4]);
865
866        // Check results
867        assert_eq!(
868            result,
869            array![
870                [0, 5, 0, 10],
871                [6, 7, 12, 14],
872                [0, 15, 0, 20],
873                [18, 21, 24, 28]
874            ]
875        );
876    }
877
878    #[test]
879    fn test_toeplitz() {
880        let first_row = array![1, 2, 3];
881        let first_col = array![1, 4, 5];
882
883        let result = toeplitz(first_row.view(), first_col.view()).unwrap();
884        assert_eq!(result.shape(), &[3, 3]);
885        assert_eq!(result, array![[1, 2, 3], [4, 1, 2], [5, 4, 1]]);
886
887        // Test with mismatched first elements (should return error)
888        let bad_row = array![9, 2, 3];
889        assert!(toeplitz(bad_row.view(), first_col.view()).is_err());
890    }
891
892    #[test]
893    fn test_block_diag() {
894        let a = array![[1, 2], [3, 4]];
895        let b = array![[5, 6], [7, 8]];
896
897        let result = block_diag(&[a.view(), b.view()]);
898        assert_eq!(result.shape(), &[4, 4]);
899        assert_eq!(
900            result,
901            array![[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 5, 6], [0, 0, 7, 8]]
902        );
903
904        // Test with different size blocks
905        let c = array![[9]];
906        let result2 = block_diag(&[a.view(), c.view()]);
907        assert_eq!(result2.shape(), &[3, 3]);
908        assert_eq!(result2, array![[1, 2, 0], [3, 4, 0], [0, 0, 9]]);
909
910        // Test with empty array list
911        let empty: [ArrayView<i32, Ix2>; 0] = [];
912        let result3 = block_diag(&empty);
913        assert_eq!(result3.shape(), &[0, 0]);
914    }
915
916    #[test]
917    fn test_tridiagonal() {
918        let diag = array![1, 2, 3];
919        let lower = array![4, 5];
920        let upper = array![6, 7];
921
922        let result = tridiagonal(diag.view(), lower.view(), upper.view()).unwrap();
923        assert_eq!(result.shape(), &[3, 3]);
924        assert_eq!(result, array![[1, 6, 0], [4, 2, 7], [0, 5, 3]]);
925
926        // Test with incorrect diagonals size (should return error)
927        let bad_lower = array![4];
928        assert!(tridiagonal(diag.view(), bad_lower.view(), upper.view()).is_err());
929    }
930
931    #[test]
932    fn test_hankel() {
933        let first_col = array![1, 2, 3];
934        let last_row = array![3, 4, 5];
935
936        let result = hankel(first_col.view(), last_row.view()).unwrap();
937        assert_eq!(result.shape(), &[3, 3]);
938        assert_eq!(result, array![[1, 2, 3], [2, 3, 4], [3, 4, 5]]);
939
940        // Test with mismatched elements (should return error)
941        let bad_row = array![9, 4, 5];
942        assert!(hankel(first_col.view(), bad_row.view()).is_err());
943    }
944
945    #[test]
946    fn test_trace() {
947        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
948        let tr = trace(a.view()).unwrap();
949        assert_eq!(tr, 15); // 1 + 5 + 9 = 15
950
951        // Test non-square matrix (should return error)
952        let b = array![[1, 2, 3], [4, 5, 6]];
953        assert!(trace(b.view()).is_err());
954    }
955
956    #[test]
957    fn test_vander() {
958        let x = array![1.0, 2.0, 3.0];
959
960        // Default behavior: decreasing powers from n-1 to 0
961        let v1 = vander(x.view(), None, None).unwrap();
962        assert_eq!(v1.shape(), &[3, 3]);
963        // Should be equivalent to [x^2, x^1, x^0]
964        for i in 0..3 {
965            assert_abs_diff_eq!(v1[[0, 0]], x[0] * x[0]);
966            assert_abs_diff_eq!(v1[[0, 1]], x[0]);
967            assert_abs_diff_eq!(v1[[0, 2]], 1.0);
968        }
969
970        // Increasing powers from 0 to n-1
971        let v2 = vander(x.view(), None, Some(true)).unwrap();
972        assert_eq!(v2.shape(), &[3, 3]);
973        // Should be equivalent to [x^0, x^1, x^2]
974        for i in 0..3 {
975            assert_abs_diff_eq!(v2[[0, 0]], 1.0);
976            assert_abs_diff_eq!(v2[[0, 1]], x[0]);
977            assert_abs_diff_eq!(v2[[0, 2]], x[0] * x[0]);
978        }
979
980        // Specify 4 columns (decreasing power)
981        let v3 = vander(x.view(), Some(4), None).unwrap();
982        assert_eq!(v3.shape(), &[3, 4]);
983        // Should be equivalent to [x^3, x^2, x^1, x^0]
984        for i in 0..3 {
985            assert_abs_diff_eq!(v3[[0, 0]], x[0] * x[0] * x[0]);
986            assert_abs_diff_eq!(v3[[0, 1]], x[0] * x[0]);
987            assert_abs_diff_eq!(v3[[0, 2]], x[0]);
988            assert_abs_diff_eq!(v3[[0, 3]], 1.0);
989        }
990    }
991}