1use scirs2_core::numeric::{SparseElement, Zero};
7
8pub struct LilMatrix<T> {
13 rows: usize,
15 cols: usize,
17 data: Vec<Vec<T>>,
19 indices: Vec<Vec<usize>>,
21}
22
23impl<T> LilMatrix<T>
24where
25 T: Clone + Copy + Zero + std::cmp::PartialEq + SparseElement,
26{
27 pub fn new(shape: (usize, usize)) -> Self {
53 let (rows, cols) = shape;
54
55 let data = vec![Vec::new(); rows];
56 let indices = vec![Vec::new(); rows];
57
58 LilMatrix {
59 rows,
60 cols,
61 data,
62 indices,
63 }
64 }
65
66 pub fn set(&mut self, row: usize, col: usize, value: T) {
74 if row >= self.rows || col >= self.cols {
75 return;
76 }
77
78 match self.indices[row].binary_search(&col) {
80 Ok(idx) => {
81 if value == T::sparse_zero() {
83 self.data[row].remove(idx);
85 self.indices[row].remove(idx);
86 } else {
87 self.data[row][idx] = value;
89 }
90 }
91 Err(idx) => {
92 if value != T::sparse_zero() {
94 self.data[row].insert(idx, value);
96 self.indices[row].insert(idx, col);
97 }
98 }
99 }
100 }
101
102 pub fn get(&self, row: usize, col: usize) -> T {
113 if row >= self.rows || col >= self.cols {
114 return T::sparse_zero();
115 }
116
117 match self.indices[row].binary_search(&col) {
118 Ok(idx) => self.data[row][idx],
119 Err(_) => T::sparse_zero(),
120 }
121 }
122
123 pub fn rows(&self) -> usize {
125 self.rows
126 }
127
128 pub fn cols(&self) -> usize {
130 self.cols
131 }
132
133 pub fn shape(&self) -> (usize, usize) {
135 (self.rows, self.cols)
136 }
137
138 pub fn nnz(&self) -> usize {
140 self.indices.iter().map(|row| row.len()).sum()
141 }
142
143 pub fn to_dense(&self) -> Vec<Vec<T>>
145 where
146 T: Zero + Copy + SparseElement,
147 {
148 let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
149
150 for (row, (row_indices, row_data)) in self
151 .indices
152 .iter()
153 .zip(&self.data)
154 .enumerate()
155 .take(self.rows)
156 {
157 for (idx, &col) in row_indices.iter().enumerate() {
158 result[row][col] = row_data[idx];
159 }
160 }
161
162 result
163 }
164
165 pub fn to_coo(&self) -> (Vec<T>, Vec<usize>, Vec<usize>) {
171 let nnz = self.nnz();
172 let mut data = Vec::with_capacity(nnz);
173 let mut row_indices = Vec::with_capacity(nnz);
174 let mut col_indices = Vec::with_capacity(nnz);
175
176 for row in 0..self.rows {
177 for (idx, &col) in self.indices[row].iter().enumerate() {
178 data.push(self.data[row][idx]);
179 row_indices.push(row);
180 col_indices.push(col);
181 }
182 }
183
184 (data, row_indices, col_indices)
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_lil_create_and_access() {
194 let mut matrix = LilMatrix::<f64>::new((3, 3));
196
197 matrix.set(0, 0, 1.0);
199 matrix.set(0, 2, 2.0);
200 matrix.set(1, 2, 3.0);
201 matrix.set(2, 0, 4.0);
202 matrix.set(2, 1, 5.0);
203
204 assert_eq!(matrix.nnz(), 5);
205
206 assert_eq!(matrix.get(0, 0), 1.0);
208 assert_eq!(matrix.get(0, 1), 0.0); assert_eq!(matrix.get(0, 2), 2.0);
210 assert_eq!(matrix.get(1, 2), 3.0);
211 assert_eq!(matrix.get(2, 0), 4.0);
212 assert_eq!(matrix.get(2, 1), 5.0);
213
214 matrix.set(0, 0, 0.0);
216 assert_eq!(matrix.nnz(), 4);
217 assert_eq!(matrix.get(0, 0), 0.0);
218
219 assert_eq!(matrix.get(3, 0), 0.0);
221 assert_eq!(matrix.get(0, 3), 0.0);
222 }
223
224 #[test]
225 fn test_lil_to_dense() {
226 let mut matrix = LilMatrix::<f64>::new((3, 3));
228
229 matrix.set(0, 0, 1.0);
231 matrix.set(0, 2, 2.0);
232 matrix.set(1, 2, 3.0);
233 matrix.set(2, 0, 4.0);
234 matrix.set(2, 1, 5.0);
235
236 let dense = matrix.to_dense();
237
238 let expected = vec![
239 vec![1.0, 0.0, 2.0],
240 vec![0.0, 0.0, 3.0],
241 vec![4.0, 5.0, 0.0],
242 ];
243
244 assert_eq!(dense, expected);
245 }
246
247 #[test]
248 fn test_lil_to_coo() {
249 let mut matrix = LilMatrix::<f64>::new((3, 3));
251
252 matrix.set(0, 0, 1.0);
254 matrix.set(0, 2, 2.0);
255 matrix.set(1, 2, 3.0);
256 matrix.set(2, 0, 4.0);
257 matrix.set(2, 1, 5.0);
258
259 let (data, row_indices, col_indices) = matrix.to_coo();
260
261 assert_eq!(data.len(), 5);
263 assert_eq!(row_indices.len(), 5);
264 assert_eq!(col_indices.len(), 5);
265
266 for i in 0..data.len() {
268 let row = row_indices[i];
269 let col = col_indices[i];
270 let val = data[i];
271
272 assert_eq!(matrix.get(row, col), val);
273 }
274 }
275}