1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::dok_array::DokArray;
14use crate::error::{SparseError, SparseResult};
15use crate::lil_array::LilArray;
16use crate::sparray::{SparseArray, SparseSum};
17
18#[derive(Clone)]
32pub struct DiaArray<T>
33where
34 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
35{
36 data: Vec<Array1<T>>,
38 offsets: Vec<isize>,
40 shape: (usize, usize),
42}
43
44impl<T> DiaArray<T>
45where
46 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
47{
48 pub fn new(
80 data: Vec<Array1<T>>,
81 offsets: Vec<isize>,
82 shape: (usize, usize),
83 ) -> SparseResult<Self> {
84 let (rows, cols) = shape;
85 let max_dim = rows.max(cols);
86
87 if data.len() != offsets.len() {
89 return Err(SparseError::DimensionMismatch {
90 expected: data.len(),
91 found: offsets.len(),
92 });
93 }
94
95 for diag in data.iter() {
96 if diag.len() != max_dim {
97 return Err(SparseError::DimensionMismatch {
98 expected: max_dim,
99 found: diag.len(),
100 });
101 }
102 }
103
104 Ok(DiaArray {
105 data,
106 offsets,
107 shape,
108 })
109 }
110
111 pub fn empty(shape: (usize, usize)) -> Self {
121 DiaArray {
122 data: Vec::new(),
123 offsets: Vec::new(),
124 shape,
125 }
126 }
127
128 pub fn from_triplets(
141 row: &[usize],
142 col: &[usize],
143 data: &[T],
144 shape: (usize, usize),
145 ) -> SparseResult<Self> {
146 if row.len() != col.len() || row.len() != data.len() {
147 return Err(SparseError::InconsistentData {
148 reason: "Lengths of row, col, and data arrays must be equal".to_string(),
149 });
150 }
151
152 let (rows, cols) = shape;
153 let max_dim = rows.max(cols);
154
155 let mut diagonal_offsets = std::collections::HashSet::new();
157 for (&r, &c) in row.iter().zip(col.iter()) {
158 if r >= rows || c >= cols {
159 return Err(SparseError::IndexOutOfBounds {
160 index: (r, c),
161 shape,
162 });
163 }
164 let offset = c as isize - r as isize;
166 diagonal_offsets.insert(offset);
167 }
168
169 let mut offsets: Vec<isize> = diagonal_offsets.into_iter().collect();
171 offsets.sort();
172
173 let mut diag_data = Vec::with_capacity(offsets.len());
175 for _ in 0..offsets.len() {
176 diag_data.push(Array1::zeros(max_dim));
177 }
178
179 for (&r, (&c, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
181 let offset = c as isize - r as isize;
182 let diag_idx = offsets.iter().position(|&o| o == offset).unwrap();
183
184 let index = if offset >= 0 { r } else { c };
187 diag_data[diag_idx][index] = val;
188 }
189
190 DiaArray::new(diag_data, offsets, shape)
191 }
192
193 fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
195 let (rows, cols) = self.shape;
196 let mut row_indices = Vec::new();
197 let mut col_indices = Vec::new();
198 let mut values = Vec::new();
199
200 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
201 let diag = &self.data[diag_idx];
202
203 if offset >= 0 {
204 let offset_usize = offset as usize;
206 let length = rows.min(cols.saturating_sub(offset_usize));
207
208 for i in 0..length {
209 let value = diag[i];
210 if !SparseElement::is_zero(&value) {
211 row_indices.push(i);
212 col_indices.push(i + offset_usize);
213 values.push(value);
214 }
215 }
216 } else {
217 let offset_usize = (-offset) as usize;
219 let length = cols.min(rows.saturating_sub(offset_usize));
220
221 for i in 0..length {
222 let value = diag[i];
223 if !SparseElement::is_zero(&value) {
224 row_indices.push(i + offset_usize);
225 col_indices.push(i);
226 values.push(value);
227 }
228 }
229 }
230 }
231
232 (row_indices, col_indices, values)
233 }
234}
235
236impl<T> SparseArray<T> for DiaArray<T>
237where
238 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
239{
240 fn shape(&self) -> (usize, usize) {
241 self.shape
242 }
243
244 fn nnz(&self) -> usize {
245 let (rows, cols) = self.shape;
246 let mut count = 0;
247
248 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
249 let diag = &self.data[diag_idx];
250
251 let length = if offset >= 0 {
253 rows.min(cols.saturating_sub(offset as usize))
254 } else {
255 cols.min(rows.saturating_sub((-offset) as usize))
256 };
257
258 let start_idx = 0; for i in start_idx..start_idx + length {
261 if !SparseElement::is_zero(&diag[i]) {
262 count += 1;
263 }
264 }
265 }
266
267 count
268 }
269
270 fn dtype(&self) -> &str {
271 "float" }
273
274 fn to_array(&self) -> Array2<T> {
275 let (rows, cols) = self.shape;
277 let mut result = Array2::zeros((rows, cols));
278
279 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
285 let diag = &self.data[diag_idx];
286
287 if offset >= 0 {
288 let offset_usize = offset as usize;
290 for i in 0..rows.min(cols.saturating_sub(offset_usize)) {
291 result[[i, i + offset_usize]] = diag[i];
292 }
293 } else {
294 let offset_usize = (-offset) as usize;
296 for i in 0..cols.min(rows.saturating_sub(offset_usize)) {
297 result[[i + offset_usize, i]] = diag[i];
298 }
299 }
300 }
301
302 result
303 }
304
305 fn toarray(&self) -> Array2<T> {
306 self.to_array()
307 }
308
309 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
310 let (row_indices, col_indices, values) = self.to_coo_internal();
311 let row_array = Array1::from_vec(row_indices);
312 let col_array = Array1::from_vec(col_indices);
313 let data_array = Array1::from_vec(values);
314
315 CooArray::from_triplets(
316 &row_array.to_vec(),
317 &col_array.to_vec(),
318 &data_array.to_vec(),
319 self.shape,
320 false,
321 )
322 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
323 }
324
325 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
326 let (row_indices, col_indices, values) = self.to_coo_internal();
327 CsrArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
328 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
329 }
330
331 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
332 self.to_coo()?.to_csc()
333 }
334
335 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
336 let (row_indices, col_indices, values) = self.to_coo_internal();
337 DokArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
338 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
339 }
340
341 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
342 let (row_indices, col_indices, values) = self.to_coo_internal();
343 LilArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
344 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
345 }
346
347 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
348 Ok(Box::new(self.clone()))
349 }
350
351 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
352 self.to_coo()?.to_bsr()
353 }
354
355 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
356 let csr_self = self.to_csr()?;
358 let csr_other = other.to_csr()?;
359 csr_self.add(&*csr_other)
360 }
361
362 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
363 let csr_self = self.to_csr()?;
365 let csr_other = other.to_csr()?;
366 csr_self.sub(&*csr_other)
367 }
368
369 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
370 let csr_self = self.to_csr()?;
372 let csr_other = other.to_csr()?;
373 csr_self.mul(&*csr_other)
374 }
375
376 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
377 let csr_self = self.to_csr()?;
379 let csr_other = other.to_csr()?;
380 csr_self.div(&*csr_other)
381 }
382
383 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
384 let (_, n) = self.shape();
386 let (p, q) = other.shape();
387
388 if n != p {
389 return Err(SparseError::DimensionMismatch {
390 expected: n,
391 found: p,
392 });
393 }
394
395 if q == 1 {
397 let other_array = other.to_array();
399 let vec_view = other_array.column(0);
400
401 let result = self.dot_vector(&vec_view)?;
403
404 let mut rows = Vec::new();
406 let mut cols = Vec::new();
407 let mut values = Vec::new();
408
409 for (i, &val) in result.iter().enumerate() {
410 if !SparseElement::is_zero(&val) {
411 rows.push(i);
412 cols.push(0);
413 values.push(val);
414 }
415 }
416
417 CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
418 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
419 } else {
420 let csr_self = self.to_csr()?;
422 csr_self.dot(other)
423 }
424 }
425
426 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
427 let (rows, cols) = self.shape;
428
429 if cols != other.len() {
430 return Err(SparseError::DimensionMismatch {
431 expected: cols,
432 found: other.len(),
433 });
434 }
435
436 let mut result = Array1::zeros(rows);
437
438 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
439 let diag = &self.data[diag_idx];
440
441 if offset >= 0 {
442 let offset_usize = offset as usize;
444 let length = rows.min(cols.saturating_sub(offset_usize));
445
446 for i in 0..length {
447 result[i] += diag[i] * other[i + offset_usize];
448 }
449 } else {
450 let offset_usize = (-offset) as usize;
452 let length = cols.min(rows.saturating_sub(offset_usize));
453
454 for i in 0..length {
455 result[i + offset_usize] += diag[i] * other[i];
456 }
457 }
458 }
459
460 Ok(result)
461 }
462
463 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
464 let (row_indices, col_indices, values) = self.to_coo_internal();
467
468 let transposed_rows = col_indices;
470 let transposed_cols = row_indices;
471
472 CooArray::from_triplets(
474 &transposed_rows,
475 &transposed_cols,
476 &values,
477 (self.shape.1, self.shape.0),
478 false,
479 )?
480 .to_dia()
481 }
482
483 fn copy(&self) -> Box<dyn SparseArray<T>> {
484 Box::new(self.clone())
485 }
486
487 fn get(&self, i: usize, j: usize) -> T {
488 if i >= self.shape.0 || j >= self.shape.1 {
489 return T::sparse_zero();
490 }
491
492 let offset = j as isize - i as isize;
494
495 if let Some(diag_idx) = self.offsets.iter().position(|&o| o == offset) {
497 let diag = &self.data[diag_idx];
498
499 let index = if offset >= 0 { i } else { j };
502
503 if index < diag.len() {
505 return diag[index];
506 }
507 }
508
509 T::sparse_zero()
510 }
511
512 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
513 if i >= self.shape.0 || j >= self.shape.1 {
514 return Err(SparseError::IndexOutOfBounds {
515 index: (i, j),
516 shape: self.shape,
517 });
518 }
519
520 let offset = j as isize - i as isize;
522
523 let diag_idx = match self.offsets.iter().position(|&o| o == offset) {
525 Some(idx) => idx,
526 None => {
527 self.offsets.push(offset);
529 self.data
530 .push(Array1::zeros(self.shape.0.max(self.shape.1)));
531
532 let mut offset_data: Vec<(isize, Array1<T>)> = self
534 .offsets
535 .iter()
536 .cloned()
537 .zip(self.data.drain(..))
538 .collect();
539 offset_data.sort_by_key(|&(offset_, _)| offset_);
540
541 self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
542 self.data = offset_data.into_iter().map(|(_, data)| data).collect();
543
544 self.offsets.iter().position(|&o| o == offset).unwrap()
546 }
547 };
548
549 let index = if offset >= 0 { i } else { j };
551 self.data[diag_idx][index] = value;
552
553 Ok(())
554 }
555
556 fn eliminate_zeros(&mut self) {
557 let mut new_offsets = Vec::new();
559 let mut new_data = Vec::new();
560
561 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
562 let diag = &self.data[diag_idx];
563
564 let length = if offset >= 0 {
566 self.shape
567 .0
568 .min(self.shape.1.saturating_sub(offset as usize))
569 } else {
570 self.shape
571 .1
572 .min(self.shape.0.saturating_sub((-offset) as usize))
573 };
574
575 let has_nonzero = (0..length).any(|i| !SparseElement::is_zero(&diag[i]));
576
577 if has_nonzero {
578 new_offsets.push(offset);
579 new_data.push(diag.clone());
580 }
581 }
582
583 self.offsets = new_offsets;
584 self.data = new_data;
585 }
586
587 fn sort_indices(&mut self) {
588 let mut offset_data: Vec<(isize, Array1<T>)> = self
591 .offsets
592 .iter()
593 .cloned()
594 .zip(self.data.drain(..))
595 .collect();
596 offset_data.sort_by_key(|&(offset_, _)| offset_);
597
598 self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
599 self.data = offset_data.into_iter().map(|(_, data)| data).collect();
600 }
601
602 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
603 let mut result = self.clone();
605 result.sort_indices();
606 Box::new(result)
607 }
608
609 fn has_sorted_indices(&self) -> bool {
610 self.offsets.windows(2).all(|w| w[0] <= w[1])
612 }
613
614 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
615 match axis {
616 None => {
617 let mut total = T::sparse_zero();
619
620 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
621 let diag = &self.data[diag_idx];
622
623 let length = if offset >= 0 {
624 self.shape
625 .0
626 .min(self.shape.1.saturating_sub(offset as usize))
627 } else {
628 self.shape
629 .1
630 .min(self.shape.0.saturating_sub((-offset) as usize))
631 };
632
633 for i in 0..length {
634 total += diag[i];
635 }
636 }
637
638 Ok(SparseSum::Scalar(total))
639 }
640 Some(0) => {
641 let mut result = Array1::zeros(self.shape.1);
643
644 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
645 let diag = &self.data[diag_idx];
646
647 if offset >= 0 {
648 let offset_usize = offset as usize;
650 let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
651
652 for i in 0..length {
653 result[i + offset_usize] += diag[i];
654 }
655 } else {
656 let offset_usize = (-offset) as usize;
658 let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
659
660 for i in 0..length {
661 result[i] += diag[i];
662 }
663 }
664 }
665
666 match Array2::from_shape_vec((1, self.shape.1), result.to_vec()) {
668 Ok(result_2d) => {
669 let mut row_indices = Vec::new();
671 let mut col_indices = Vec::new();
672 let mut values = Vec::new();
673
674 for j in 0..self.shape.1 {
675 let val: T = result_2d[[0, j]];
676 if !SparseElement::is_zero(&val) {
677 row_indices.push(0);
678 col_indices.push(j);
679 values.push(val);
680 }
681 }
682
683 match CooArray::from_triplets(
685 &row_indices,
686 &col_indices,
687 &values,
688 (1, self.shape.1),
689 false,
690 ) {
691 Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
692 Err(e) => Err(e),
693 }
694 }
695 Err(_) => Err(SparseError::InconsistentData {
696 reason: "Failed to create 2D array from result vector".to_string(),
697 }),
698 }
699 }
700 Some(1) => {
701 let mut result = Array1::zeros(self.shape.0);
703
704 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
705 let diag = &self.data[diag_idx];
706
707 if offset >= 0 {
708 let offset_usize = offset as usize;
710 let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
711
712 for i in 0..length {
713 result[i] += diag[i];
714 }
715 } else {
716 let offset_usize = (-offset) as usize;
718 let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
719
720 for i in 0..length {
721 result[i + offset_usize] += diag[i];
722 }
723 }
724 }
725
726 match Array2::from_shape_vec((self.shape.0, 1), result.to_vec()) {
728 Ok(result_2d) => {
729 let mut row_indices = Vec::new();
731 let mut col_indices = Vec::new();
732 let mut values = Vec::new();
733
734 for i in 0..self.shape.0 {
735 let val: T = result_2d[[i, 0]];
736 if !SparseElement::is_zero(&val) {
737 row_indices.push(i);
738 col_indices.push(0);
739 values.push(val);
740 }
741 }
742
743 match CooArray::from_triplets(
745 &row_indices,
746 &col_indices,
747 &values,
748 (self.shape.0, 1),
749 false,
750 ) {
751 Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
752 Err(e) => Err(e),
753 }
754 }
755 Err(_) => Err(SparseError::InconsistentData {
756 reason: "Failed to create 2D array from result vector".to_string(),
757 }),
758 }
759 }
760 _ => Err(SparseError::InvalidAxis),
761 }
762 }
763
764 fn max(&self) -> T {
765 let mut max_val = T::neg_infinity();
766
767 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
768 let diag = &self.data[diag_idx];
769
770 let length = if offset >= 0 {
771 self.shape
772 .0
773 .min(self.shape.1.saturating_sub(offset as usize))
774 } else {
775 self.shape
776 .1
777 .min(self.shape.0.saturating_sub((-offset) as usize))
778 };
779
780 for i in 0..length {
781 max_val = max_val.max(diag[i]);
782 }
783 }
784
785 if max_val == T::neg_infinity() {
787 T::sparse_zero()
788 } else {
789 max_val
790 }
791 }
792
793 fn min(&self) -> T {
794 let mut min_val = T::sparse_zero();
795 let mut has_nonzero = false;
796
797 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
798 let diag = &self.data[diag_idx];
799
800 let length = if offset >= 0 {
801 self.shape
802 .0
803 .min(self.shape.1.saturating_sub(offset as usize))
804 } else {
805 self.shape
806 .1
807 .min(self.shape.0.saturating_sub((-offset) as usize))
808 };
809
810 for i in 0..length {
811 if !SparseElement::is_zero(&diag[i]) {
812 has_nonzero = true;
813 min_val = min_val.min(diag[i]);
814 }
815 }
816 }
817
818 if !has_nonzero {
820 T::sparse_zero()
821 } else {
822 min_val
823 }
824 }
825
826 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
827 let (row_indices, col_indices, values) = self.to_coo_internal();
828
829 (
830 Array1::from_vec(row_indices),
831 Array1::from_vec(col_indices),
832 Array1::from_vec(values),
833 )
834 }
835
836 fn slice(
837 &self,
838 row_range: (usize, usize),
839 col_range: (usize, usize),
840 ) -> SparseResult<Box<dyn SparseArray<T>>> {
841 let (start_row, end_row) = row_range;
842 let (start_col, end_col) = col_range;
843 let (rows, cols) = self.shape;
844
845 if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
846 return Err(SparseError::IndexOutOfBounds {
847 index: (start_row.max(end_row), start_col.max(end_col)),
848 shape: (rows, cols),
849 });
850 }
851
852 if start_row >= end_row || start_col >= end_col {
853 return Err(SparseError::InvalidSliceRange);
854 }
855
856 let coo = self.to_coo()?;
858 coo.slice(row_range, col_range)?.to_dia()
859 }
860
861 fn as_any(&self) -> &dyn std::any::Any {
862 self
863 }
864}
865
866impl<T> fmt::Display for DiaArray<T>
868where
869 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
870{
871 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
872 writeln!(
873 f,
874 "DiaArray of shape {:?} with {} stored elements",
875 self.shape,
876 self.nnz()
877 )?;
878 writeln!(f, "Offsets: {:?}", self.offsets)?;
879
880 if self.offsets.len() <= 5 {
881 for (i, &offset) in self.offsets.iter().enumerate() {
882 let diag = &self.data[i];
883 let length = if offset >= 0 {
884 self.shape
885 .0
886 .min(self.shape.1.saturating_sub(offset as usize))
887 } else {
888 self.shape
889 .1
890 .min(self.shape.0.saturating_sub((-offset) as usize))
891 };
892
893 write!(f, "Diagonal {offset}: [")?;
894 for j in 0..length.min(10) {
895 if j > 0 {
896 write!(f, ", ")?;
897 }
898 write!(f, "{:?}", diag[j])?;
899 }
900 if length > 10 {
901 write!(f, ", ...")?;
902 }
903 writeln!(f, "]")?;
904 }
905 } else {
906 writeln!(f, "({} diagonals)", self.offsets.len())?;
907 }
908
909 Ok(())
910 }
911}
912
913#[cfg(test)]
914mod tests {
915 use super::*;
916
917 #[test]
918 fn test_dia_array_create() {
919 let data = vec![
921 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), ];
924 let offsets = vec![0, 1]; let shape = (3, 3);
926
927 let array = DiaArray::new(data, offsets, shape).unwrap();
928
929 assert_eq!(array.shape(), (3, 3));
930 assert_eq!(array.nnz(), 5); assert_eq!(array.get(0, 0), 1.0);
934 assert_eq!(array.get(1, 1), 2.0);
935 assert_eq!(array.get(2, 2), 3.0);
936 assert_eq!(array.get(0, 1), 4.0);
937 assert_eq!(array.get(1, 2), 5.0);
938 assert_eq!(array.get(0, 2), 0.0);
939 }
940
941 #[test]
942 fn test_dia_array_from_triplets() {
943 let row = vec![0, 0, 1, 1, 1, 2, 2];
945 let col = vec![0, 1, 0, 1, 2, 1, 2];
946 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0, 6.0, 7.0];
947 let shape = (3, 3);
948
949 let array = DiaArray::from_triplets(&row, &col, &data, shape).unwrap();
950
951 assert_eq!(array.offsets.len(), 3);
953 assert!(array.offsets.contains(&0));
954 assert!(array.offsets.contains(&1));
955 assert!(array.offsets.contains(&-1));
956
957 assert_eq!(array.get(0, 0), 1.0);
959 assert_eq!(array.get(0, 1), 4.0);
960 assert_eq!(array.get(1, 0), 2.0);
961 assert_eq!(array.get(1, 1), 3.0);
962 assert_eq!(array.get(1, 2), 5.0);
963 assert_eq!(array.get(2, 1), 6.0);
964 assert_eq!(array.get(2, 2), 7.0);
965 }
966
967 #[test]
968 fn test_dia_array_conversion() {
969 let data = vec![
971 Array1::from_vec(vec![1.0, 3.0, 7.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 2.0, 0.0]), ];
975 let offsets = vec![0, 1, -1]; let shape = (3, 3);
977
978 let array = DiaArray::new(data, offsets, shape).unwrap();
979
980 let coo = array.to_coo().unwrap();
982 assert_eq!(coo.shape(), (3, 3));
983 assert_eq!(coo.nnz(), 6); let dense = array.to_array();
987
988 let expected =
991 Array2::from_shape_vec((3, 3), vec![1.0, 4.0, 0.0, 0.0, 3.0, 5.0, 0.0, 2.0, 7.0])
992 .unwrap();
993 assert_eq!(dense, expected);
994 }
995
996 #[test]
997 fn test_dia_array_operations() {
998 let data1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])]; let offsets1 = vec![0];
1001 let shape1 = (3, 3);
1002 let array1 = DiaArray::new(data1, offsets1, shape1).unwrap();
1003
1004 let data2 = vec![Array1::from_vec(vec![4.0, 5.0, 6.0])]; let offsets2 = vec![0];
1006 let shape2 = (3, 3);
1007 let array2 = DiaArray::new(data2, offsets2, shape2).unwrap();
1008
1009 let sum = array1.add(&array2).unwrap();
1011 assert_eq!(sum.get(0, 0), 5.0);
1012 assert_eq!(sum.get(1, 1), 7.0);
1013 assert_eq!(sum.get(2, 2), 9.0);
1014
1015 let product = array1.mul(&array2).unwrap();
1017 assert_eq!(product.get(0, 0), 4.0);
1018 assert_eq!(product.get(1, 1), 10.0);
1019 assert_eq!(product.get(2, 2), 18.0);
1020
1021 let dot = array1.dot(&array2).unwrap();
1023 assert_eq!(dot.get(0, 0), 4.0);
1024 assert_eq!(dot.get(1, 1), 10.0);
1025 assert_eq!(dot.get(2, 2), 18.0);
1026 }
1027
1028 #[test]
1029 fn test_dia_array_dot_vector() {
1030 let data = vec![
1032 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 6.0, 7.0]), ];
1036 let offsets = vec![0, 1, -1]; let shape = (3, 3);
1038
1039 let array = DiaArray::new(data, offsets, shape).unwrap();
1040
1041 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1043
1044 let result = array.dot_vector(&vector.view()).unwrap();
1046
1047 let expected = Array1::from_vec(vec![9.0, 19.0, 21.0]);
1050 assert_eq!(result, expected);
1051 }
1052
1053 #[test]
1054 fn test_dia_array_transpose() {
1055 let data = vec![
1057 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 6.0, 7.0]), ];
1061 let offsets = vec![0, 1, -1]; let shape = (3, 3);
1063
1064 let array = DiaArray::new(data, offsets, shape).unwrap();
1065 let transposed = array.transpose().unwrap();
1066
1067 assert_eq!(transposed.shape(), (3, 3));
1069
1070 let original_dense = array.to_array();
1072 let transposed_dense = transposed.to_array();
1073
1074 for i in 0..3 {
1075 for j in 0..3 {
1076 assert_eq!(transposed_dense[[i, j]], original_dense[[j, i]]);
1077 }
1078 }
1079 }
1080
1081 #[test]
1082 fn test_dia_array_sum() {
1083 let data = vec![
1085 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), ];
1088 let offsets = vec![0, 1]; let shape = (3, 3);
1090
1091 let array = DiaArray::new(data, offsets, shape).unwrap();
1092
1093 if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1095 assert_eq!(sum, 15.0); } else {
1097 panic!("Expected SparseSum::Scalar");
1098 }
1099
1100 if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1102 assert_eq!(row_sum.shape(), (1, 3));
1103 assert_eq!(row_sum.get(0, 0), 1.0);
1104 assert_eq!(row_sum.get(0, 1), 6.0); assert_eq!(row_sum.get(0, 2), 8.0); } else {
1107 panic!("Expected SparseSum::SparseArray");
1108 }
1109
1110 if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1112 assert_eq!(col_sum.shape(), (3, 1));
1113 assert_eq!(col_sum.get(0, 0), 5.0); assert_eq!(col_sum.get(1, 0), 7.0); assert_eq!(col_sum.get(2, 0), 3.0);
1116 } else {
1117 panic!("Expected SparseSum::SparseArray");
1118 }
1119 }
1120}