1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::dok_array::DokArray;
14use crate::error::{SparseError, SparseResult};
15use crate::lil_array::LilArray;
16use crate::sparray::{SparseArray, SparseSum};
17
18#[derive(Clone)]
32pub struct DiaArray<T>
33where
34 T: Float
35 + Add<Output = T>
36 + Sub<Output = T>
37 + Mul<Output = T>
38 + Div<Output = T>
39 + Debug
40 + Copy
41 + 'static
42 + std::ops::AddAssign,
43{
44 data: Vec<Array1<T>>,
46 offsets: Vec<isize>,
48 shape: (usize, usize),
50}
51
52impl<T> DiaArray<T>
53where
54 T: Float
55 + Add<Output = T>
56 + Sub<Output = T>
57 + Mul<Output = T>
58 + Div<Output = T>
59 + Debug
60 + Copy
61 + 'static
62 + std::ops::AddAssign,
63{
64 pub fn new(
96 data: Vec<Array1<T>>,
97 offsets: Vec<isize>,
98 shape: (usize, usize),
99 ) -> SparseResult<Self> {
100 let (rows, cols) = shape;
101 let max_dim = rows.max(cols);
102
103 if data.len() != offsets.len() {
105 return Err(SparseError::DimensionMismatch {
106 expected: data.len(),
107 found: offsets.len(),
108 });
109 }
110
111 for diag in data.iter() {
112 if diag.len() != max_dim {
113 return Err(SparseError::DimensionMismatch {
114 expected: max_dim,
115 found: diag.len(),
116 });
117 }
118 }
119
120 Ok(DiaArray {
121 data,
122 offsets,
123 shape,
124 })
125 }
126
127 pub fn empty(shape: (usize, usize)) -> Self {
137 DiaArray {
138 data: Vec::new(),
139 offsets: Vec::new(),
140 shape,
141 }
142 }
143
144 pub fn from_triplets(
157 row: &[usize],
158 col: &[usize],
159 data: &[T],
160 shape: (usize, usize),
161 ) -> SparseResult<Self> {
162 if row.len() != col.len() || row.len() != data.len() {
163 return Err(SparseError::InconsistentData {
164 reason: "Lengths of row, col, and data arrays must be equal".to_string(),
165 });
166 }
167
168 let (rows, cols) = shape;
169 let max_dim = rows.max(cols);
170
171 let mut diagonal_offsets = std::collections::HashSet::new();
173 for (&r, &c) in row.iter().zip(col.iter()) {
174 if r >= rows || c >= cols {
175 return Err(SparseError::IndexOutOfBounds {
176 index: (r, c),
177 shape,
178 });
179 }
180 let offset = c as isize - r as isize;
182 diagonal_offsets.insert(offset);
183 }
184
185 let mut offsets: Vec<isize> = diagonal_offsets.into_iter().collect();
187 offsets.sort();
188
189 let mut diag_data = Vec::with_capacity(offsets.len());
191 for _ in 0..offsets.len() {
192 diag_data.push(Array1::zeros(max_dim));
193 }
194
195 for (&r, (&c, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
197 let offset = c as isize - r as isize;
198 let diag_idx = offsets.iter().position(|&o| o == offset).unwrap();
199
200 let index = if offset >= 0 { r } else { c };
203 diag_data[diag_idx][index] = val;
204 }
205
206 DiaArray::new(diag_data, offsets, shape)
207 }
208
209 fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
211 let (rows, cols) = self.shape;
212 let mut row_indices = Vec::new();
213 let mut col_indices = Vec::new();
214 let mut values = Vec::new();
215
216 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
217 let diag = &self.data[diag_idx];
218
219 if offset >= 0 {
220 let offset_usize = offset as usize;
222 let length = rows.min(cols.saturating_sub(offset_usize));
223
224 for i in 0..length {
225 let value = diag[i];
226 if !value.is_zero() {
227 row_indices.push(i);
228 col_indices.push(i + offset_usize);
229 values.push(value);
230 }
231 }
232 } else {
233 let offset_usize = (-offset) as usize;
235 let length = cols.min(rows.saturating_sub(offset_usize));
236
237 for i in 0..length {
238 let value = diag[i];
239 if !value.is_zero() {
240 row_indices.push(i + offset_usize);
241 col_indices.push(i);
242 values.push(value);
243 }
244 }
245 }
246 }
247
248 (row_indices, col_indices, values)
249 }
250}
251
252impl<T> SparseArray<T> for DiaArray<T>
253where
254 T: Float
255 + Add<Output = T>
256 + Sub<Output = T>
257 + Mul<Output = T>
258 + Div<Output = T>
259 + Debug
260 + Copy
261 + 'static
262 + std::ops::AddAssign,
263{
264 fn shape(&self) -> (usize, usize) {
265 self.shape
266 }
267
268 fn nnz(&self) -> usize {
269 let (rows, cols) = self.shape;
270 let mut count = 0;
271
272 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
273 let diag = &self.data[diag_idx];
274
275 let length = if offset >= 0 {
277 rows.min(cols.saturating_sub(offset as usize))
278 } else {
279 cols.min(rows.saturating_sub((-offset) as usize))
280 };
281
282 let start_idx = 0; for i in start_idx..start_idx + length {
285 if !diag[i].is_zero() {
286 count += 1;
287 }
288 }
289 }
290
291 count
292 }
293
294 fn dtype(&self) -> &str {
295 "float" }
297
298 fn to_array(&self) -> Array2<T> {
299 let (rows, cols) = self.shape;
301 let mut result = Array2::zeros((rows, cols));
302
303 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
309 let diag = &self.data[diag_idx];
310
311 if offset >= 0 {
312 let offset_usize = offset as usize;
314 for i in 0..rows.min(cols.saturating_sub(offset_usize)) {
315 result[[i, i + offset_usize]] = diag[i];
316 }
317 } else {
318 let offset_usize = (-offset) as usize;
320 for i in 0..cols.min(rows.saturating_sub(offset_usize)) {
321 result[[i + offset_usize, i]] = diag[i];
322 }
323 }
324 }
325
326 result
327 }
328
329 fn toarray(&self) -> Array2<T> {
330 self.to_array()
331 }
332
333 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
334 let (row_indices, col_indices, values) = self.to_coo_internal();
335 let row_array = Array1::from_vec(row_indices);
336 let col_array = Array1::from_vec(col_indices);
337 let data_array = Array1::from_vec(values);
338
339 CooArray::from_triplets(
340 &row_array.to_vec(),
341 &col_array.to_vec(),
342 &data_array.to_vec(),
343 self.shape,
344 false,
345 )
346 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
347 }
348
349 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
350 let (row_indices, col_indices, values) = self.to_coo_internal();
351 CsrArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
352 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
353 }
354
355 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
356 self.to_coo()?.to_csc()
357 }
358
359 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
360 let (row_indices, col_indices, values) = self.to_coo_internal();
361 DokArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
362 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
363 }
364
365 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
366 let (row_indices, col_indices, values) = self.to_coo_internal();
367 LilArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
368 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
369 }
370
371 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
372 Ok(Box::new(self.clone()))
373 }
374
375 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
376 self.to_coo()?.to_bsr()
377 }
378
379 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
380 let csr_self = self.to_csr()?;
382 let csr_other = other.to_csr()?;
383 csr_self.add(&*csr_other)
384 }
385
386 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
387 let csr_self = self.to_csr()?;
389 let csr_other = other.to_csr()?;
390 csr_self.sub(&*csr_other)
391 }
392
393 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
394 let csr_self = self.to_csr()?;
396 let csr_other = other.to_csr()?;
397 csr_self.mul(&*csr_other)
398 }
399
400 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
401 let csr_self = self.to_csr()?;
403 let csr_other = other.to_csr()?;
404 csr_self.div(&*csr_other)
405 }
406
407 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
408 let (_, n) = self.shape();
410 let (p, q) = other.shape();
411
412 if n != p {
413 return Err(SparseError::DimensionMismatch {
414 expected: n,
415 found: p,
416 });
417 }
418
419 if q == 1 {
421 let other_array = other.to_array();
423 let vec_view = other_array.column(0);
424
425 let result = self.dot_vector(&vec_view)?;
427
428 let mut rows = Vec::new();
430 let mut cols = Vec::new();
431 let mut values = Vec::new();
432
433 for (i, &val) in result.iter().enumerate() {
434 if !val.is_zero() {
435 rows.push(i);
436 cols.push(0);
437 values.push(val);
438 }
439 }
440
441 CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
442 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
443 } else {
444 let csr_self = self.to_csr()?;
446 csr_self.dot(other)
447 }
448 }
449
450 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
451 let (rows, cols) = self.shape;
452
453 if cols != other.len() {
454 return Err(SparseError::DimensionMismatch {
455 expected: cols,
456 found: other.len(),
457 });
458 }
459
460 let mut result = Array1::zeros(rows);
461
462 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
463 let diag = &self.data[diag_idx];
464
465 if offset >= 0 {
466 let offset_usize = offset as usize;
468 let length = rows.min(cols.saturating_sub(offset_usize));
469
470 for i in 0..length {
471 result[i] += diag[i] * other[i + offset_usize];
472 }
473 } else {
474 let offset_usize = (-offset) as usize;
476 let length = cols.min(rows.saturating_sub(offset_usize));
477
478 for i in 0..length {
479 result[i + offset_usize] += diag[i] * other[i];
480 }
481 }
482 }
483
484 Ok(result)
485 }
486
487 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
488 let (row_indices, col_indices, values) = self.to_coo_internal();
491
492 let transposed_rows = col_indices;
494 let transposed_cols = row_indices;
495
496 CooArray::from_triplets(
498 &transposed_rows,
499 &transposed_cols,
500 &values,
501 (self.shape.1, self.shape.0),
502 false,
503 )?
504 .to_dia()
505 }
506
507 fn copy(&self) -> Box<dyn SparseArray<T>> {
508 Box::new(self.clone())
509 }
510
511 fn get(&self, i: usize, j: usize) -> T {
512 if i >= self.shape.0 || j >= self.shape.1 {
513 return T::zero();
514 }
515
516 let offset = j as isize - i as isize;
518
519 if let Some(diag_idx) = self.offsets.iter().position(|&o| o == offset) {
521 let diag = &self.data[diag_idx];
522
523 let index = if offset >= 0 { i } else { j };
526
527 if index < diag.len() {
529 return diag[index];
530 }
531 }
532
533 T::zero()
534 }
535
536 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
537 if i >= self.shape.0 || j >= self.shape.1 {
538 return Err(SparseError::IndexOutOfBounds {
539 index: (i, j),
540 shape: self.shape,
541 });
542 }
543
544 let offset = j as isize - i as isize;
546
547 let diag_idx = match self.offsets.iter().position(|&o| o == offset) {
549 Some(idx) => idx,
550 None => {
551 self.offsets.push(offset);
553 self.data
554 .push(Array1::zeros(self.shape.0.max(self.shape.1)));
555
556 let mut offset_data: Vec<(isize, Array1<T>)> = self
558 .offsets
559 .iter()
560 .cloned()
561 .zip(self.data.drain(..))
562 .collect();
563 offset_data.sort_by_key(|&(offset_, _)| offset_);
564
565 self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
566 self.data = offset_data.into_iter().map(|(_, data)| data).collect();
567
568 self.offsets.iter().position(|&o| o == offset).unwrap()
570 }
571 };
572
573 let index = if offset >= 0 { i } else { j };
575 self.data[diag_idx][index] = value;
576
577 Ok(())
578 }
579
580 fn eliminate_zeros(&mut self) {
581 let mut new_offsets = Vec::new();
583 let mut new_data = Vec::new();
584
585 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
586 let diag = &self.data[diag_idx];
587
588 let length = if offset >= 0 {
590 self.shape
591 .0
592 .min(self.shape.1.saturating_sub(offset as usize))
593 } else {
594 self.shape
595 .1
596 .min(self.shape.0.saturating_sub((-offset) as usize))
597 };
598
599 let has_nonzero = (0..length).any(|i| !diag[i].is_zero());
600
601 if has_nonzero {
602 new_offsets.push(offset);
603 new_data.push(diag.clone());
604 }
605 }
606
607 self.offsets = new_offsets;
608 self.data = new_data;
609 }
610
611 fn sort_indices(&mut self) {
612 let mut offset_data: Vec<(isize, Array1<T>)> = self
615 .offsets
616 .iter()
617 .cloned()
618 .zip(self.data.drain(..))
619 .collect();
620 offset_data.sort_by_key(|&(offset_, _)| offset_);
621
622 self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
623 self.data = offset_data.into_iter().map(|(_, data)| data).collect();
624 }
625
626 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
627 let mut result = self.clone();
629 result.sort_indices();
630 Box::new(result)
631 }
632
633 fn has_sorted_indices(&self) -> bool {
634 self.offsets.windows(2).all(|w| w[0] <= w[1])
636 }
637
638 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
639 match axis {
640 None => {
641 let mut total = T::zero();
643
644 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
645 let diag = &self.data[diag_idx];
646
647 let length = if offset >= 0 {
648 self.shape
649 .0
650 .min(self.shape.1.saturating_sub(offset as usize))
651 } else {
652 self.shape
653 .1
654 .min(self.shape.0.saturating_sub((-offset) as usize))
655 };
656
657 for i in 0..length {
658 total += diag[i];
659 }
660 }
661
662 Ok(SparseSum::Scalar(total))
663 }
664 Some(0) => {
665 let mut result = Array1::zeros(self.shape.1);
667
668 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
669 let diag = &self.data[diag_idx];
670
671 if offset >= 0 {
672 let offset_usize = offset as usize;
674 let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
675
676 for i in 0..length {
677 result[i + offset_usize] += diag[i];
678 }
679 } else {
680 let offset_usize = (-offset) as usize;
682 let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
683
684 for i in 0..length {
685 result[i] += diag[i];
686 }
687 }
688 }
689
690 match Array2::from_shape_vec((1, self.shape.1), result.to_vec()) {
692 Ok(result_2d) => {
693 let mut row_indices = Vec::new();
695 let mut col_indices = Vec::new();
696 let mut values = Vec::new();
697
698 for j in 0..self.shape.1 {
699 let val: T = result_2d[[0, j]];
700 if !val.is_zero() {
701 row_indices.push(0);
702 col_indices.push(j);
703 values.push(val);
704 }
705 }
706
707 match CooArray::from_triplets(
709 &row_indices,
710 &col_indices,
711 &values,
712 (1, self.shape.1),
713 false,
714 ) {
715 Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
716 Err(e) => Err(e),
717 }
718 }
719 Err(_) => Err(SparseError::InconsistentData {
720 reason: "Failed to create 2D array from result vector".to_string(),
721 }),
722 }
723 }
724 Some(1) => {
725 let mut result = Array1::zeros(self.shape.0);
727
728 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
729 let diag = &self.data[diag_idx];
730
731 if offset >= 0 {
732 let offset_usize = offset as usize;
734 let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
735
736 for i in 0..length {
737 result[i] += diag[i];
738 }
739 } else {
740 let offset_usize = (-offset) as usize;
742 let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
743
744 for i in 0..length {
745 result[i + offset_usize] += diag[i];
746 }
747 }
748 }
749
750 match Array2::from_shape_vec((self.shape.0, 1), result.to_vec()) {
752 Ok(result_2d) => {
753 let mut row_indices = Vec::new();
755 let mut col_indices = Vec::new();
756 let mut values = Vec::new();
757
758 for i in 0..self.shape.0 {
759 let val: T = result_2d[[i, 0]];
760 if !val.is_zero() {
761 row_indices.push(i);
762 col_indices.push(0);
763 values.push(val);
764 }
765 }
766
767 match CooArray::from_triplets(
769 &row_indices,
770 &col_indices,
771 &values,
772 (self.shape.0, 1),
773 false,
774 ) {
775 Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
776 Err(e) => Err(e),
777 }
778 }
779 Err(_) => Err(SparseError::InconsistentData {
780 reason: "Failed to create 2D array from result vector".to_string(),
781 }),
782 }
783 }
784 _ => Err(SparseError::InvalidAxis),
785 }
786 }
787
788 fn max(&self) -> T {
789 let mut max_val = T::neg_infinity();
790
791 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
792 let diag = &self.data[diag_idx];
793
794 let length = if offset >= 0 {
795 self.shape
796 .0
797 .min(self.shape.1.saturating_sub(offset as usize))
798 } else {
799 self.shape
800 .1
801 .min(self.shape.0.saturating_sub((-offset) as usize))
802 };
803
804 for i in 0..length {
805 max_val = max_val.max(diag[i]);
806 }
807 }
808
809 if max_val == T::neg_infinity() {
811 T::zero()
812 } else {
813 max_val
814 }
815 }
816
817 fn min(&self) -> T {
818 let mut min_val = T::infinity();
819 let mut has_nonzero = false;
820
821 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
822 let diag = &self.data[diag_idx];
823
824 let length = if offset >= 0 {
825 self.shape
826 .0
827 .min(self.shape.1.saturating_sub(offset as usize))
828 } else {
829 self.shape
830 .1
831 .min(self.shape.0.saturating_sub((-offset) as usize))
832 };
833
834 for i in 0..length {
835 if !diag[i].is_zero() {
836 has_nonzero = true;
837 min_val = min_val.min(diag[i]);
838 }
839 }
840 }
841
842 if !has_nonzero {
844 T::zero()
845 } else {
846 min_val
847 }
848 }
849
850 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
851 let (row_indices, col_indices, values) = self.to_coo_internal();
852
853 (
854 Array1::from_vec(row_indices),
855 Array1::from_vec(col_indices),
856 Array1::from_vec(values),
857 )
858 }
859
860 fn slice(
861 &self,
862 row_range: (usize, usize),
863 col_range: (usize, usize),
864 ) -> SparseResult<Box<dyn SparseArray<T>>> {
865 let (start_row, end_row) = row_range;
866 let (start_col, end_col) = col_range;
867 let (rows, cols) = self.shape;
868
869 if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
870 return Err(SparseError::IndexOutOfBounds {
871 index: (start_row.max(end_row), start_col.max(end_col)),
872 shape: (rows, cols),
873 });
874 }
875
876 if start_row >= end_row || start_col >= end_col {
877 return Err(SparseError::InvalidSliceRange);
878 }
879
880 let coo = self.to_coo()?;
882 coo.slice(row_range, col_range)?.to_dia()
883 }
884
885 fn as_any(&self) -> &dyn std::any::Any {
886 self
887 }
888}
889
890impl<T> fmt::Display for DiaArray<T>
892where
893 T: Float
894 + Add<Output = T>
895 + Sub<Output = T>
896 + Mul<Output = T>
897 + Div<Output = T>
898 + Debug
899 + Copy
900 + 'static
901 + std::ops::AddAssign,
902{
903 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
904 writeln!(
905 f,
906 "DiaArray of shape {:?} with {} stored elements",
907 self.shape,
908 self.nnz()
909 )?;
910 writeln!(f, "Offsets: {:?}", self.offsets)?;
911
912 if self.offsets.len() <= 5 {
913 for (i, &offset) in self.offsets.iter().enumerate() {
914 let diag = &self.data[i];
915 let length = if offset >= 0 {
916 self.shape
917 .0
918 .min(self.shape.1.saturating_sub(offset as usize))
919 } else {
920 self.shape
921 .1
922 .min(self.shape.0.saturating_sub((-offset) as usize))
923 };
924
925 write!(f, "Diagonal {offset}: [")?;
926 for j in 0..length.min(10) {
927 if j > 0 {
928 write!(f, ", ")?;
929 }
930 write!(f, "{:?}", diag[j])?;
931 }
932 if length > 10 {
933 write!(f, ", ...")?;
934 }
935 writeln!(f, "]")?;
936 }
937 } else {
938 writeln!(f, "({} diagonals)", self.offsets.len())?;
939 }
940
941 Ok(())
942 }
943}
944
945#[cfg(test)]
946mod tests {
947 use super::*;
948
949 #[test]
950 fn test_dia_array_create() {
951 let data = vec![
953 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), ];
956 let offsets = vec![0, 1]; let shape = (3, 3);
958
959 let array = DiaArray::new(data, offsets, shape).unwrap();
960
961 assert_eq!(array.shape(), (3, 3));
962 assert_eq!(array.nnz(), 5); assert_eq!(array.get(0, 0), 1.0);
966 assert_eq!(array.get(1, 1), 2.0);
967 assert_eq!(array.get(2, 2), 3.0);
968 assert_eq!(array.get(0, 1), 4.0);
969 assert_eq!(array.get(1, 2), 5.0);
970 assert_eq!(array.get(0, 2), 0.0);
971 }
972
973 #[test]
974 fn test_dia_array_from_triplets() {
975 let row = vec![0, 0, 1, 1, 1, 2, 2];
977 let col = vec![0, 1, 0, 1, 2, 1, 2];
978 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0, 6.0, 7.0];
979 let shape = (3, 3);
980
981 let array = DiaArray::from_triplets(&row, &col, &data, shape).unwrap();
982
983 assert_eq!(array.offsets.len(), 3);
985 assert!(array.offsets.contains(&0));
986 assert!(array.offsets.contains(&1));
987 assert!(array.offsets.contains(&-1));
988
989 assert_eq!(array.get(0, 0), 1.0);
991 assert_eq!(array.get(0, 1), 4.0);
992 assert_eq!(array.get(1, 0), 2.0);
993 assert_eq!(array.get(1, 1), 3.0);
994 assert_eq!(array.get(1, 2), 5.0);
995 assert_eq!(array.get(2, 1), 6.0);
996 assert_eq!(array.get(2, 2), 7.0);
997 }
998
999 #[test]
1000 fn test_dia_array_conversion() {
1001 let data = vec![
1003 Array1::from_vec(vec![1.0, 3.0, 7.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 2.0, 0.0]), ];
1007 let offsets = vec![0, 1, -1]; let shape = (3, 3);
1009
1010 let array = DiaArray::new(data, offsets, shape).unwrap();
1011
1012 let coo = array.to_coo().unwrap();
1014 assert_eq!(coo.shape(), (3, 3));
1015 assert_eq!(coo.nnz(), 6); let dense = array.to_array();
1019
1020 let expected =
1024 Array2::from_shape_vec((3, 3), vec![1.0, 4.0, 0.0, 0.0, 3.0, 5.0, 0.0, 2.0, 7.0])
1025 .unwrap();
1026 assert_eq!(dense, expected);
1027 }
1028
1029 #[test]
1030 fn test_dia_array_operations() {
1031 let data1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])]; let offsets1 = vec![0];
1034 let shape1 = (3, 3);
1035 let array1 = DiaArray::new(data1, offsets1, shape1).unwrap();
1036
1037 let data2 = vec![Array1::from_vec(vec![4.0, 5.0, 6.0])]; let offsets2 = vec![0];
1039 let shape2 = (3, 3);
1040 let array2 = DiaArray::new(data2, offsets2, shape2).unwrap();
1041
1042 let sum = array1.add(&array2).unwrap();
1044 assert_eq!(sum.get(0, 0), 5.0);
1045 assert_eq!(sum.get(1, 1), 7.0);
1046 assert_eq!(sum.get(2, 2), 9.0);
1047
1048 let product = array1.mul(&array2).unwrap();
1050 assert_eq!(product.get(0, 0), 4.0);
1051 assert_eq!(product.get(1, 1), 10.0);
1052 assert_eq!(product.get(2, 2), 18.0);
1053
1054 let dot = array1.dot(&array2).unwrap();
1056 assert_eq!(dot.get(0, 0), 4.0);
1057 assert_eq!(dot.get(1, 1), 10.0);
1058 assert_eq!(dot.get(2, 2), 18.0);
1059 }
1060
1061 #[test]
1062 fn test_dia_array_dot_vector() {
1063 let data = vec![
1065 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 6.0, 7.0]), ];
1069 let offsets = vec![0, 1, -1]; let shape = (3, 3);
1071
1072 let array = DiaArray::new(data, offsets, shape).unwrap();
1073
1074 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1076
1077 let result = array.dot_vector(&vector.view()).unwrap();
1079
1080 let expected = Array1::from_vec(vec![9.0, 19.0, 21.0]);
1083 assert_eq!(result, expected);
1084 }
1085
1086 #[test]
1087 fn test_dia_array_transpose() {
1088 let data = vec![
1090 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 6.0, 7.0]), ];
1094 let offsets = vec![0, 1, -1]; let shape = (3, 3);
1096
1097 let array = DiaArray::new(data, offsets, shape).unwrap();
1098 let transposed = array.transpose().unwrap();
1099
1100 assert_eq!(transposed.shape(), (3, 3));
1102
1103 let original_dense = array.to_array();
1105 let transposed_dense = transposed.to_array();
1106
1107 for i in 0..3 {
1108 for j in 0..3 {
1109 assert_eq!(transposed_dense[[i, j]], original_dense[[j, i]]);
1110 }
1111 }
1112 }
1113
1114 #[test]
1115 fn test_dia_array_sum() {
1116 let data = vec![
1118 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), ];
1121 let offsets = vec![0, 1]; let shape = (3, 3);
1123
1124 let array = DiaArray::new(data, offsets, shape).unwrap();
1125
1126 if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1128 assert_eq!(sum, 15.0); } else {
1130 panic!("Expected SparseSum::Scalar");
1131 }
1132
1133 if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1135 assert_eq!(row_sum.shape(), (1, 3));
1136 assert_eq!(row_sum.get(0, 0), 1.0);
1137 assert_eq!(row_sum.get(0, 1), 6.0); assert_eq!(row_sum.get(0, 2), 8.0); } else {
1140 panic!("Expected SparseSum::SparseArray");
1141 }
1142
1143 if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1145 assert_eq!(col_sum.shape(), (3, 1));
1146 assert_eq!(col_sum.get(0, 0), 5.0); assert_eq!(col_sum.get(1, 0), 7.0); assert_eq!(col_sum.get(2, 0), 3.0);
1149 } else {
1150 panic!("Expected SparseSum::SparseArray");
1151 }
1152 }
1153}