scirs2_sparse/
construct.rs

1// Construction utilities for sparse arrays
2//
3// This module provides functions for constructing sparse arrays,
4// including identity matrices, diagonal matrices, random arrays, etc.
5
6use ndarray::Array1;
7use num_traits::Float;
8use rand::seq::SliceRandom;
9use rand::{Rng, SeedableRng};
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::csr_array::CsrArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::SparseArray;
19
20/// Creates a sparse identity array of size n x n
21///
22/// # Arguments
23/// * `n` - Size of the square array
24/// * `format` - Format of the output array ("csr" or "coo")
25///
26/// # Returns
27/// A sparse array representing the identity matrix
28///
29/// # Examples
30///
31/// ```
32/// use scirs2_sparse::construct::eye_array;
33///
34/// let eye: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array(3, "csr").unwrap();
35/// assert_eq!(eye.shape(), (3, 3));
36/// assert_eq!(eye.nnz(), 3);
37/// assert_eq!(eye.get(0, 0), 1.0);
38/// assert_eq!(eye.get(1, 1), 1.0);
39/// assert_eq!(eye.get(2, 2), 1.0);
40/// assert_eq!(eye.get(0, 1), 0.0);
41/// ```
42pub fn eye_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SparseArray<T>>>
43where
44    T: Float
45        + Add<Output = T>
46        + Sub<Output = T>
47        + Mul<Output = T>
48        + Div<Output = T>
49        + Debug
50        + Copy
51        + 'static,
52{
53    if n == 0 {
54        return Err(SparseError::ValueError(
55            "Matrix dimension must be positive".to_string(),
56        ));
57    }
58
59    let mut rows = Vec::with_capacity(n);
60    let mut cols = Vec::with_capacity(n);
61    let mut data = Vec::with_capacity(n);
62
63    for i in 0..n {
64        rows.push(i);
65        cols.push(i);
66        data.push(T::one());
67    }
68
69    match format.to_lowercase().as_str() {
70        "csr" => CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
71            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
72        "coo" => CooArray::from_triplets(&rows, &cols, &data, (n, n), true)
73            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
74        "dok" => DokArray::from_triplets(&rows, &cols, &data, (n, n))
75            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
76        "lil" => LilArray::from_triplets(&rows, &cols, &data, (n, n))
77            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
78        _ => Err(SparseError::ValueError(format!(
79            "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
80            format
81        ))),
82    }
83}
84
85/// Creates a sparse identity array of size m x n with k-th diagonal filled with ones
86///
87/// # Arguments
88/// * `m` - Number of rows
89/// * `n` - Number of columns
90/// * `k` - Diagonal index (0 = main diagonal, >0 = above main, <0 = below main)
91/// * `format` - Format of the output array ("csr" or "coo")
92///
93/// # Returns
94/// A sparse array with ones on the specified diagonal
95///
96/// # Examples
97///
98/// ```
99/// use scirs2_sparse::construct::eye_array_k;
100///
101/// // Identity with main diagonal (k=0)
102/// let eye: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(3, 3, 0, "csr").unwrap();
103/// assert_eq!(eye.get(0, 0), 1.0);
104/// assert_eq!(eye.get(1, 1), 1.0);
105/// assert_eq!(eye.get(2, 2), 1.0);
106///
107/// // Superdiagonal (k=1)
108/// let superdiag: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(3, 4, 1, "csr").unwrap();
109/// assert_eq!(superdiag.get(0, 1), 1.0);
110/// assert_eq!(superdiag.get(1, 2), 1.0);
111/// assert_eq!(superdiag.get(2, 3), 1.0);
112///
113/// // Subdiagonal (k=-1)
114/// let subdiag: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(4, 3, -1, "csr").unwrap();
115/// assert_eq!(subdiag.get(1, 0), 1.0);
116/// assert_eq!(subdiag.get(2, 1), 1.0);
117/// assert_eq!(subdiag.get(3, 2), 1.0);
118/// ```
119pub fn eye_array_k<T>(
120    m: usize,
121    n: usize,
122    k: isize,
123    format: &str,
124) -> SparseResult<Box<dyn SparseArray<T>>>
125where
126    T: Float
127        + Add<Output = T>
128        + Sub<Output = T>
129        + Mul<Output = T>
130        + Div<Output = T>
131        + Debug
132        + Copy
133        + 'static,
134{
135    if m == 0 || n == 0 {
136        return Err(SparseError::ValueError(
137            "Matrix dimensions must be positive".to_string(),
138        ));
139    }
140
141    let mut rows = Vec::new();
142    let mut cols = Vec::new();
143    let mut data = Vec::new();
144
145    // Calculate diagonal elements
146    if k >= 0 {
147        let k_usize = k as usize;
148        let len = std::cmp::min(m, n.saturating_sub(k_usize));
149
150        for i in 0..len {
151            rows.push(i);
152            cols.push(i + k_usize);
153            data.push(T::one());
154        }
155    } else {
156        let k_abs = (-k) as usize;
157        let len = std::cmp::min(m.saturating_sub(k_abs), n);
158
159        for i in 0..len {
160            rows.push(i + k_abs);
161            cols.push(i);
162            data.push(T::one());
163        }
164    }
165
166    match format.to_lowercase().as_str() {
167        "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
168            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
169        "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), true)
170            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
171        "dok" => DokArray::from_triplets(&rows, &cols, &data, (m, n))
172            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
173        "lil" => LilArray::from_triplets(&rows, &cols, &data, (m, n))
174            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
175        _ => Err(SparseError::ValueError(format!(
176            "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
177            format
178        ))),
179    }
180}
181
182/// Creates a sparse array from the specified diagonals
183///
184/// # Arguments
185/// * `diagonals` - Data for the diagonals
186/// * `offsets` - Offset for each diagonal (0 = main, >0 = above main, <0 = below main)
187/// * `shape` - Shape of the output array (m, n)
188/// * `format` - Format of the output array ("csr" or "coo")
189///
190/// # Returns
191/// A sparse array with the specified diagonals
192///
193/// # Examples
194///
195/// ```
196/// use scirs2_sparse::construct::diags_array;
197/// use ndarray::Array1;
198///
199/// let diags = vec![
200///     Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
201///     Array1::from_vec(vec![4.0, 5.0])       // superdiagonal
202/// ];
203/// let offsets = vec![0, 1];
204/// let shape = (3, 3);
205///
206/// let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
207/// assert_eq!(result.shape(), (3, 3));
208/// assert_eq!(result.get(0, 0), 1.0);
209/// assert_eq!(result.get(1, 1), 2.0);
210/// assert_eq!(result.get(2, 2), 3.0);
211/// assert_eq!(result.get(0, 1), 4.0);
212/// assert_eq!(result.get(1, 2), 5.0);
213/// ```
214pub fn diags_array<T>(
215    diagonals: &[Array1<T>],
216    offsets: &[isize],
217    shape: (usize, usize),
218    format: &str,
219) -> SparseResult<Box<dyn SparseArray<T>>>
220where
221    T: Float
222        + Add<Output = T>
223        + Sub<Output = T>
224        + Mul<Output = T>
225        + Div<Output = T>
226        + Debug
227        + Copy
228        + 'static,
229{
230    if diagonals.len() != offsets.len() {
231        return Err(SparseError::InconsistentData {
232            reason: "Number of diagonals must match number of offsets".to_string(),
233        });
234    }
235
236    if shape.0 == 0 || shape.1 == 0 {
237        return Err(SparseError::ValueError(
238            "Matrix dimensions must be positive".to_string(),
239        ));
240    }
241
242    let (m, n) = shape;
243    let mut rows = Vec::new();
244    let mut cols = Vec::new();
245    let mut data = Vec::new();
246
247    for (i, (diag, &offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
248        if offset >= 0 {
249            let offset_usize = offset as usize;
250            let max_len = std::cmp::min(m, n.saturating_sub(offset_usize));
251
252            if diag.len() > max_len {
253                return Err(SparseError::InconsistentData {
254                    reason: format!("Diagonal {} is too long ({} > {})", i, diag.len(), max_len),
255                });
256            }
257
258            for (j, &value) in diag.iter().enumerate() {
259                if !value.is_zero() {
260                    rows.push(j);
261                    cols.push(j + offset_usize);
262                    data.push(value);
263                }
264            }
265        } else {
266            let offset_abs = (-offset) as usize;
267            let max_len = std::cmp::min(m.saturating_sub(offset_abs), n);
268
269            if diag.len() > max_len {
270                return Err(SparseError::InconsistentData {
271                    reason: format!("Diagonal {} is too long ({} > {})", i, diag.len(), max_len),
272                });
273            }
274
275            for (j, &value) in diag.iter().enumerate() {
276                if !value.is_zero() {
277                    rows.push(j + offset_abs);
278                    cols.push(j);
279                    data.push(value);
280                }
281            }
282        }
283    }
284
285    match format.to_lowercase().as_str() {
286        "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
287            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
288        "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
289            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
290        "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
291            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
292        "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
293            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
294        _ => Err(SparseError::ValueError(format!(
295            "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
296            format
297        ))),
298    }
299}
300
301/// Creates a random sparse array with specified density
302///
303/// # Arguments
304/// * `shape` - Shape of the output array (m, n)
305/// * `density` - Density of non-zero elements (between 0.0 and 1.0)
306/// * `seed` - Optional seed for the random number generator
307/// * `format` - Format of the output array ("csr" or "coo")
308///
309/// # Returns
310/// A sparse array with random non-zero elements
311///
312/// # Examples
313///
314/// ```
315/// use scirs2_sparse::construct::random_array;
316///
317/// // Create a 10x10 array with 30% non-zero elements
318/// let random = random_array::<f64>((10, 10), 0.3, None, "csr").unwrap();
319/// assert_eq!(random.shape(), (10, 10));
320///
321/// // Create a random array with a specific seed
322/// let seeded = random_array::<f64>((5, 5), 0.5, Some(42), "coo").unwrap();
323/// assert_eq!(seeded.shape(), (5, 5));
324/// ```
325pub fn random_array<T>(
326    shape: (usize, usize),
327    density: f64,
328    seed: Option<u64>,
329    format: &str,
330) -> SparseResult<Box<dyn SparseArray<T>>>
331where
332    T: Float
333        + Add<Output = T>
334        + Sub<Output = T>
335        + Mul<Output = T>
336        + Div<Output = T>
337        + Debug
338        + Copy
339        + 'static,
340{
341    let (m, n) = shape;
342
343    if !(0.0..=1.0).contains(&density) {
344        return Err(SparseError::ValueError(
345            "Density must be between 0.0 and 1.0".to_string(),
346        ));
347    }
348
349    if m == 0 || n == 0 {
350        return Err(SparseError::ValueError(
351            "Matrix dimensions must be positive".to_string(),
352        ));
353    }
354
355    // Calculate the number of non-zero elements
356    let nnz = (m * n) as f64 * density;
357    let nnz = nnz.round() as usize;
358
359    // Create random indices
360    let mut rows = Vec::with_capacity(nnz);
361    let mut cols = Vec::with_capacity(nnz);
362    let mut data = Vec::with_capacity(nnz);
363
364    // Create RNG
365    let mut rng = if let Some(seed_value) = seed {
366        rand::rngs::StdRng::seed_from_u64(seed_value)
367    } else {
368        // For a random seed, use rng
369        let seed = rand::Rng::random::<u64>(&mut rand::rng());
370        rand::rngs::StdRng::seed_from_u64(seed)
371    };
372
373    // Generate random elements
374    let total = m * n;
375
376    if density > 0.4 {
377        // For high densities, more efficient to generate a mask
378        let mut indices: Vec<usize> = (0..total).collect();
379        indices.shuffle(&mut rng);
380
381        for &idx in indices.iter().take(nnz) {
382            let row = idx / n;
383            let col = idx % n;
384
385            rows.push(row);
386            cols.push(col);
387
388            // Generate random non-zero value
389            // For simplicity, using values between -1 and 1
390            let mut val: f64 = rng.random_range(-1.0..1.0);
391            // Make sure the value is not zero
392            while val.abs() < 1e-10 {
393                val = rng.random_range(-1.0..1.0);
394            }
395            data.push(T::from(val).unwrap());
396        }
397    } else {
398        // For low densities, use a set to track already-chosen positions
399        let mut positions = std::collections::HashSet::with_capacity(nnz);
400
401        while positions.len() < nnz {
402            let row = rng.random_range(0..m);
403            let col = rng.random_range(0..n);
404            let pos = row * n + col; // Using row/col as usize indices
405
406            if positions.insert(pos) {
407                rows.push(row);
408                cols.push(col);
409
410                // Generate random non-zero value
411                let mut val: f64 = rng.random_range(-1.0..1.0);
412                // Make sure the value is not zero
413                while val.abs() < 1e-10 {
414                    val = rng.random_range(-1.0..1.0);
415                }
416                data.push(T::from(val).unwrap());
417            }
418        }
419    }
420
421    // Create the output array
422    match format.to_lowercase().as_str() {
423        "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
424            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
425        "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
426            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
427        "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
428            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
429        "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
430            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
431        _ => Err(SparseError::ValueError(format!(
432            "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
433            format
434        ))),
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_eye_array() {
444        let eye = eye_array::<f64>(3, "csr").unwrap();
445
446        assert_eq!(eye.shape(), (3, 3));
447        assert_eq!(eye.nnz(), 3);
448        assert_eq!(eye.get(0, 0), 1.0);
449        assert_eq!(eye.get(1, 1), 1.0);
450        assert_eq!(eye.get(2, 2), 1.0);
451        assert_eq!(eye.get(0, 1), 0.0);
452
453        // Try COO format
454        let eye_coo = eye_array::<f64>(3, "coo").unwrap();
455        assert_eq!(eye_coo.shape(), (3, 3));
456        assert_eq!(eye_coo.nnz(), 3);
457
458        // Try DOK format
459        let eye_dok = eye_array::<f64>(3, "dok").unwrap();
460        assert_eq!(eye_dok.shape(), (3, 3));
461        assert_eq!(eye_dok.nnz(), 3);
462        assert_eq!(eye_dok.get(0, 0), 1.0);
463        assert_eq!(eye_dok.get(1, 1), 1.0);
464        assert_eq!(eye_dok.get(2, 2), 1.0);
465
466        // Try LIL format
467        let eye_lil = eye_array::<f64>(3, "lil").unwrap();
468        assert_eq!(eye_lil.shape(), (3, 3));
469        assert_eq!(eye_lil.nnz(), 3);
470        assert_eq!(eye_lil.get(0, 0), 1.0);
471        assert_eq!(eye_lil.get(1, 1), 1.0);
472        assert_eq!(eye_lil.get(2, 2), 1.0);
473    }
474
475    #[test]
476    fn test_eye_array_k() {
477        // Identity with main diagonal (k=0)
478        let eye = eye_array_k::<f64>(3, 3, 0, "csr").unwrap();
479        assert_eq!(eye.get(0, 0), 1.0);
480        assert_eq!(eye.get(1, 1), 1.0);
481        assert_eq!(eye.get(2, 2), 1.0);
482
483        // Superdiagonal (k=1)
484        let superdiag = eye_array_k::<f64>(3, 4, 1, "csr").unwrap();
485        assert_eq!(superdiag.get(0, 1), 1.0);
486        assert_eq!(superdiag.get(1, 2), 1.0);
487        assert_eq!(superdiag.get(2, 3), 1.0);
488
489        // Subdiagonal (k=-1)
490        let subdiag = eye_array_k::<f64>(4, 3, -1, "csr").unwrap();
491        assert_eq!(subdiag.get(1, 0), 1.0);
492        assert_eq!(subdiag.get(2, 1), 1.0);
493        assert_eq!(subdiag.get(3, 2), 1.0);
494
495        // Try LIL format
496        let eye_lil = eye_array_k::<f64>(3, 3, 0, "lil").unwrap();
497        assert_eq!(eye_lil.get(0, 0), 1.0);
498        assert_eq!(eye_lil.get(1, 1), 1.0);
499        assert_eq!(eye_lil.get(2, 2), 1.0);
500    }
501
502    #[test]
503    fn test_diags_array() {
504        let diags = vec![
505            Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
506            Array1::from_vec(vec![4.0, 5.0]),      // superdiagonal
507        ];
508        let offsets = vec![0, 1];
509        let shape = (3, 3);
510
511        let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
512        assert_eq!(result.shape(), (3, 3));
513        assert_eq!(result.get(0, 0), 1.0);
514        assert_eq!(result.get(1, 1), 2.0);
515        assert_eq!(result.get(2, 2), 3.0);
516        assert_eq!(result.get(0, 1), 4.0);
517        assert_eq!(result.get(1, 2), 5.0);
518
519        // Try with multiple diagonals and subdiagonals
520        let diags = vec![
521            Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
522            Array1::from_vec(vec![4.0, 5.0]),      // superdiagonal
523            Array1::from_vec(vec![6.0, 7.0]),      // subdiagonal
524        ];
525        let offsets = vec![0, 1, -1];
526        let shape = (3, 3);
527
528        let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
529        assert_eq!(result.shape(), (3, 3));
530        assert_eq!(result.get(0, 0), 1.0);
531        assert_eq!(result.get(1, 1), 2.0);
532        assert_eq!(result.get(2, 2), 3.0);
533        assert_eq!(result.get(0, 1), 4.0);
534        assert_eq!(result.get(1, 2), 5.0);
535        assert_eq!(result.get(1, 0), 6.0);
536        assert_eq!(result.get(2, 1), 7.0);
537
538        // Try LIL format
539        let result_lil = diags_array(&diags, &offsets, shape, "lil").unwrap();
540        assert_eq!(result_lil.shape(), (3, 3));
541        assert_eq!(result_lil.get(0, 0), 1.0);
542        assert_eq!(result_lil.get(1, 1), 2.0);
543        assert_eq!(result_lil.get(2, 2), 3.0);
544        assert_eq!(result_lil.get(0, 1), 4.0);
545        assert_eq!(result_lil.get(1, 2), 5.0);
546        assert_eq!(result_lil.get(1, 0), 6.0);
547        assert_eq!(result_lil.get(2, 1), 7.0);
548    }
549
550    #[test]
551    fn test_random_array() {
552        let shape = (10, 10);
553        let density = 0.3;
554
555        let random = random_array::<f64>(shape, density, None, "csr").unwrap();
556
557        // Check shape and sparsity
558        assert_eq!(random.shape(), shape);
559        let nnz = random.nnz();
560        let expected_nnz = (shape.0 * shape.1) as f64 * density;
561
562        // Allow for some random variation, but should be close to expected density
563        assert!(
564            (nnz as f64) > expected_nnz * 0.7,
565            "Too few non-zeros: {}",
566            nnz
567        );
568        assert!(
569            (nnz as f64) < expected_nnz * 1.3,
570            "Too many non-zeros: {}",
571            nnz
572        );
573
574        // Test with custom RNG seed
575        let random_seeded = random_array::<f64>(shape, density, Some(42), "csr").unwrap();
576        assert_eq!(random_seeded.shape(), shape);
577
578        // Test LIL format
579        let random_lil = random_array::<f64>((5, 5), 0.5, Some(42), "lil").unwrap();
580        assert_eq!(random_lil.shape(), (5, 5));
581        let nnz_lil = random_lil.nnz();
582        let expected_nnz_lil = 25.0 * 0.5;
583        assert!(
584            (nnz_lil as f64) > expected_nnz_lil * 0.7,
585            "Too few non-zeros in LIL: {}",
586            nnz_lil
587        );
588        assert!(
589            (nnz_lil as f64) < expected_nnz_lil * 1.3,
590            "Too many non-zeros in LIL: {}",
591            nnz_lil
592        );
593    }
594}