1pub mod error;
66pub use error::{SparseError, SparseResult};
67
68pub mod sparray;
70pub use sparray::{is_sparse, SparseArray, SparseSum};
71
72pub mod sym_sparray;
74pub use sym_sparray::SymSparseArray;
75
76pub mod csr_array;
80pub use csr_array::CsrArray;
81
82pub mod csc_array;
83pub use csc_array::CscArray;
84
85pub mod coo_array;
86pub use coo_array::CooArray;
87
88pub mod dok_array;
89pub use dok_array::DokArray;
90
91pub mod lil_array;
92pub use lil_array::LilArray;
93
94pub mod dia_array;
95pub use dia_array::DiaArray;
96
97pub mod bsr_array;
98pub use bsr_array::BsrArray;
99
100pub mod sym_csr;
102pub use sym_csr::{SymCsrArray, SymCsrMatrix};
103
104pub mod sym_coo;
105pub use sym_coo::{SymCooArray, SymCooMatrix};
106
107pub mod csr;
109pub use csr::CsrMatrix;
110
111pub mod csc;
112pub use csc::CscMatrix;
113
114pub mod coo;
115pub use coo::CooMatrix;
116
117pub mod dok;
118pub use dok::DokMatrix;
119
120pub mod lil;
121pub use lil::LilMatrix;
122
123pub mod dia;
124pub use dia::DiaMatrix;
125
126pub mod bsr;
127pub use bsr::BsrMatrix;
128
129pub mod utils;
131
132pub mod linalg;
134pub use linalg::{
136 add,
138 bicg,
140 bicgstab,
141 cg,
142 diag_matrix,
143 expm,
144 expm_multiply,
146 eye,
147 gmres,
148 inv,
149 matmul,
150 matrix_power,
151 multiply,
152 norm,
153 onenormest,
154 sparse_direct_solve,
155 sparse_lstsq,
156 spsolve,
157 AsLinearOperator,
159 BiCGOptions,
161 BiCGSTABOptions,
162 BiCGSTABResult,
163 CGOptions,
164 CGSOptions,
165 CGSResult,
166 DiagonalOperator,
168 GMRESOptions,
169 ILU0Preconditioner,
171 IdentityOperator,
172 IterationResult,
173 JacobiPreconditioner,
174 LinearOperator,
175 SSORPreconditioner,
176 ScaledIdentityOperator,
177};
178
179pub mod convert;
181
182pub mod construct;
184pub mod construct_sym;
185
186pub mod combine;
188pub use combine::{block_diag, bmat, hstack, kron, kronsum, tril, triu, vstack};
189
190pub mod index_dtype;
192pub use index_dtype::{can_cast_safely, get_index_dtype, safely_cast_index_arrays};
193
194pub mod sym_ops;
196pub use sym_ops::{
197 sym_coo_matvec, sym_csr_matvec, sym_csr_quadratic_form, sym_csr_rank1_update, sym_csr_trace,
198};
199
200pub struct SparseEfficiencyWarning;
202pub struct SparseWarning;
203
204pub fn is_sparse_array<T>(obj: &dyn SparseArray<T>) -> bool
206where
207 T: num_traits::Float
208 + std::fmt::Debug
209 + Copy
210 + std::ops::Add<Output = T>
211 + std::ops::Sub<Output = T>
212 + std::ops::Mul<Output = T>
213 + std::ops::Div<Output = T>
214 + 'static,
215{
216 sparray::is_sparse(obj)
217}
218
219pub fn is_sym_sparse_array<T>(obj: &dyn SymSparseArray<T>) -> bool
221where
222 T: num_traits::Float
223 + std::fmt::Debug
224 + Copy
225 + std::ops::Add<Output = T>
226 + std::ops::Sub<Output = T>
227 + std::ops::Mul<Output = T>
228 + std::ops::Div<Output = T>
229 + 'static,
230{
231 obj.is_symmetric()
232}
233
234pub fn is_sparse_matrix(obj: &dyn std::any::Any) -> bool {
236 obj.is::<CsrMatrix<f64>>()
237 || obj.is::<CscMatrix<f64>>()
238 || obj.is::<CooMatrix<f64>>()
239 || obj.is::<DokMatrix<f64>>()
240 || obj.is::<LilMatrix<f64>>()
241 || obj.is::<DiaMatrix<f64>>()
242 || obj.is::<BsrMatrix<f64>>()
243 || obj.is::<SymCsrMatrix<f64>>()
244 || obj.is::<SymCooMatrix<f64>>()
245 || obj.is::<CsrMatrix<f32>>()
246 || obj.is::<CscMatrix<f32>>()
247 || obj.is::<CooMatrix<f32>>()
248 || obj.is::<DokMatrix<f32>>()
249 || obj.is::<LilMatrix<f32>>()
250 || obj.is::<DiaMatrix<f32>>()
251 || obj.is::<BsrMatrix<f32>>()
252 || obj.is::<SymCsrMatrix<f32>>()
253 || obj.is::<SymCooMatrix<f32>>()
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use approx::assert_relative_eq;
260
261 #[test]
262 fn test_csr_array() {
263 let rows = vec![0, 0, 1, 2, 2];
264 let cols = vec![0, 2, 2, 0, 1];
265 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
266 let shape = (3, 3);
267
268 let array = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
269
270 assert_eq!(array.shape(), (3, 3));
271 assert_eq!(array.nnz(), 5);
272 assert!(is_sparse_array(&array));
273 }
274
275 #[test]
276 fn test_coo_array() {
277 let rows = vec![0, 0, 1, 2, 2];
278 let cols = vec![0, 2, 2, 0, 1];
279 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
280 let shape = (3, 3);
281
282 let array = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
283
284 assert_eq!(array.shape(), (3, 3));
285 assert_eq!(array.nnz(), 5);
286 assert!(is_sparse_array(&array));
287 }
288
289 #[test]
290 fn test_dok_array() {
291 let rows = vec![0, 0, 1, 2, 2];
292 let cols = vec![0, 2, 2, 0, 1];
293 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
294 let shape = (3, 3);
295
296 let array = DokArray::from_triplets(&rows, &cols, &data, shape).unwrap();
297
298 assert_eq!(array.shape(), (3, 3));
299 assert_eq!(array.nnz(), 5);
300 assert!(is_sparse_array(&array));
301
302 let mut array = DokArray::<f64>::new((2, 2));
304 array.set(0, 0, 1.0).unwrap();
305 array.set(1, 1, 2.0).unwrap();
306
307 assert_eq!(array.get(0, 0), 1.0);
308 assert_eq!(array.get(0, 1), 0.0);
309 assert_eq!(array.get(1, 1), 2.0);
310
311 array.set(0, 0, 0.0).unwrap();
313 assert_eq!(array.nnz(), 1);
314 }
315
316 #[test]
317 fn test_lil_array() {
318 let rows = vec![0, 0, 1, 2, 2];
319 let cols = vec![0, 2, 2, 0, 1];
320 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
321 let shape = (3, 3);
322
323 let array = LilArray::from_triplets(&rows, &cols, &data, shape).unwrap();
324
325 assert_eq!(array.shape(), (3, 3));
326 assert_eq!(array.nnz(), 5);
327 assert!(is_sparse_array(&array));
328
329 let mut array = LilArray::<f64>::new((2, 2));
331 array.set(0, 0, 1.0).unwrap();
332 array.set(1, 1, 2.0).unwrap();
333
334 assert_eq!(array.get(0, 0), 1.0);
335 assert_eq!(array.get(0, 1), 0.0);
336 assert_eq!(array.get(1, 1), 2.0);
337
338 assert!(array.has_sorted_indices());
340
341 array.set(0, 0, 0.0).unwrap();
343 assert_eq!(array.nnz(), 1);
344 }
345
346 #[test]
347 fn test_dia_array() {
348 use ndarray::Array1;
349
350 let data = vec![
352 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), ];
355 let offsets = vec![0, 1]; let shape = (3, 3);
357
358 let array = DiaArray::new(data, offsets, shape).unwrap();
359
360 assert_eq!(array.shape(), (3, 3));
361 assert_eq!(array.nnz(), 5); assert!(is_sparse_array(&array));
363
364 assert_eq!(array.get(0, 0), 1.0);
366 assert_eq!(array.get(1, 1), 2.0);
367 assert_eq!(array.get(2, 2), 3.0);
368 assert_eq!(array.get(0, 1), 4.0);
369 assert_eq!(array.get(1, 2), 5.0);
370 assert_eq!(array.get(0, 2), 0.0);
371
372 let rows = vec![0, 0, 1, 1, 2];
374 let cols = vec![0, 1, 1, 2, 2];
375 let data_vec = vec![1.0, 4.0, 2.0, 5.0, 3.0];
376
377 let array2 = DiaArray::from_triplets(&rows, &cols, &data_vec, shape).unwrap();
378
379 assert_eq!(array2.get(0, 0), 1.0);
381 assert_eq!(array2.get(1, 1), 2.0);
382 assert_eq!(array2.get(2, 2), 3.0);
383 assert_eq!(array2.get(0, 1), 4.0);
384 assert_eq!(array2.get(1, 2), 5.0);
385
386 let csr = array.to_csr().unwrap();
388 assert_eq!(csr.nnz(), 5);
389 assert_eq!(csr.get(0, 0), 1.0);
390 assert_eq!(csr.get(0, 1), 4.0);
391 }
392
393 #[test]
394 fn test_format_conversions() {
395 let rows = vec![0, 0, 1, 2, 2];
396 let cols = vec![0, 2, 1, 0, 2];
397 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
398 let shape = (3, 3);
399
400 let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
402
403 let csr = coo.to_csr().unwrap();
405
406 let coo_dense = coo.to_array();
408 let csr_dense = csr.to_array();
409
410 for i in 0..shape.0 {
411 for j in 0..shape.1 {
412 assert_relative_eq!(coo_dense[[i, j]], csr_dense[[i, j]]);
413 }
414 }
415 }
416
417 #[test]
418 fn test_dot_product() {
419 let rows = vec![0, 0, 1, 2, 2];
420 let cols = vec![0, 2, 1, 0, 2];
421 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
422 let shape = (3, 3);
423
424 let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
426 let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
427
428 let coo_result = coo.dot(&coo).unwrap();
430 let csr_result = csr.dot(&csr).unwrap();
431
432 let coo_dense = coo_result.to_array();
434 let csr_dense = csr_result.to_array();
435
436 for i in 0..shape.0 {
437 for j in 0..shape.1 {
438 assert_relative_eq!(coo_dense[[i, j]], csr_dense[[i, j]], epsilon = 1e-10);
439 }
440 }
441 }
442
443 #[test]
444 fn test_sym_csr_array() {
445 let data = vec![2.0, 1.0, 2.0, 3.0, 0.0, 3.0, 1.0];
447 let indices = vec![0, 0, 1, 2, 0, 1, 2];
448 let indptr = vec![0, 1, 3, 7];
449
450 let sym_matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).unwrap();
451 let sym_array = SymCsrArray::new(sym_matrix);
452
453 assert_eq!(sym_array.shape(), (3, 3));
454 assert!(is_sym_sparse_array(&sym_array));
455
456 assert_eq!(SparseArray::get(&sym_array, 0, 0), 2.0);
458 assert_eq!(SparseArray::get(&sym_array, 0, 1), 1.0);
459 assert_eq!(SparseArray::get(&sym_array, 1, 0), 1.0); assert_eq!(SparseArray::get(&sym_array, 1, 2), 3.0);
461 assert_eq!(SparseArray::get(&sym_array, 2, 1), 3.0); let csr = SymSparseArray::to_csr(&sym_array).unwrap();
465 assert_eq!(csr.nnz(), 10); }
467
468 #[test]
469 fn test_sym_coo_array() {
470 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
472 let rows = vec![0, 1, 1, 2, 2];
473 let cols = vec![0, 0, 1, 1, 2];
474
475 let sym_matrix = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
476 let sym_array = SymCooArray::new(sym_matrix);
477
478 assert_eq!(sym_array.shape(), (3, 3));
479 assert!(is_sym_sparse_array(&sym_array));
480
481 assert_eq!(SparseArray::get(&sym_array, 0, 0), 2.0);
483 assert_eq!(SparseArray::get(&sym_array, 0, 1), 1.0);
484 assert_eq!(SparseArray::get(&sym_array, 1, 0), 1.0); assert_eq!(SparseArray::get(&sym_array, 1, 2), 3.0);
486 assert_eq!(SparseArray::get(&sym_array, 2, 1), 3.0); let rows2 = vec![0, 0, 1, 1, 2, 1, 0];
491 let cols2 = vec![0, 1, 1, 2, 2, 0, 2];
492 let data2 = vec![2.0, 1.5, 2.0, 3.5, 1.0, 0.5, 0.0];
493
494 let sym_array2 = SymCooArray::from_triplets(&rows2, &cols2, &data2, (3, 3), true).unwrap();
495
496 assert_eq!(SparseArray::get(&sym_array2, 0, 1), 1.0); assert_eq!(SparseArray::get(&sym_array2, 1, 0), 1.0); assert_eq!(SparseArray::get(&sym_array2, 0, 2), 0.0); }
501
502 #[test]
503 fn test_construct_sym_utils() {
504 let eye = construct_sym::eye_sym_array::<f64>(3, "csr").unwrap();
506
507 assert_eq!(eye.shape(), (3, 3));
508 assert_eq!(SparseArray::get(&*eye, 0, 0), 1.0);
509 assert_eq!(SparseArray::get(&*eye, 1, 1), 1.0);
510 assert_eq!(SparseArray::get(&*eye, 2, 2), 1.0);
511 assert_eq!(SparseArray::get(&*eye, 0, 1), 0.0);
512
513 let diag = vec![2.0, 2.0, 2.0];
515 let offdiag = vec![1.0, 1.0];
516
517 let tri = construct_sym::tridiagonal_sym_array(&diag, &offdiag, "coo").unwrap();
518
519 assert_eq!(tri.shape(), (3, 3));
520 assert_eq!(SparseArray::get(&*tri, 0, 0), 2.0); assert_eq!(SparseArray::get(&*tri, 1, 1), 2.0);
522 assert_eq!(SparseArray::get(&*tri, 2, 2), 2.0);
523 assert_eq!(SparseArray::get(&*tri, 0, 1), 1.0); assert_eq!(SparseArray::get(&*tri, 1, 0), 1.0); assert_eq!(SparseArray::get(&*tri, 1, 2), 1.0);
526 assert_eq!(SparseArray::get(&*tri, 0, 2), 0.0); let diagonals = vec![
530 vec![2.0, 2.0, 2.0, 2.0, 2.0], vec![1.0, 1.0, 1.0, 1.0], vec![0.5, 0.5, 0.5], ];
534
535 let band = construct_sym::banded_sym_array(&diagonals, 5, "csr").unwrap();
536
537 assert_eq!(band.shape(), (5, 5));
538 assert_eq!(SparseArray::get(&*band, 0, 0), 2.0);
539 assert_eq!(SparseArray::get(&*band, 0, 1), 1.0);
540 assert_eq!(SparseArray::get(&*band, 0, 2), 0.5);
541 assert_eq!(SparseArray::get(&*band, 2, 0), 0.5); }
543
544 #[test]
545 fn test_sym_conversions() {
546 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
549 let rows = vec![0, 1, 1, 2, 2];
550 let cols = vec![0, 0, 1, 1, 2];
551
552 let sym_coo = SymCooArray::from_triplets(&rows, &cols, &data, (3, 3), true).unwrap();
553
554 let sym_csr = sym_coo.to_sym_csr().unwrap();
556
557 for i in 0..3 {
559 for j in 0..3 {
560 assert_eq!(
561 SparseArray::get(&sym_coo, i, j),
562 SparseArray::get(&sym_csr, i, j)
563 );
564 }
565 }
566
567 let csr = SymSparseArray::to_csr(&sym_coo).unwrap();
569 let coo = SymSparseArray::to_coo(&sym_csr).unwrap();
570
571 assert_eq!(csr.nnz(), 7); assert_eq!(coo.nnz(), 7);
574
575 for i in 0..3 {
576 for j in 0..3 {
577 assert_eq!(SparseArray::get(&csr, i, j), SparseArray::get(&coo, i, j));
578 assert_eq!(
579 SparseArray::get(&csr, i, j),
580 SparseArray::get(&sym_csr, i, j)
581 );
582 }
583 }
584 }
585}