scirs2_sparse/
construct.rs

1// Construction utilities for sparse arrays
2//
3// This module provides functions for constructing sparse arrays,
4// including identity matrices, diagonal matrices, random arrays, etc.
5
6#![allow(unused_variables)]
7#![allow(unused_assignments)]
8#![allow(unused_mut)]
9
10use scirs2_core::ndarray::Array1;
11use scirs2_core::numeric::Float;
12use scirs2_core::random::seq::SliceRandom;
13use scirs2_core::random::{Rng, SeedableRng};
14use std::fmt::Debug;
15use std::ops::{Add, Div, Mul, Sub};
16
17use crate::coo_array::CooArray;
18use crate::csr_array::CsrArray;
19use crate::dok_array::DokArray;
20use crate::error::{SparseError, SparseResult};
21use crate::lil_array::LilArray;
22use crate::sparray::SparseArray;
23
24// Import parallel operations from scirs2-core
25use scirs2_core::parallel_ops::*;
26
27/// Creates a sparse identity array of size n x n
28///
29/// # Arguments
30/// * `n` - Size of the square array
31/// * `format` - Format of the output array ("csr" or "coo")
32///
33/// # Returns
34/// A sparse array representing the identity matrix
35///
36/// # Examples
37///
38/// ```
39/// use scirs2_sparse::construct::eye_array;
40///
41/// let eye: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array(3, "csr").unwrap();
42/// assert_eq!(eye.shape(), (3, 3));
43/// assert_eq!(eye.nnz(), 3);
44/// assert_eq!(eye.get(0, 0), 1.0);
45/// assert_eq!(eye.get(1, 1), 1.0);
46/// assert_eq!(eye.get(2, 2), 1.0);
47/// assert_eq!(eye.get(0, 1), 0.0);
48/// ```
49#[allow(dead_code)]
50pub fn eye_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SparseArray<T>>>
51where
52    T: Float
53        + Add<Output = T>
54        + Sub<Output = T>
55        + Mul<Output = T>
56        + Div<Output = T>
57        + Debug
58        + Copy
59        + 'static,
60{
61    if n == 0 {
62        return Err(SparseError::ValueError(
63            "Matrix dimension must be positive".to_string(),
64        ));
65    }
66
67    let mut rows = Vec::with_capacity(n);
68    let mut cols = Vec::with_capacity(n);
69    let mut data = Vec::with_capacity(n);
70
71    for i in 0..n {
72        rows.push(i);
73        cols.push(i);
74        data.push(T::one());
75    }
76
77    match format.to_lowercase().as_str() {
78        "csr" => CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
79            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
80        "coo" => CooArray::from_triplets(&rows, &cols, &data, (n, n), true)
81            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
82        "dok" => DokArray::from_triplets(&rows, &cols, &data, (n, n))
83            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
84        "lil" => LilArray::from_triplets(&rows, &cols, &data, (n, n))
85            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
86        _ => Err(SparseError::ValueError(format!(
87            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
88        ))),
89    }
90}
91
92/// Creates a sparse identity array of size m x n with k-th diagonal filled with ones
93///
94/// # Arguments
95/// * `m` - Number of rows
96/// * `n` - Number of columns
97/// * `k` - Diagonal index (0 = main diagonal, >0 = above main, <0 = below main)
98/// * `format` - Format of the output array ("csr" or "coo")
99///
100/// # Returns
101/// A sparse array with ones on the specified diagonal
102///
103/// # Examples
104///
105/// ```
106/// use scirs2_sparse::construct::eye_array_k;
107///
108/// // Identity with main diagonal (k=0)
109/// let eye: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(3, 3, 0, "csr").unwrap();
110/// assert_eq!(eye.get(0, 0), 1.0);
111/// assert_eq!(eye.get(1, 1), 1.0);
112/// assert_eq!(eye.get(2, 2), 1.0);
113///
114/// // Superdiagonal (k=1)
115/// let superdiag: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(3, 4, 1, "csr").unwrap();
116/// assert_eq!(superdiag.get(0, 1), 1.0);
117/// assert_eq!(superdiag.get(1, 2), 1.0);
118/// assert_eq!(superdiag.get(2, 3), 1.0);
119///
120/// // Subdiagonal (k=-1)
121/// let subdiag: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(4, 3, -1, "csr").unwrap();
122/// assert_eq!(subdiag.get(1, 0), 1.0);
123/// assert_eq!(subdiag.get(2, 1), 1.0);
124/// assert_eq!(subdiag.get(3, 2), 1.0);
125/// ```
126#[allow(dead_code)]
127pub fn eye_array_k<T>(
128    m: usize,
129    n: usize,
130    k: isize,
131    format: &str,
132) -> SparseResult<Box<dyn SparseArray<T>>>
133where
134    T: Float
135        + Add<Output = T>
136        + Sub<Output = T>
137        + Mul<Output = T>
138        + Div<Output = T>
139        + Debug
140        + Copy
141        + 'static,
142{
143    if m == 0 || n == 0 {
144        return Err(SparseError::ValueError(
145            "Matrix dimensions must be positive".to_string(),
146        ));
147    }
148
149    let mut rows = Vec::new();
150    let mut cols = Vec::new();
151    let mut data = Vec::new();
152
153    // Calculate diagonal elements
154    if k >= 0 {
155        let k_usize = k as usize;
156        let len = std::cmp::min(m, n.saturating_sub(k_usize));
157
158        for i in 0..len {
159            rows.push(i);
160            cols.push(i + k_usize);
161            data.push(T::one());
162        }
163    } else {
164        let k_abs = (-k) as usize;
165        let len = std::cmp::min(m.saturating_sub(k_abs), n);
166
167        for i in 0..len {
168            rows.push(i + k_abs);
169            cols.push(i);
170            data.push(T::one());
171        }
172    }
173
174    match format.to_lowercase().as_str() {
175        "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
176            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
177        "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), true)
178            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
179        "dok" => DokArray::from_triplets(&rows, &cols, &data, (m, n))
180            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
181        "lil" => LilArray::from_triplets(&rows, &cols, &data, (m, n))
182            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
183        _ => Err(SparseError::ValueError(format!(
184            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
185        ))),
186    }
187}
188
189/// Creates a sparse array from the specified diagonals
190///
191/// # Arguments
192/// * `diagonals` - Data for the diagonals
193/// * `offsets` - Offset for each diagonal (0 = main, >0 = above main, <0 = below main)
194/// * `shape` - Shape of the output array (m, n)
195/// * `format` - Format of the output array ("csr" or "coo")
196///
197/// # Returns
198/// A sparse array with the specified diagonals
199///
200/// # Examples
201///
202/// ```
203/// use scirs2_sparse::construct::diags_array;
204/// use scirs2_core::ndarray::Array1;
205///
206/// let diags = vec![
207///     Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
208///     Array1::from_vec(vec![4.0, 5.0])       // superdiagonal
209/// ];
210/// let offsets = vec![0, 1];
211/// let shape = (3, 3);
212///
213/// let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
214/// assert_eq!(result.shape(), (3, 3));
215/// assert_eq!(result.get(0, 0), 1.0);
216/// assert_eq!(result.get(1, 1), 2.0);
217/// assert_eq!(result.get(2, 2), 3.0);
218/// assert_eq!(result.get(0, 1), 4.0);
219/// assert_eq!(result.get(1, 2), 5.0);
220/// ```
221#[allow(dead_code)]
222pub fn diags_array<T>(
223    diagonals: &[Array1<T>],
224    offsets: &[isize],
225    shape: (usize, usize),
226    format: &str,
227) -> SparseResult<Box<dyn SparseArray<T>>>
228where
229    T: Float
230        + Add<Output = T>
231        + Sub<Output = T>
232        + Mul<Output = T>
233        + Div<Output = T>
234        + Debug
235        + Copy
236        + 'static,
237{
238    if diagonals.len() != offsets.len() {
239        return Err(SparseError::InconsistentData {
240            reason: "Number of diagonals must match number of offsets".to_string(),
241        });
242    }
243
244    if shape.0 == 0 || shape.1 == 0 {
245        return Err(SparseError::ValueError(
246            "Matrix dimensions must be positive".to_string(),
247        ));
248    }
249
250    let (m, n) = shape;
251    let mut rows = Vec::new();
252    let mut cols = Vec::new();
253    let mut data = Vec::new();
254
255    for (i, (diag, &offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
256        if offset >= 0 {
257            let offset_usize = offset as usize;
258            let max_len = std::cmp::min(m, n.saturating_sub(offset_usize));
259
260            if diag.len() > max_len {
261                return Err(SparseError::InconsistentData {
262                    reason: format!("Diagonal {i} is too long ({} > {})", diag.len(), max_len),
263                });
264            }
265
266            for (j, &value) in diag.iter().enumerate() {
267                if !value.is_zero() {
268                    rows.push(j);
269                    cols.push(j + offset_usize);
270                    data.push(value);
271                }
272            }
273        } else {
274            let offset_abs = (-offset) as usize;
275            let max_len = std::cmp::min(m.saturating_sub(offset_abs), n);
276
277            if diag.len() > max_len {
278                return Err(SparseError::InconsistentData {
279                    reason: format!("Diagonal {i} is too long ({} > {})", diag.len(), max_len),
280                });
281            }
282
283            for (j, &value) in diag.iter().enumerate() {
284                if !value.is_zero() {
285                    rows.push(j + offset_abs);
286                    cols.push(j);
287                    data.push(value);
288                }
289            }
290        }
291    }
292
293    match format.to_lowercase().as_str() {
294        "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
295            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
296        "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
297            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
298        "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
299            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
300        "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
301            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
302        _ => Err(SparseError::ValueError(format!(
303            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
304        ))),
305    }
306}
307
308/// Creates a random sparse array with specified density
309///
310/// # Arguments
311/// * `shape` - Shape of the output array (m, n)
312/// * `density` - Density of non-zero elements (between 0.0 and 1.0)
313/// * `seed` - Optional seed for the random number generator
314/// * `format` - Format of the output array ("csr" or "coo")
315///
316/// # Returns
317/// A sparse array with random non-zero elements
318///
319/// # Examples
320///
321/// ```
322/// use scirs2_sparse::construct::random_array;
323///
324/// // Create a 10x10 array with 30% non-zero elements
325/// let random = random_array::<f64>((10, 10), 0.3, None, "csr").unwrap();
326/// assert_eq!(random.shape(), (10, 10));
327///
328/// // Create a random array with a specific seed
329/// let seeded = random_array::<f64>((5, 5), 0.5, Some(42), "coo").unwrap();
330/// assert_eq!(seeded.shape(), (5, 5));
331/// ```
332#[allow(dead_code)]
333pub fn random_array<T>(
334    shape: (usize, usize),
335    density: f64,
336    seed: Option<u64>,
337    format: &str,
338) -> SparseResult<Box<dyn SparseArray<T>>>
339where
340    T: Float
341        + Add<Output = T>
342        + Sub<Output = T>
343        + Mul<Output = T>
344        + Div<Output = T>
345        + Debug
346        + Copy
347        + 'static,
348{
349    let (m, n) = shape;
350
351    if !(0.0..=1.0).contains(&density) {
352        return Err(SparseError::ValueError(
353            "Density must be between 0.0 and 1.0".to_string(),
354        ));
355    }
356
357    if m == 0 || n == 0 {
358        return Err(SparseError::ValueError(
359            "Matrix dimensions must be positive".to_string(),
360        ));
361    }
362
363    // Calculate the number of non-zero elements
364    let nnz = (m * n) as f64 * density;
365    let nnz = nnz.round() as usize;
366
367    // Create random indices
368    let mut rows = Vec::with_capacity(nnz);
369    let mut cols = Vec::with_capacity(nnz);
370    let mut data = Vec::with_capacity(nnz);
371
372    // Create RNG
373    let mut rng = if let Some(seed_value) = seed {
374        scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value)
375    } else {
376        // For a random seed, use rng
377        let seed = scirs2_core::random::random::<u64>();
378        scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
379    };
380
381    // Generate random elements
382    let total = m * n;
383
384    if density > 0.4 {
385        // For high densities, more efficient to generate a mask
386        let mut indices: Vec<usize> = (0..total).collect();
387        indices.shuffle(&mut rng);
388
389        for &idx in indices.iter().take(nnz) {
390            let row = idx / n;
391            let col = idx % n;
392
393            rows.push(row);
394            cols.push(col);
395
396            // Generate random non-zero value
397            // For simplicity, using values between -1 and 1
398            let mut val: f64 = rng.random_range(-1.0..1.0);
399            // Make sure the value is not zero
400            while val.abs() < 1e-10 {
401                val = rng.random_range(-1.0..1.0);
402            }
403            data.push(T::from(val).unwrap());
404        }
405    } else {
406        // For low densities..use a set to track already-chosen positions
407        let mut positions = std::collections::HashSet::with_capacity(nnz);
408
409        while positions.len() < nnz {
410            let row = rng.random_range(0..m);
411            let col = rng.random_range(0..n);
412            let pos = row * n + col; // Using row/col as usize indices
413
414            if positions.insert(pos) {
415                rows.push(row);
416                cols.push(col);
417
418                // Generate random non-zero value
419                let mut val: f64 = rng.random_range(-1.0..1.0);
420                // Make sure the value is not zero
421                while val.abs() < 1e-10 {
422                    val = rng.random_range(-1.0..1.0);
423                }
424                data.push(T::from(val).unwrap());
425            }
426        }
427    }
428
429    // Create the output array
430    match format.to_lowercase().as_str() {
431        "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
432            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
433        "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
434            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
435        "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
436            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
437        "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
438            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
439        _ => Err(SparseError::ValueError(format!(
440            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
441        ))),
442    }
443}
444
445/// Creates a large sparse random array using parallel processing
446///
447/// This function uses parallel construction for improved performance when creating
448/// large sparse arrays with many non-zero elements.
449///
450/// # Arguments
451/// * `shape` - Shape of the array (rows, cols)
452/// * `density` - Density of non-zero elements (0.0 to 1.0)
453/// * `seed` - Optional random seed for reproducibility
454/// * `format` - Format of the output array ("csr" or "coo")
455/// * `parallel_threshold` - Minimum number of elements to use parallel construction
456///
457/// # Returns
458/// A sparse array with randomly distributed non-zero elements
459///
460/// # Examples
461///
462/// ```
463/// use scirs2_sparse::construct::random_array_parallel;
464///
465/// // Create a large random sparse array
466/// let large_random = random_array_parallel::<f64>((1000, 1000), 0.01, Some(42), "csr", 10000).unwrap();
467/// assert_eq!(large_random.shape(), (1000, 1000));
468/// assert!(large_random.nnz() > 5000); // Approximately 10000 non-zeros expected
469/// ```
470#[allow(dead_code)]
471pub fn random_array_parallel<T>(
472    shape: (usize, usize),
473    density: f64,
474    seed: Option<u64>,
475    format: &str,
476    parallel_threshold: usize,
477) -> SparseResult<Box<dyn SparseArray<T>>>
478where
479    T: Float
480        + Add<Output = T>
481        + Sub<Output = T>
482        + Mul<Output = T>
483        + Div<Output = T>
484        + Debug
485        + Copy
486        + Send
487        + Sync
488        + 'static,
489{
490    if !(0.0..=1.0).contains(&density) {
491        return Err(SparseError::ValueError(
492            "Density must be between 0.0 and 1.0".to_string(),
493        ));
494    }
495
496    let (rows, cols) = shape;
497    if rows == 0 || cols == 0 {
498        return Err(SparseError::ValueError(
499            "Matrix dimensions must be positive".to_string(),
500        ));
501    }
502
503    let total_elements = rows * cols;
504    let expected_nnz = (total_elements as f64 * density) as usize;
505
506    // Use parallel construction for large matrices
507    if total_elements >= parallel_threshold && expected_nnz >= 1000 {
508        parallel_random_construction(shape, density, seed, format)
509    } else {
510        // Fall back to sequential construction for small matrices
511        random_array(shape, density, seed, format)
512    }
513}
514
515/// Internal parallel construction function
516#[allow(dead_code)]
517fn parallel_random_construction<T>(
518    shape: (usize, usize),
519    density: f64,
520    seed: Option<u64>,
521    format: &str,
522) -> SparseResult<Box<dyn SparseArray<T>>>
523where
524    T: Float
525        + Add<Output = T>
526        + Sub<Output = T>
527        + Mul<Output = T>
528        + Div<Output = T>
529        + Debug
530        + Copy
531        + Send
532        + Sync
533        + 'static,
534{
535    let (rows, cols) = shape;
536    let total_elements = rows * cols;
537    let expected_nnz = (total_elements as f64 * density) as usize;
538
539    // Determine number of chunks based on available parallelism
540    let num_chunks = std::cmp::min(scirs2_core::parallel_ops::get_num_threads(), rows.min(cols));
541    let chunk_size = std::cmp::max(1, rows / num_chunks);
542
543    // Create row chunks for parallel processing
544    let row_chunks: Vec<_> = (0..rows)
545        .collect::<Vec<_>>()
546        .chunks(chunk_size)
547        .map(|chunk| chunk.to_vec())
548        .collect();
549
550    // Generate random elements in parallel using enumerate to get chunk index
551    let chunk_data: Vec<_> = row_chunks.iter().enumerate().collect();
552    let results: Vec<_> = parallel_map(&chunk_data, |(chunk_idx, row_chunk)| {
553        let mut local_rows = Vec::new();
554        let mut local_cols = Vec::new();
555        let mut local_data = Vec::new();
556
557        // Use a different seed for each chunk to ensure good randomization
558        let chunk_seed = seed.unwrap_or(42) + *chunk_idx as u64 * 1000007; // Large prime offset
559        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(chunk_seed);
560
561        for &row in row_chunk.iter() {
562            // Determine how many elements to generate for this row
563            let row_elements = cols;
564            let row_expected_nnz = std::cmp::max(1, (row_elements as f64 * density) as usize);
565
566            // Generate random column indices for this row
567            let mut col_indices: Vec<usize> = (0..cols).collect();
568            col_indices.shuffle(&mut rng);
569
570            // Take the first row_expected_nnz columns
571            for &col in col_indices.iter().take(row_expected_nnz) {
572                // Generate random value
573                let mut val = rng.random_range(-1.0..1.0);
574                // Make sure the value is not zero
575                while val.abs() < 1e-10 {
576                    val = rng.random_range(-1.0..1.0);
577                }
578
579                local_rows.push(row);
580                local_cols.push(col);
581                local_data.push(T::from(val).unwrap());
582            }
583        }
584
585        (local_rows, local_cols, local_data)
586    });
587
588    // Combine results from all chunks
589    let mut all_rows = Vec::new();
590    let mut all_cols = Vec::new();
591    let mut all_data = Vec::new();
592
593    for (mut rowschunk, mut cols_chunk, mut data_chunk) in results {
594        all_rows.extend(rowschunk);
595        all_cols.append(&mut cols_chunk);
596        all_data.append(&mut data_chunk);
597    }
598
599    // Create the output array
600    match format.to_lowercase().as_str() {
601        "csr" => CsrArray::from_triplets(&all_rows, &all_cols, &all_data, shape, false)
602            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
603        "coo" => CooArray::from_triplets(&all_rows, &all_cols, &all_data, shape, false)
604            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
605        "dok" => DokArray::from_triplets(&all_rows, &all_cols, &all_data, shape)
606            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
607        "lil" => LilArray::from_triplets(&all_rows, &all_cols, &all_data, shape)
608            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
609        _ => Err(SparseError::ValueError(format!(
610            "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
611        ))),
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618
619    #[test]
620    fn test_eye_array() {
621        let eye = eye_array::<f64>(3, "csr").unwrap();
622
623        assert_eq!(eye.shape(), (3, 3));
624        assert_eq!(eye.nnz(), 3);
625        assert_eq!(eye.get(0, 0), 1.0);
626        assert_eq!(eye.get(1, 1), 1.0);
627        assert_eq!(eye.get(2, 2), 1.0);
628        assert_eq!(eye.get(0, 1), 0.0);
629
630        // Try COO format
631        let eye_coo = eye_array::<f64>(3, "coo").unwrap();
632        assert_eq!(eye_coo.shape(), (3, 3));
633        assert_eq!(eye_coo.nnz(), 3);
634
635        // Try DOK format
636        let eye_dok = eye_array::<f64>(3, "dok").unwrap();
637        assert_eq!(eye_dok.shape(), (3, 3));
638        assert_eq!(eye_dok.nnz(), 3);
639        assert_eq!(eye_dok.get(0, 0), 1.0);
640        assert_eq!(eye_dok.get(1, 1), 1.0);
641        assert_eq!(eye_dok.get(2, 2), 1.0);
642
643        // Try LIL format
644        let eye_lil = eye_array::<f64>(3, "lil").unwrap();
645        assert_eq!(eye_lil.shape(), (3, 3));
646        assert_eq!(eye_lil.nnz(), 3);
647        assert_eq!(eye_lil.get(0, 0), 1.0);
648        assert_eq!(eye_lil.get(1, 1), 1.0);
649        assert_eq!(eye_lil.get(2, 2), 1.0);
650    }
651
652    #[test]
653    fn test_eye_array_k() {
654        // Identity with main diagonal (k=0)
655        let eye = eye_array_k::<f64>(3, 3, 0, "csr").unwrap();
656        assert_eq!(eye.get(0, 0), 1.0);
657        assert_eq!(eye.get(1, 1), 1.0);
658        assert_eq!(eye.get(2, 2), 1.0);
659
660        // Superdiagonal (k=1)
661        let superdiag = eye_array_k::<f64>(3, 4, 1, "csr").unwrap();
662        assert_eq!(superdiag.get(0, 1), 1.0);
663        assert_eq!(superdiag.get(1, 2), 1.0);
664        assert_eq!(superdiag.get(2, 3), 1.0);
665
666        // Subdiagonal (k=-1)
667        let subdiag = eye_array_k::<f64>(4, 3, -1, "csr").unwrap();
668        assert_eq!(subdiag.get(1, 0), 1.0);
669        assert_eq!(subdiag.get(2, 1), 1.0);
670        assert_eq!(subdiag.get(3, 2), 1.0);
671
672        // Try LIL format
673        let eye_lil = eye_array_k::<f64>(3, 3, 0, "lil").unwrap();
674        assert_eq!(eye_lil.get(0, 0), 1.0);
675        assert_eq!(eye_lil.get(1, 1), 1.0);
676        assert_eq!(eye_lil.get(2, 2), 1.0);
677    }
678
679    #[test]
680    fn test_diags_array() {
681        let diags = vec![
682            Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
683            Array1::from_vec(vec![4.0, 5.0]),      // superdiagonal
684        ];
685        let offsets = vec![0, 1];
686        let shape = (3, 3);
687
688        let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
689        assert_eq!(result.shape(), (3, 3));
690        assert_eq!(result.get(0, 0), 1.0);
691        assert_eq!(result.get(1, 1), 2.0);
692        assert_eq!(result.get(2, 2), 3.0);
693        assert_eq!(result.get(0, 1), 4.0);
694        assert_eq!(result.get(1, 2), 5.0);
695
696        // Try with multiple diagonals and subdiagonals
697        let diags = vec![
698            Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
699            Array1::from_vec(vec![4.0, 5.0]),      // superdiagonal
700            Array1::from_vec(vec![6.0, 7.0]),      // subdiagonal
701        ];
702        let offsets = vec![0, 1, -1];
703        let shape = (3, 3);
704
705        let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
706        assert_eq!(result.shape(), (3, 3));
707        assert_eq!(result.get(0, 0), 1.0);
708        assert_eq!(result.get(1, 1), 2.0);
709        assert_eq!(result.get(2, 2), 3.0);
710        assert_eq!(result.get(0, 1), 4.0);
711        assert_eq!(result.get(1, 2), 5.0);
712        assert_eq!(result.get(1, 0), 6.0);
713        assert_eq!(result.get(2, 1), 7.0);
714
715        // Try LIL format
716        let result_lil = diags_array(&diags, &offsets, shape, "lil").unwrap();
717        assert_eq!(result_lil.shape(), (3, 3));
718        assert_eq!(result_lil.get(0, 0), 1.0);
719        assert_eq!(result_lil.get(1, 1), 2.0);
720        assert_eq!(result_lil.get(2, 2), 3.0);
721        assert_eq!(result_lil.get(0, 1), 4.0);
722        assert_eq!(result_lil.get(1, 2), 5.0);
723        assert_eq!(result_lil.get(1, 0), 6.0);
724        assert_eq!(result_lil.get(2, 1), 7.0);
725    }
726
727    #[test]
728    fn test_random_array() {
729        let shape = (10, 10);
730        let density = 0.3;
731
732        let random = random_array::<f64>(shape, density, None, "csr").unwrap();
733
734        // Check shape and sparsity
735        assert_eq!(random.shape(), shape);
736        let nnz = random.nnz();
737        let expected_nnz = (shape.0 * shape.1) as f64 * density;
738
739        // Allow for some random variation, but should be close to expected density
740        assert!(
741            (nnz as f64) > expected_nnz * 0.7,
742            "Too few non-zeros: {nnz}"
743        );
744        assert!(
745            (nnz as f64) < expected_nnz * 1.3,
746            "Too many non-zeros: {nnz}"
747        );
748
749        // Test with custom RNG seed
750        let random_seeded = random_array::<f64>(shape, density, Some(42), "csr").unwrap();
751        assert_eq!(random_seeded.shape(), shape);
752
753        // Test LIL format
754        let random_lil = random_array::<f64>((5, 5), 0.5, Some(42), "lil").unwrap();
755        assert_eq!(random_lil.shape(), (5, 5));
756        let nnz_lil = random_lil.nnz();
757        let expected_nnz_lil = 25.0 * 0.5;
758        assert!(
759            (nnz_lil as f64) > expected_nnz_lil * 0.7,
760            "Too few non-zeros in LIL: {nnz_lil}"
761        );
762        assert!(
763            (nnz_lil as f64) < expected_nnz_lil * 1.3,
764            "Too many non-zeros in LIL: {nnz_lil}"
765        );
766    }
767}