scirs2_sparse/
dok.rs

1//! Dictionary of Keys (DOK) matrix format
2//!
3//! This module provides the DOK matrix format implementation, which is
4//! efficient for incremental matrix construction.
5
6use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::Zero;
8use std::collections::HashMap;
9
10/// Dictionary of Keys (DOK) matrix
11///
12/// A sparse matrix format that stores elements in a dictionary (hash map),
13/// making it efficient for incremental construction.
14pub struct DokMatrix<T> {
15    /// Number of rows
16    rows: usize,
17    /// Number of columns
18    cols: usize,
19    /// Dictionary of (row, col) -> value
20    data: HashMap<(usize, usize), T>,
21}
22
23impl<T> DokMatrix<T>
24where
25    T: Clone + Copy + Zero + std::cmp::PartialEq,
26{
27    /// Create a new DOK matrix
28    ///
29    /// # Arguments
30    ///
31    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
32    ///
33    /// # Returns
34    ///
35    /// * A new empty DOK matrix
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// use scirs2_sparse::dok::DokMatrix;
41    ///
42    /// // Create a 3x3 sparse matrix
43    /// let mut matrix = DokMatrix::<f64>::new((3, 3));
44    ///
45    /// // Set some values
46    /// matrix.set(0, 0, 1.0);
47    /// matrix.set(0, 2, 2.0);
48    /// matrix.set(1, 2, 3.0);
49    /// matrix.set(2, 0, 4.0);
50    /// matrix.set(2, 1, 5.0);
51    /// ```
52    pub fn new(shape: (usize, usize)) -> Self {
53        let (rows, cols) = shape;
54
55        DokMatrix {
56            rows,
57            cols,
58            data: HashMap::new(),
59        }
60    }
61
62    /// Set a value in the matrix
63    ///
64    /// # Arguments
65    ///
66    /// * `row` - Row index
67    /// * `col` - Column index
68    /// * `value` - Value to set
69    ///
70    /// # Returns
71    ///
72    /// * Ok(()) if successful, Error otherwise
73    pub fn set(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
74        if row >= self.rows || col >= self.cols {
75            return Err(SparseError::ValueError(
76                "Row or column index out of bounds".to_string(),
77            ));
78        }
79
80        if value == T::zero() {
81            // Remove zero entries
82            self.data.remove(&(row, col));
83        } else {
84            // Set non-zero value
85            self.data.insert((row, col), value);
86        }
87
88        Ok(())
89    }
90
91    /// Get a value from the matrix
92    ///
93    /// # Arguments
94    ///
95    /// * `row` - Row index
96    /// * `col` - Column index
97    ///
98    /// # Returns
99    ///
100    /// * Value at the specified position, or zero if not set
101    pub fn get(&self, row: usize, col: usize) -> T {
102        if row >= self.rows || col >= self.cols {
103            return T::zero();
104        }
105
106        *self.data.get(&(row, col)).unwrap_or(&T::zero())
107    }
108
109    /// Get the number of rows in the matrix
110    pub fn rows(&self) -> usize {
111        self.rows
112    }
113
114    /// Get the number of columns in the matrix
115    pub fn cols(&self) -> usize {
116        self.cols
117    }
118
119    /// Get the shape (dimensions) of the matrix
120    pub fn shape(&self) -> (usize, usize) {
121        (self.rows, self.cols)
122    }
123
124    /// Get the number of non-zero elements in the matrix
125    pub fn nnz(&self) -> usize {
126        self.data.len()
127    }
128
129    /// Convert to dense matrix (as Vec<Vec<T>>)
130    pub fn to_dense(&self) -> Vec<Vec<T>>
131    where
132        T: Zero + Copy,
133    {
134        let mut result = vec![vec![T::zero(); self.cols]; self.rows];
135
136        for (&(row, col), &value) in &self.data {
137            result[row][col] = value;
138        }
139
140        result
141    }
142
143    /// Convert to COO representation
144    ///
145    /// # Returns
146    ///
147    /// * Tuple of (data, row_indices, col_indices)
148    pub fn to_coo(&self) -> (Vec<T>, Vec<usize>, Vec<usize>) {
149        let nnz = self.nnz();
150        let mut data = Vec::with_capacity(nnz);
151        let mut row_indices = Vec::with_capacity(nnz);
152        let mut col_indices = Vec::with_capacity(nnz);
153
154        // Sort by row, then column for deterministic output
155        let mut entries: Vec<_> = self.data.iter().collect();
156        entries.sort_by_key(|(&(row, col), _)| (row, col));
157
158        for (&(row, col), &value) in entries {
159            data.push(value);
160            row_indices.push(row);
161            col_indices.push(col);
162        }
163
164        (data, row_indices, col_indices)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_dok_create_and_access() {
174        // Create a 3x3 sparse matrix
175        let mut matrix = DokMatrix::<f64>::new((3, 3));
176
177        // Set some values
178        matrix.set(0, 0, 1.0).unwrap();
179        matrix.set(0, 2, 2.0).unwrap();
180        matrix.set(1, 2, 3.0).unwrap();
181        matrix.set(2, 0, 4.0).unwrap();
182        matrix.set(2, 1, 5.0).unwrap();
183
184        assert_eq!(matrix.nnz(), 5);
185
186        // Access values
187        assert_eq!(matrix.get(0, 0), 1.0);
188        assert_eq!(matrix.get(0, 1), 0.0); // Zero entry
189        assert_eq!(matrix.get(0, 2), 2.0);
190        assert_eq!(matrix.get(1, 2), 3.0);
191        assert_eq!(matrix.get(2, 0), 4.0);
192        assert_eq!(matrix.get(2, 1), 5.0);
193
194        // Set a value to zero should remove it
195        matrix.set(0, 0, 0.0).unwrap();
196        assert_eq!(matrix.nnz(), 4);
197        assert_eq!(matrix.get(0, 0), 0.0);
198
199        // Out of bounds access should return zero
200        assert_eq!(matrix.get(3, 0), 0.0);
201        assert_eq!(matrix.get(0, 3), 0.0);
202    }
203
204    #[test]
205    fn test_dok_to_dense() {
206        // Create a 3x3 sparse matrix
207        let mut matrix = DokMatrix::<f64>::new((3, 3));
208
209        // Set some values
210        matrix.set(0, 0, 1.0).unwrap();
211        matrix.set(0, 2, 2.0).unwrap();
212        matrix.set(1, 2, 3.0).unwrap();
213        matrix.set(2, 0, 4.0).unwrap();
214        matrix.set(2, 1, 5.0).unwrap();
215
216        let dense = matrix.to_dense();
217
218        let expected = vec![
219            vec![1.0, 0.0, 2.0],
220            vec![0.0, 0.0, 3.0],
221            vec![4.0, 5.0, 0.0],
222        ];
223
224        assert_eq!(dense, expected);
225    }
226
227    #[test]
228    fn test_dok_to_coo() {
229        // Create a 3x3 sparse matrix
230        let mut matrix = DokMatrix::<f64>::new((3, 3));
231
232        // Set some values
233        matrix.set(0, 0, 1.0).unwrap();
234        matrix.set(0, 2, 2.0).unwrap();
235        matrix.set(1, 2, 3.0).unwrap();
236        matrix.set(2, 0, 4.0).unwrap();
237        matrix.set(2, 1, 5.0).unwrap();
238
239        let (data, row_indices, col_indices) = matrix.to_coo();
240
241        // Check that all entries are present
242        assert_eq!(data.len(), 5);
243        assert_eq!(row_indices.len(), 5);
244        assert_eq!(col_indices.len(), 5);
245
246        // Check the content (sorted by row, then column)
247        let expected_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
248        let expected_rows = vec![0, 0, 1, 2, 2];
249        let expected_cols = vec![0, 2, 2, 0, 1];
250
251        assert_eq!(data, expected_data);
252        assert_eq!(row_indices, expected_rows);
253        assert_eq!(col_indices, expected_cols);
254    }
255}