scirs2_sparse/
construct.rs

1// Construction utilities for sparse arrays
2//
3// This module provides functions for constructing sparse arrays,
4// including identity matrices, diagonal matrices, random arrays, etc.
5
6use ndarray::Array1;
7use num_traits::Float;
8use rand::seq::SliceRandom;
9use rand::{Rng, SeedableRng};
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::csr_array::CsrArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::SparseArray;
19
20/// Creates a sparse identity array of size n x n
21///
22/// # Arguments
23/// * `n` - Size of the square array
24/// * `format` - Format of the output array ("csr" or "coo")
25///
26/// # Returns
27/// A sparse array representing the identity matrix
28///
29/// # Examples
30///
31/// ```
32/// use scirs2_sparse::construct::eye_array;
33///
34/// let eye: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array(3, "csr").unwrap();
35/// assert_eq!(eye.shape(), (3, 3));
36/// assert_eq!(eye.nnz(), 3);
37/// assert_eq!(eye.get(0, 0), 1.0);
38/// assert_eq!(eye.get(1, 1), 1.0);
39/// assert_eq!(eye.get(2, 2), 1.0);
40/// assert_eq!(eye.get(0, 1), 0.0);
41/// ```
42pub fn eye_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SparseArray<T>>>
43where
44    T: Float
45        + Add<Output = T>
46        + Sub<Output = T>
47        + Mul<Output = T>
48        + Div<Output = T>
49        + Debug
50        + Copy
51        + 'static,
52{
53    if n == 0 {
54        return Err(SparseError::ValueError(
55            "Matrix dimension must be positive".to_string(),
56        ));
57    }
58
59    let mut rows = Vec::with_capacity(n);
60    let mut cols = Vec::with_capacity(n);
61    let mut data = Vec::with_capacity(n);
62
63    for i in 0..n {
64        rows.push(i);
65        cols.push(i);
66        data.push(T::one());
67    }
68
69    match format.to_lowercase().as_str() {
70        "csr" => CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
71            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
72        "coo" => CooArray::from_triplets(&rows, &cols, &data, (n, n), true)
73            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
74        "dok" => DokArray::from_triplets(&rows, &cols, &data, (n, n))
75            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
76        "lil" => LilArray::from_triplets(&rows, &cols, &data, (n, n))
77            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
78        _ => Err(SparseError::ValueError(format!(
79            "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
80            format
81        ))),
82    }
83}
84
85/// Creates a sparse identity array of size m x n with k-th diagonal filled with ones
86///
87/// # Arguments
88/// * `m` - Number of rows
89/// * `n` - Number of columns
90/// * `k` - Diagonal index (0 = main diagonal, >0 = above main, <0 = below main)
91/// * `format` - Format of the output array ("csr" or "coo")
92///
93/// # Returns
94/// A sparse array with ones on the specified diagonal
95///
96/// # Examples
97///
98/// ```
99/// use scirs2_sparse::construct::eye_array_k;
100///
101/// // Identity with main diagonal (k=0)
102/// let eye: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(3, 3, 0, "csr").unwrap();
103/// assert_eq!(eye.get(0, 0), 1.0);
104/// assert_eq!(eye.get(1, 1), 1.0);
105/// assert_eq!(eye.get(2, 2), 1.0);
106///
107/// // Superdiagonal (k=1)
108/// let superdiag: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(3, 4, 1, "csr").unwrap();
109/// assert_eq!(superdiag.get(0, 1), 1.0);
110/// assert_eq!(superdiag.get(1, 2), 1.0);
111/// assert_eq!(superdiag.get(2, 3), 1.0);
112///
113/// // Subdiagonal (k=-1)
114/// let subdiag: Box<dyn scirs2_sparse::SparseArray<f64>> = eye_array_k(4, 3, -1, "csr").unwrap();
115/// assert_eq!(subdiag.get(1, 0), 1.0);
116/// assert_eq!(subdiag.get(2, 1), 1.0);
117/// assert_eq!(subdiag.get(3, 2), 1.0);
118/// ```
119pub fn eye_array_k<T>(
120    m: usize,
121    n: usize,
122    k: isize,
123    format: &str,
124) -> SparseResult<Box<dyn SparseArray<T>>>
125where
126    T: Float
127        + Add<Output = T>
128        + Sub<Output = T>
129        + Mul<Output = T>
130        + Div<Output = T>
131        + Debug
132        + Copy
133        + 'static,
134{
135    if m == 0 || n == 0 {
136        return Err(SparseError::ValueError(
137            "Matrix dimensions must be positive".to_string(),
138        ));
139    }
140
141    let mut rows = Vec::new();
142    let mut cols = Vec::new();
143    let mut data = Vec::new();
144
145    // Calculate diagonal elements
146    if k >= 0 {
147        let k_usize = k as usize;
148        let len = std::cmp::min(m, n.saturating_sub(k_usize));
149
150        for i in 0..len {
151            rows.push(i);
152            cols.push(i + k_usize);
153            data.push(T::one());
154        }
155    } else {
156        let k_abs = (-k) as usize;
157        let len = std::cmp::min(m.saturating_sub(k_abs), n);
158
159        for i in 0..len {
160            rows.push(i + k_abs);
161            cols.push(i);
162            data.push(T::one());
163        }
164    }
165
166    match format.to_lowercase().as_str() {
167        "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
168            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
169        "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), true)
170            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
171        "dok" => DokArray::from_triplets(&rows, &cols, &data, (m, n))
172            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
173        "lil" => LilArray::from_triplets(&rows, &cols, &data, (m, n))
174            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
175        _ => Err(SparseError::ValueError(format!(
176            "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
177            format
178        ))),
179    }
180}
181
182/// Creates a sparse array from the specified diagonals
183///
184/// # Arguments
185/// * `diagonals` - Data for the diagonals
186/// * `offsets` - Offset for each diagonal (0 = main, >0 = above main, <0 = below main)
187/// * `shape` - Shape of the output array (m, n)
188/// * `format` - Format of the output array ("csr" or "coo")
189///
190/// # Returns
191/// A sparse array with the specified diagonals
192///
193/// # Examples
194///
195/// ```
196/// use scirs2_sparse::construct::diags_array;
197/// use ndarray::Array1;
198///
199/// let diags = vec![
200///     Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
201///     Array1::from_vec(vec![4.0, 5.0])       // superdiagonal
202/// ];
203/// let offsets = vec![0, 1];
204/// let shape = (3, 3);
205///
206/// let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
207/// assert_eq!(result.shape(), (3, 3));
208/// assert_eq!(result.get(0, 0), 1.0);
209/// assert_eq!(result.get(1, 1), 2.0);
210/// assert_eq!(result.get(2, 2), 3.0);
211/// assert_eq!(result.get(0, 1), 4.0);
212/// assert_eq!(result.get(1, 2), 5.0);
213/// ```
214pub fn diags_array<T>(
215    diagonals: &[Array1<T>],
216    offsets: &[isize],
217    shape: (usize, usize),
218    format: &str,
219) -> SparseResult<Box<dyn SparseArray<T>>>
220where
221    T: Float
222        + Add<Output = T>
223        + Sub<Output = T>
224        + Mul<Output = T>
225        + Div<Output = T>
226        + Debug
227        + Copy
228        + 'static,
229{
230    if diagonals.len() != offsets.len() {
231        return Err(SparseError::InconsistentData {
232            reason: "Number of diagonals must match number of offsets".to_string(),
233        });
234    }
235
236    if shape.0 == 0 || shape.1 == 0 {
237        return Err(SparseError::ValueError(
238            "Matrix dimensions must be positive".to_string(),
239        ));
240    }
241
242    let (m, n) = shape;
243    let mut rows = Vec::new();
244    let mut cols = Vec::new();
245    let mut data = Vec::new();
246
247    for (i, (diag, &offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
248        if offset >= 0 {
249            let offset_usize = offset as usize;
250            let max_len = std::cmp::min(m, n.saturating_sub(offset_usize));
251
252            if diag.len() > max_len {
253                return Err(SparseError::InconsistentData {
254                    reason: format!("Diagonal {} is too long ({} > {})", i, diag.len(), max_len),
255                });
256            }
257
258            for (j, &value) in diag.iter().enumerate() {
259                if !value.is_zero() {
260                    rows.push(j);
261                    cols.push(j + offset_usize);
262                    data.push(value);
263                }
264            }
265        } else {
266            let offset_abs = (-offset) as usize;
267            let max_len = std::cmp::min(m.saturating_sub(offset_abs), n);
268
269            if diag.len() > max_len {
270                return Err(SparseError::InconsistentData {
271                    reason: format!("Diagonal {} is too long ({} > {})", i, diag.len(), max_len),
272                });
273            }
274
275            for (j, &value) in diag.iter().enumerate() {
276                if !value.is_zero() {
277                    rows.push(j + offset_abs);
278                    cols.push(j);
279                    data.push(value);
280                }
281            }
282        }
283    }
284
285    match format.to_lowercase().as_str() {
286        "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
287            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
288        "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
289            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
290        "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
291            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
292        "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
293            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
294        _ => Err(SparseError::ValueError(format!(
295            "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
296            format
297        ))),
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_eye_array() {
307        let eye = eye_array::<f64>(3, "csr").unwrap();
308
309        assert_eq!(eye.shape(), (3, 3));
310        assert_eq!(eye.nnz(), 3);
311        assert_eq!(eye.get(0, 0), 1.0);
312        assert_eq!(eye.get(1, 1), 1.0);
313        assert_eq!(eye.get(2, 2), 1.0);
314        assert_eq!(eye.get(0, 1), 0.0);
315
316        // Try COO format
317        let eye_coo = eye_array::<f64>(3, "coo").unwrap();
318        assert_eq!(eye_coo.shape(), (3, 3));
319        assert_eq!(eye_coo.nnz(), 3);
320
321        // Try DOK format
322        let eye_dok = eye_array::<f64>(3, "dok").unwrap();
323        assert_eq!(eye_dok.shape(), (3, 3));
324        assert_eq!(eye_dok.nnz(), 3);
325        assert_eq!(eye_dok.get(0, 0), 1.0);
326        assert_eq!(eye_dok.get(1, 1), 1.0);
327        assert_eq!(eye_dok.get(2, 2), 1.0);
328
329        // Try LIL format
330        let eye_lil = eye_array::<f64>(3, "lil").unwrap();
331        assert_eq!(eye_lil.shape(), (3, 3));
332        assert_eq!(eye_lil.nnz(), 3);
333        assert_eq!(eye_lil.get(0, 0), 1.0);
334        assert_eq!(eye_lil.get(1, 1), 1.0);
335        assert_eq!(eye_lil.get(2, 2), 1.0);
336    }
337
338    #[test]
339    fn test_eye_array_k() {
340        // Identity with main diagonal (k=0)
341        let eye = eye_array_k::<f64>(3, 3, 0, "csr").unwrap();
342        assert_eq!(eye.get(0, 0), 1.0);
343        assert_eq!(eye.get(1, 1), 1.0);
344        assert_eq!(eye.get(2, 2), 1.0);
345
346        // Superdiagonal (k=1)
347        let superdiag = eye_array_k::<f64>(3, 4, 1, "csr").unwrap();
348        assert_eq!(superdiag.get(0, 1), 1.0);
349        assert_eq!(superdiag.get(1, 2), 1.0);
350        assert_eq!(superdiag.get(2, 3), 1.0);
351
352        // Subdiagonal (k=-1)
353        let subdiag = eye_array_k::<f64>(4, 3, -1, "csr").unwrap();
354        assert_eq!(subdiag.get(1, 0), 1.0);
355        assert_eq!(subdiag.get(2, 1), 1.0);
356        assert_eq!(subdiag.get(3, 2), 1.0);
357
358        // Try LIL format
359        let eye_lil = eye_array_k::<f64>(3, 3, 0, "lil").unwrap();
360        assert_eq!(eye_lil.get(0, 0), 1.0);
361        assert_eq!(eye_lil.get(1, 1), 1.0);
362        assert_eq!(eye_lil.get(2, 2), 1.0);
363    }
364
365    #[test]
366    fn test_diags_array() {
367        let diags = vec![
368            Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
369            Array1::from_vec(vec![4.0, 5.0]),      // superdiagonal
370        ];
371        let offsets = vec![0, 1];
372        let shape = (3, 3);
373
374        let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
375        assert_eq!(result.shape(), (3, 3));
376        assert_eq!(result.get(0, 0), 1.0);
377        assert_eq!(result.get(1, 1), 2.0);
378        assert_eq!(result.get(2, 2), 3.0);
379        assert_eq!(result.get(0, 1), 4.0);
380        assert_eq!(result.get(1, 2), 5.0);
381
382        // Try with multiple diagonals and subdiagonals
383        let diags = vec![
384            Array1::from_vec(vec![1.0, 2.0, 3.0]), // main diagonal
385            Array1::from_vec(vec![4.0, 5.0]),      // superdiagonal
386            Array1::from_vec(vec![6.0, 7.0]),      // subdiagonal
387        ];
388        let offsets = vec![0, 1, -1];
389        let shape = (3, 3);
390
391        let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
392        assert_eq!(result.shape(), (3, 3));
393        assert_eq!(result.get(0, 0), 1.0);
394        assert_eq!(result.get(1, 1), 2.0);
395        assert_eq!(result.get(2, 2), 3.0);
396        assert_eq!(result.get(0, 1), 4.0);
397        assert_eq!(result.get(1, 2), 5.0);
398        assert_eq!(result.get(1, 0), 6.0);
399        assert_eq!(result.get(2, 1), 7.0);
400
401        // Try LIL format
402        let result_lil = diags_array(&diags, &offsets, shape, "lil").unwrap();
403        assert_eq!(result_lil.shape(), (3, 3));
404        assert_eq!(result_lil.get(0, 0), 1.0);
405        assert_eq!(result_lil.get(1, 1), 2.0);
406        assert_eq!(result_lil.get(2, 2), 3.0);
407        assert_eq!(result_lil.get(0, 1), 4.0);
408        assert_eq!(result_lil.get(1, 2), 5.0);
409        assert_eq!(result_lil.get(1, 0), 6.0);
410        assert_eq!(result_lil.get(2, 1), 7.0);
411    }
412
413    #[test]
414    fn test_random_array() {
415        let shape = (10, 10);
416        let density = 0.3;
417
418        let random = random_array::<f64>(shape, density, None, "csr").unwrap();
419
420        // Check shape and sparsity
421        assert_eq!(random.shape(), shape);
422        let nnz = random.nnz();
423        let expected_nnz = (shape.0 * shape.1) as f64 * density;
424
425        // Allow for some random variation, but should be close to expected density
426        assert!(
427            (nnz as f64) > expected_nnz * 0.7,
428            "Too few non-zeros: {}",
429            nnz
430        );
431        assert!(
432            (nnz as f64) < expected_nnz * 1.3,
433            "Too many non-zeros: {}",
434            nnz
435        );
436
437        // Test with custom RNG seed
438        let random_seeded = random_array::<f64>(shape, density, Some(42), "csr").unwrap();
439        assert_eq!(random_seeded.shape(), shape);
440
441        // Test LIL format
442        let random_lil = random_array::<f64>((5, 5), 0.5, Some(42), "lil").unwrap();
443        assert_eq!(random_lil.shape(), (5, 5));
444        let nnz_lil = random_lil.nnz();
445        let expected_nnz_lil = 25.0 * 0.5;
446        assert!(
447            (nnz_lil as f64) > expected_nnz_lil * 0.7,
448            "Too few non-zeros in LIL: {}",
449            nnz_lil
450        );
451        assert!(
452            (nnz_lil as f64) < expected_nnz_lil * 1.3,
453            "Too many non-zeros in LIL: {}",
454            nnz_lil
455        );
456    }
457}
458
459/// Creates a random sparse array with specified density
460///
461/// # Arguments
462/// * `shape` - Shape of the output array (m, n)
463/// * `density` - Density of non-zero elements (between 0.0 and 1.0)
464/// * `seed` - Optional seed for the random number generator
465/// * `format` - Format of the output array ("csr" or "coo")
466///
467/// # Returns
468/// A sparse array with random non-zero elements
469///
470/// # Examples
471///
472/// ```
473/// use scirs2_sparse::construct::random_array;
474///
475/// // Create a 10x10 array with 30% non-zero elements
476/// let random = random_array::<f64>((10, 10), 0.3, None, "csr").unwrap();
477/// assert_eq!(random.shape(), (10, 10));
478///
479/// // Create a random array with a specific seed
480/// let seeded = random_array::<f64>((5, 5), 0.5, Some(42), "coo").unwrap();
481/// assert_eq!(seeded.shape(), (5, 5));
482/// ```
483pub fn random_array<T>(
484    shape: (usize, usize),
485    density: f64,
486    seed: Option<u64>,
487    format: &str,
488) -> SparseResult<Box<dyn SparseArray<T>>>
489where
490    T: Float
491        + Add<Output = T>
492        + Sub<Output = T>
493        + Mul<Output = T>
494        + Div<Output = T>
495        + Debug
496        + Copy
497        + 'static,
498{
499    let (m, n) = shape;
500
501    if !(0.0..=1.0).contains(&density) {
502        return Err(SparseError::ValueError(
503            "Density must be between 0.0 and 1.0".to_string(),
504        ));
505    }
506
507    if m == 0 || n == 0 {
508        return Err(SparseError::ValueError(
509            "Matrix dimensions must be positive".to_string(),
510        ));
511    }
512
513    // Calculate the number of non-zero elements
514    let nnz = (m * n) as f64 * density;
515    let nnz = nnz.round() as usize;
516
517    // Create random indices
518    let mut rows = Vec::with_capacity(nnz);
519    let mut cols = Vec::with_capacity(nnz);
520    let mut data = Vec::with_capacity(nnz);
521
522    // Create RNG
523    let mut rng = if let Some(seed_value) = seed {
524        rand::rngs::StdRng::seed_from_u64(seed_value)
525    } else {
526        // For a random seed, use rng
527        let seed = rand::Rng::random::<u64>(&mut rand::rng());
528        rand::rngs::StdRng::seed_from_u64(seed)
529    };
530
531    // Generate random elements
532    let total = m * n;
533
534    if density > 0.4 {
535        // For high densities, more efficient to generate a mask
536        let mut indices: Vec<usize> = (0..total).collect();
537        indices.shuffle(&mut rng);
538
539        for &idx in indices.iter().take(nnz) {
540            let row = idx / n;
541            let col = idx % n;
542
543            rows.push(row);
544            cols.push(col);
545
546            // Generate random non-zero value
547            // For simplicity, using values between -1 and 1
548            let mut val: f64 = rng.random_range(-1.0..1.0);
549            // Make sure the value is not zero
550            while val.abs() < 1e-10 {
551                val = rng.random_range(-1.0..1.0);
552            }
553            data.push(T::from(val).unwrap());
554        }
555    } else {
556        // For low densities, use a set to track already-chosen positions
557        let mut positions = std::collections::HashSet::with_capacity(nnz);
558
559        while positions.len() < nnz {
560            let row = rng.random_range(0..m);
561            let col = rng.random_range(0..n);
562            let pos = row * n + col; // Using row/col as usize indices
563
564            if positions.insert(pos) {
565                rows.push(row);
566                cols.push(col);
567
568                // Generate random non-zero value
569                let mut val: f64 = rng.random_range(-1.0..1.0);
570                // Make sure the value is not zero
571                while val.abs() < 1e-10 {
572                    val = rng.random_range(-1.0..1.0);
573                }
574                data.push(T::from(val).unwrap());
575            }
576        }
577    }
578
579    // Create the output array
580    match format.to_lowercase().as_str() {
581        "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
582            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
583        "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
584            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
585        "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
586            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
587        "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
588            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
589        _ => Err(SparseError::ValueError(format!(
590            "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
591            format
592        ))),
593    }
594}