1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::Debug;
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12
13pub trait SparseArray<T>: std::any::Any
29where
30 T: Float
31 + Add<Output = T>
32 + Sub<Output = T>
33 + Mul<Output = T>
34 + Div<Output = T>
35 + Debug
36 + Copy
37 + 'static,
38{
39 fn shape(&self) -> (usize, usize);
41
42 fn nnz(&self) -> usize;
44
45 fn dtype(&self) -> &str;
47
48 fn to_array(&self) -> Array2<T>;
50
51 fn toarray(&self) -> Array2<T>;
53
54 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
56
57 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
59
60 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
62
63 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
65
66 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
68
69 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
71
72 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
74
75 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
77
78 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
80
81 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
83
84 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
86
87 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
89
90 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>>;
92
93 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
95
96 fn copy(&self) -> Box<dyn SparseArray<T>>;
98
99 fn get(&self, i: usize, j: usize) -> T;
101
102 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()>;
104
105 fn eliminate_zeros(&mut self);
107
108 fn sort_indices(&mut self);
110
111 fn sorted_indices(&self) -> Box<dyn SparseArray<T>>;
113
114 fn has_sorted_indices(&self) -> bool;
116
117 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>>;
124
125 fn max(&self) -> T;
127
128 fn min(&self) -> T;
130
131 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>);
133
134 fn slice(
136 &self,
137 row_range: (usize, usize),
138 col_range: (usize, usize),
139 ) -> SparseResult<Box<dyn SparseArray<T>>>;
140
141 fn as_any(&self) -> &dyn std::any::Any;
143}
144
145pub enum SparseSum<T>
148where
149 T: Float + Debug + Copy + 'static,
150{
151 SparseArray(Box<dyn SparseArray<T>>),
153
154 Scalar(T),
156}
157
158impl<T> Debug for SparseSum<T>
159where
160 T: Float + Debug + Copy + 'static,
161{
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 match self {
164 SparseSum::SparseArray(_) => write!(f, "SparseSum::SparseArray(...)"),
165 SparseSum::Scalar(value) => write!(f, "SparseSum::Scalar({:?})", value),
166 }
167 }
168}
169
170impl<T> Clone for SparseSum<T>
171where
172 T: Float + Debug + Copy + 'static,
173{
174 fn clone(&self) -> Self {
175 match self {
176 SparseSum::SparseArray(array) => SparseSum::SparseArray(array.copy()),
177 SparseSum::Scalar(value) => SparseSum::Scalar(*value),
178 }
179 }
180}
181
182pub fn is_sparse<T>(_obj: &dyn SparseArray<T>) -> bool
184where
185 T: Float
186 + Add<Output = T>
187 + Sub<Output = T>
188 + Mul<Output = T>
189 + Div<Output = T>
190 + Debug
191 + Copy
192 + 'static,
193{
194 true }
196
197pub struct SparseArrayBase<T>
199where
200 T: Float
201 + Add<Output = T>
202 + Sub<Output = T>
203 + Mul<Output = T>
204 + Div<Output = T>
205 + Debug
206 + Copy
207 + 'static,
208{
209 data: Array2<T>,
210}
211
212impl<T> SparseArrayBase<T>
213where
214 T: Float
215 + Add<Output = T>
216 + Sub<Output = T>
217 + Mul<Output = T>
218 + Div<Output = T>
219 + Debug
220 + Copy
221 + 'static,
222{
223 pub fn new(data: Array2<T>) -> Self {
225 Self { data }
226 }
227}
228
229impl<T> SparseArray<T> for SparseArrayBase<T>
230where
231 T: Float
232 + Add<Output = T>
233 + Sub<Output = T>
234 + Mul<Output = T>
235 + Div<Output = T>
236 + Debug
237 + Copy
238 + 'static,
239{
240 fn shape(&self) -> (usize, usize) {
241 let shape = self.data.shape();
242 (shape[0], shape[1])
243 }
244
245 fn nnz(&self) -> usize {
246 self.data.iter().filter(|&&x| !x.is_zero()).count()
247 }
248
249 fn dtype(&self) -> &str {
250 "float" }
252
253 fn to_array(&self) -> Array2<T> {
254 self.data.clone()
255 }
256
257 fn toarray(&self) -> Array2<T> {
258 self.data.clone()
259 }
260
261 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
262 Ok(Box::new(self.clone()))
264 }
265
266 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
267 Ok(Box::new(self.clone()))
269 }
270
271 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
272 Ok(Box::new(self.clone()))
274 }
275
276 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
277 Ok(Box::new(self.clone()))
279 }
280
281 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
282 Ok(Box::new(self.clone()))
284 }
285
286 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
287 Ok(Box::new(self.clone()))
289 }
290
291 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
292 Ok(Box::new(self.clone()))
294 }
295
296 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
297 let other_array = other.to_array();
298 let result = &self.data + &other_array;
299 Ok(Box::new(SparseArrayBase::new(result)))
300 }
301
302 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
303 let other_array = other.to_array();
304 let result = &self.data - &other_array;
305 Ok(Box::new(SparseArrayBase::new(result)))
306 }
307
308 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
309 let other_array = other.to_array();
310 let result = &self.data * &other_array;
311 Ok(Box::new(SparseArrayBase::new(result)))
312 }
313
314 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
315 let other_array = other.to_array();
316 let result = &self.data / &other_array;
317 Ok(Box::new(SparseArrayBase::new(result)))
318 }
319
320 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
321 let other_array = other.to_array();
322 let (m, n) = self.shape();
323 let (p, q) = other.shape();
324
325 if n != p {
326 return Err(SparseError::DimensionMismatch {
327 expected: n,
328 found: p,
329 });
330 }
331
332 let mut result = Array2::zeros((m, q));
333 for i in 0..m {
334 for j in 0..q {
335 let mut sum = T::zero();
336 for k in 0..n {
337 sum = sum + self.data[[i, k]] * other_array[[k, j]];
338 }
339 result[[i, j]] = sum;
340 }
341 }
342
343 Ok(Box::new(SparseArrayBase::new(result)))
344 }
345
346 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
347 let (m, n) = self.shape();
348 if n != other.len() {
349 return Err(SparseError::DimensionMismatch {
350 expected: n,
351 found: other.len(),
352 });
353 }
354
355 let mut result = Array1::zeros(m);
356 for i in 0..m {
357 let mut sum = T::zero();
358 for j in 0..n {
359 sum = sum + self.data[[i, j]] * other[j];
360 }
361 result[i] = sum;
362 }
363
364 Ok(result)
365 }
366
367 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
368 Ok(Box::new(SparseArrayBase::new(self.data.t().to_owned())))
369 }
370
371 fn copy(&self) -> Box<dyn SparseArray<T>> {
372 Box::new(self.clone())
373 }
374
375 fn get(&self, i: usize, j: usize) -> T {
376 self.data[[i, j]]
377 }
378
379 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
380 let (m, n) = self.shape();
381 if i >= m || j >= n {
382 return Err(SparseError::IndexOutOfBounds {
383 index: (i, j),
384 shape: (m, n),
385 });
386 }
387 self.data[[i, j]] = value;
388 Ok(())
389 }
390
391 fn eliminate_zeros(&mut self) {
392 }
394
395 fn sort_indices(&mut self) {
396 }
398
399 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
400 self.copy()
401 }
402
403 fn has_sorted_indices(&self) -> bool {
404 true }
406
407 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
408 match axis {
409 None => {
410 let mut sum = T::zero();
411 for &val in self.data.iter() {
412 sum = sum + val;
413 }
414 Ok(SparseSum::Scalar(sum))
415 }
416 Some(0) => {
417 let (_, n) = self.shape();
418 let mut result = Array2::zeros((1, n));
419 for j in 0..n {
420 let mut sum = T::zero();
421 for i in 0..self.data.shape()[0] {
422 sum = sum + self.data[[i, j]];
423 }
424 result[[0, j]] = sum;
425 }
426 Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
427 result,
428 ))))
429 }
430 Some(1) => {
431 let (m, _) = self.shape();
432 let mut result = Array2::zeros((m, 1));
433 for i in 0..m {
434 let mut sum = T::zero();
435 for j in 0..self.data.shape()[1] {
436 sum = sum + self.data[[i, j]];
437 }
438 result[[i, 0]] = sum;
439 }
440 Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
441 result,
442 ))))
443 }
444 _ => Err(SparseError::InvalidAxis),
445 }
446 }
447
448 fn max(&self) -> T {
449 self.data
450 .iter()
451 .fold(T::neg_infinity(), |acc, &x| acc.max(x))
452 }
453
454 fn min(&self) -> T {
455 self.data.iter().fold(T::infinity(), |acc, &x| acc.min(x))
456 }
457
458 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
459 let (m, n) = self.shape();
460 let nnz = self.nnz();
461 let mut rows = Vec::with_capacity(nnz);
462 let mut cols = Vec::with_capacity(nnz);
463 let mut values = Vec::with_capacity(nnz);
464
465 for i in 0..m {
466 for j in 0..n {
467 let value = self.data[[i, j]];
468 if !value.is_zero() {
469 rows.push(i);
470 cols.push(j);
471 values.push(value);
472 }
473 }
474 }
475
476 (
477 Array1::from_vec(rows),
478 Array1::from_vec(cols),
479 Array1::from_vec(values),
480 )
481 }
482
483 fn slice(
484 &self,
485 row_range: (usize, usize),
486 col_range: (usize, usize),
487 ) -> SparseResult<Box<dyn SparseArray<T>>> {
488 let (start_row, end_row) = row_range;
489 let (start_col, end_col) = col_range;
490 let (m, n) = self.shape();
491
492 if start_row >= m
493 || end_row > m
494 || start_col >= n
495 || end_col > n
496 || start_row >= end_row
497 || start_col >= end_col
498 {
499 return Err(SparseError::InvalidSliceRange);
500 }
501
502 let view = self
503 .data
504 .slice(ndarray::s![start_row..end_row, start_col..end_col]);
505 Ok(Box::new(SparseArrayBase::new(view.to_owned())))
506 }
507
508 fn as_any(&self) -> &dyn std::any::Any {
509 self
510 }
511}
512
513impl<T> Clone for SparseArrayBase<T>
514where
515 T: Float
516 + Add<Output = T>
517 + Sub<Output = T>
518 + Mul<Output = T>
519 + Div<Output = T>
520 + Debug
521 + Copy
522 + 'static,
523{
524 fn clone(&self) -> Self {
525 Self {
526 data: self.data.clone(),
527 }
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use ndarray::Array;
535
536 #[test]
537 fn test_sparse_array_base() {
538 let data = Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0])
539 .unwrap();
540 let sparse = SparseArrayBase::new(data);
541
542 assert_eq!(sparse.shape(), (3, 3));
543 assert_eq!(sparse.nnz(), 5);
544 assert_eq!(sparse.get(0, 0), 1.0);
545 assert_eq!(sparse.get(1, 1), 3.0);
546 assert_eq!(sparse.get(2, 2), 5.0);
547 assert_eq!(sparse.get(0, 1), 0.0);
548 }
549
550 #[test]
551 fn test_sparse_array_operations() {
552 let data1 = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
553 let data2 = Array::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
554
555 let sparse1 = SparseArrayBase::new(data1);
556 let sparse2 = SparseArrayBase::new(data2);
557
558 let result = sparse1.add(&sparse2).unwrap();
560 let result_array = result.to_array();
561 assert_eq!(result_array[[0, 0]], 6.0);
562 assert_eq!(result_array[[0, 1]], 8.0);
563 assert_eq!(result_array[[1, 0]], 10.0);
564 assert_eq!(result_array[[1, 1]], 12.0);
565
566 let result = sparse1.dot(&sparse2).unwrap();
568 let result_array = result.to_array();
569 assert_eq!(result_array[[0, 0]], 19.0);
570 assert_eq!(result_array[[0, 1]], 22.0);
571 assert_eq!(result_array[[1, 0]], 43.0);
572 assert_eq!(result_array[[1, 1]], 50.0);
573 }
574}