1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::Debug;
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12
13pub trait SparseArray<T>: std::any::Any
29where
30 T: Float
31 + Add<Output = T>
32 + Sub<Output = T>
33 + Mul<Output = T>
34 + Div<Output = T>
35 + Debug
36 + Copy
37 + 'static,
38{
39 fn shape(&self) -> (usize, usize);
41
42 fn nnz(&self) -> usize;
44
45 fn dtype(&self) -> &str;
47
48 fn to_array(&self) -> Array2<T>;
50
51 fn toarray(&self) -> Array2<T>;
53
54 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
56
57 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
59
60 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
62
63 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
65
66 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
68
69 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
71
72 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
74
75 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
77
78 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
80
81 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
83
84 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
86
87 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
89
90 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>>;
92
93 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
95
96 fn copy(&self) -> Box<dyn SparseArray<T>>;
98
99 fn get(&self, i: usize, j: usize) -> T;
101
102 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()>;
104
105 fn eliminate_zeros(&mut self);
107
108 fn sort_indices(&mut self);
110
111 fn sorted_indices(&self) -> Box<dyn SparseArray<T>>;
113
114 fn has_sorted_indices(&self) -> bool;
116
117 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>>;
124
125 fn max(&self) -> T;
127
128 fn min(&self) -> T;
130
131 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>);
133
134 fn slice(
136 &self,
137 row_range: (usize, usize),
138 col_range: (usize, usize),
139 ) -> SparseResult<Box<dyn SparseArray<T>>>;
140
141 fn as_any(&self) -> &dyn std::any::Any;
143
144 fn get_indptr(&self) -> Option<&Array1<usize>> {
147 None
148 }
149
150 fn indptr(&self) -> Option<&Array1<usize>> {
153 None
154 }
155}
156
157pub enum SparseSum<T>
160where
161 T: Float + Debug + Copy + 'static,
162{
163 SparseArray(Box<dyn SparseArray<T>>),
165
166 Scalar(T),
168}
169
170impl<T> Debug for SparseSum<T>
171where
172 T: Float + Debug + Copy + 'static,
173{
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match self {
176 SparseSum::SparseArray(_) => write!(f, "SparseSum::SparseArray(...)"),
177 SparseSum::Scalar(value) => write!(f, "SparseSum::Scalar({value:?})"),
178 }
179 }
180}
181
182impl<T> Clone for SparseSum<T>
183where
184 T: Float + Debug + Copy + 'static,
185{
186 fn clone(&self) -> Self {
187 match self {
188 SparseSum::SparseArray(array) => SparseSum::SparseArray(array.copy()),
189 SparseSum::Scalar(value) => SparseSum::Scalar(*value),
190 }
191 }
192}
193
194#[allow(dead_code)]
196pub fn is_sparse<T>(obj: &dyn SparseArray<T>) -> bool
197where
198 T: Float
199 + Add<Output = T>
200 + Sub<Output = T>
201 + Mul<Output = T>
202 + Div<Output = T>
203 + Debug
204 + Copy
205 + 'static,
206{
207 true }
209
210pub struct SparseArrayBase<T>
212where
213 T: Float
214 + Add<Output = T>
215 + Sub<Output = T>
216 + Mul<Output = T>
217 + Div<Output = T>
218 + Debug
219 + Copy
220 + 'static,
221{
222 data: Array2<T>,
223}
224
225impl<T> SparseArrayBase<T>
226where
227 T: Float
228 + Add<Output = T>
229 + Sub<Output = T>
230 + Mul<Output = T>
231 + Div<Output = T>
232 + Debug
233 + Copy
234 + 'static,
235{
236 pub fn new(data: Array2<T>) -> Self {
238 Self { data }
239 }
240}
241
242impl<T> SparseArray<T> for SparseArrayBase<T>
243where
244 T: Float
245 + Add<Output = T>
246 + Sub<Output = T>
247 + Mul<Output = T>
248 + Div<Output = T>
249 + Debug
250 + Copy
251 + 'static,
252{
253 fn shape(&self) -> (usize, usize) {
254 let shape = self.data.shape();
255 (shape[0], shape[1])
256 }
257
258 fn nnz(&self) -> usize {
259 self.data.iter().filter(|&&x| !x.is_zero()).count()
260 }
261
262 fn dtype(&self) -> &str {
263 "float" }
265
266 fn to_array(&self) -> Array2<T> {
267 self.data.clone()
268 }
269
270 fn toarray(&self) -> Array2<T> {
271 self.data.clone()
272 }
273
274 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
275 Ok(Box::new(self.clone()))
277 }
278
279 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
280 Ok(Box::new(self.clone()))
282 }
283
284 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
285 Ok(Box::new(self.clone()))
287 }
288
289 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
290 Ok(Box::new(self.clone()))
292 }
293
294 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
295 Ok(Box::new(self.clone()))
297 }
298
299 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
300 Ok(Box::new(self.clone()))
302 }
303
304 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
305 Ok(Box::new(self.clone()))
307 }
308
309 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
310 let other_array = other.to_array();
311 let result = &self.data + &other_array;
312 Ok(Box::new(SparseArrayBase::new(result)))
313 }
314
315 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
316 let other_array = other.to_array();
317 let result = &self.data - &other_array;
318 Ok(Box::new(SparseArrayBase::new(result)))
319 }
320
321 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
322 let other_array = other.to_array();
323 let result = &self.data * &other_array;
324 Ok(Box::new(SparseArrayBase::new(result)))
325 }
326
327 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
328 let other_array = other.to_array();
329 let result = &self.data / &other_array;
330 Ok(Box::new(SparseArrayBase::new(result)))
331 }
332
333 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
334 let other_array = other.to_array();
335 let (m, n) = self.shape();
336 let (p, q) = other.shape();
337
338 if n != p {
339 return Err(SparseError::DimensionMismatch {
340 expected: n,
341 found: p,
342 });
343 }
344
345 let mut result = Array2::zeros((m, q));
346 for i in 0..m {
347 for j in 0..q {
348 let mut sum = T::zero();
349 for k in 0..n {
350 sum = sum + self.data[[i, k]] * other_array[[k, j]];
351 }
352 result[[i, j]] = sum;
353 }
354 }
355
356 Ok(Box::new(SparseArrayBase::new(result)))
357 }
358
359 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
360 let (m, n) = self.shape();
361 if n != other.len() {
362 return Err(SparseError::DimensionMismatch {
363 expected: n,
364 found: other.len(),
365 });
366 }
367
368 let mut result = Array1::zeros(m);
369 for i in 0..m {
370 let mut sum = T::zero();
371 for j in 0..n {
372 sum = sum + self.data[[i, j]] * other[j];
373 }
374 result[i] = sum;
375 }
376
377 Ok(result)
378 }
379
380 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
381 Ok(Box::new(SparseArrayBase::new(self.data.t().to_owned())))
382 }
383
384 fn copy(&self) -> Box<dyn SparseArray<T>> {
385 Box::new(self.clone())
386 }
387
388 fn get(&self, i: usize, j: usize) -> T {
389 self.data[[i, j]]
390 }
391
392 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
393 let (m, n) = self.shape();
394 if i >= m || j >= n {
395 return Err(SparseError::IndexOutOfBounds {
396 index: (i, j),
397 shape: (m, n),
398 });
399 }
400 self.data[[i, j]] = value;
401 Ok(())
402 }
403
404 fn eliminate_zeros(&mut self) {
405 }
407
408 fn sort_indices(&mut self) {
409 }
411
412 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
413 self.copy()
414 }
415
416 fn has_sorted_indices(&self) -> bool {
417 true }
419
420 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
421 match axis {
422 None => {
423 let mut sum = T::zero();
424 for &val in self.data.iter() {
425 sum = sum + val;
426 }
427 Ok(SparseSum::Scalar(sum))
428 }
429 Some(0) => {
430 let (_, n) = self.shape();
431 let mut result = Array2::zeros((1, n));
432 for j in 0..n {
433 let mut sum = T::zero();
434 for i in 0..self.data.shape()[0] {
435 sum = sum + self.data[[i, j]];
436 }
437 result[[0, j]] = sum;
438 }
439 Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
440 result,
441 ))))
442 }
443 Some(1) => {
444 let (m_, _) = self.shape();
445 let mut result = Array2::zeros((m_, 1));
446 for i in 0..m_ {
447 let mut sum = T::zero();
448 for j in 0..self.data.shape()[1] {
449 sum = sum + self.data[[i, j]];
450 }
451 result[[i, 0]] = sum;
452 }
453 Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
454 result,
455 ))))
456 }
457 _ => Err(SparseError::InvalidAxis),
458 }
459 }
460
461 fn max(&self) -> T {
462 self.data
463 .iter()
464 .fold(T::neg_infinity(), |acc, &x| acc.max(x))
465 }
466
467 fn min(&self) -> T {
468 self.data.iter().fold(T::infinity(), |acc, &x| acc.min(x))
469 }
470
471 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
472 let (m, n) = self.shape();
473 let nnz = self.nnz();
474 let mut rows = Vec::with_capacity(nnz);
475 let mut cols = Vec::with_capacity(nnz);
476 let mut values = Vec::with_capacity(nnz);
477
478 for i in 0..m {
479 for j in 0..n {
480 let value = self.data[[i, j]];
481 if !value.is_zero() {
482 rows.push(i);
483 cols.push(j);
484 values.push(value);
485 }
486 }
487 }
488
489 (
490 Array1::from_vec(rows),
491 Array1::from_vec(cols),
492 Array1::from_vec(values),
493 )
494 }
495
496 fn slice(
497 &self,
498 row_range: (usize, usize),
499 col_range: (usize, usize),
500 ) -> SparseResult<Box<dyn SparseArray<T>>> {
501 let (start_row, end_row) = row_range;
502 let (start_col, end_col) = col_range;
503 let (m, n) = self.shape();
504
505 if start_row >= m
506 || end_row > m
507 || start_col >= n
508 || end_col > n
509 || start_row >= end_row
510 || start_col >= end_col
511 {
512 return Err(SparseError::InvalidSliceRange);
513 }
514
515 let view = self
516 .data
517 .slice(ndarray::s![start_row..end_row, start_col..end_col]);
518 Ok(Box::new(SparseArrayBase::new(view.to_owned())))
519 }
520
521 fn as_any(&self) -> &dyn std::any::Any {
522 self
523 }
524}
525
526impl<T> Clone for SparseArrayBase<T>
527where
528 T: Float
529 + Add<Output = T>
530 + Sub<Output = T>
531 + Mul<Output = T>
532 + Div<Output = T>
533 + Debug
534 + Copy
535 + 'static,
536{
537 fn clone(&self) -> Self {
538 Self {
539 data: self.data.clone(),
540 }
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547 use ndarray::Array;
548
549 #[test]
550 fn test_sparse_array_base() {
551 let data = Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0])
552 .unwrap();
553 let sparse = SparseArrayBase::new(data);
554
555 assert_eq!(sparse.shape(), (3, 3));
556 assert_eq!(sparse.nnz(), 5);
557 assert_eq!(sparse.get(0, 0), 1.0);
558 assert_eq!(sparse.get(1, 1), 3.0);
559 assert_eq!(sparse.get(2, 2), 5.0);
560 assert_eq!(sparse.get(0, 1), 0.0);
561 }
562
563 #[test]
564 fn test_sparse_array_operations() {
565 let data1 = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
566 let data2 = Array::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
567
568 let sparse1 = SparseArrayBase::new(data1);
569 let sparse2 = SparseArrayBase::new(data2);
570
571 let result = sparse1.add(&sparse2).unwrap();
573 let result_array = result.to_array();
574 assert_eq!(result_array[[0, 0]], 6.0);
575 assert_eq!(result_array[[0, 1]], 8.0);
576 assert_eq!(result_array[[1, 0]], 10.0);
577 assert_eq!(result_array[[1, 1]], 12.0);
578
579 let result = sparse1.dot(&sparse2).unwrap();
581 let result_array = result.to_array();
582 assert_eq!(result_array[[0, 0]], 19.0);
583 assert_eq!(result_array[[0, 1]], 22.0);
584 assert_eq!(result_array[[1, 0]], 43.0);
585 assert_eq!(result_array[[1, 1]], 50.0);
586 }
587}