1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::dok_array::DokArray;
14use crate::error::{SparseError, SparseResult};
15use crate::lil_array::LilArray;
16use crate::sparray::{SparseArray, SparseSum};
17
18#[derive(Clone)]
32pub struct DiaArray<T>
33where
34 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
35{
36 data: Vec<Array1<T>>,
38 offsets: Vec<isize>,
40 shape: (usize, usize),
42}
43
44impl<T> DiaArray<T>
45where
46 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
47{
48 pub fn new(
80 data: Vec<Array1<T>>,
81 offsets: Vec<isize>,
82 shape: (usize, usize),
83 ) -> SparseResult<Self> {
84 let (rows, cols) = shape;
85 let max_dim = rows.max(cols);
86
87 if data.len() != offsets.len() {
89 return Err(SparseError::DimensionMismatch {
90 expected: data.len(),
91 found: offsets.len(),
92 });
93 }
94
95 for diag in data.iter() {
96 if diag.len() != max_dim {
97 return Err(SparseError::DimensionMismatch {
98 expected: max_dim,
99 found: diag.len(),
100 });
101 }
102 }
103
104 Ok(DiaArray {
105 data,
106 offsets,
107 shape,
108 })
109 }
110
111 pub fn empty(shape: (usize, usize)) -> Self {
121 DiaArray {
122 data: Vec::new(),
123 offsets: Vec::new(),
124 shape,
125 }
126 }
127
128 pub fn from_triplets(
141 row: &[usize],
142 col: &[usize],
143 data: &[T],
144 shape: (usize, usize),
145 ) -> SparseResult<Self> {
146 if row.len() != col.len() || row.len() != data.len() {
147 return Err(SparseError::InconsistentData {
148 reason: "Lengths of row, col, and data arrays must be equal".to_string(),
149 });
150 }
151
152 let (rows, cols) = shape;
153 let max_dim = rows.max(cols);
154
155 let mut diagonal_offsets = std::collections::HashSet::new();
157 for (&r, &c) in row.iter().zip(col.iter()) {
158 if r >= rows || c >= cols {
159 return Err(SparseError::IndexOutOfBounds {
160 index: (r, c),
161 shape,
162 });
163 }
164 let offset = c as isize - r as isize;
166 diagonal_offsets.insert(offset);
167 }
168
169 let mut offsets: Vec<isize> = diagonal_offsets.into_iter().collect();
171 offsets.sort();
172
173 let mut diag_data = Vec::with_capacity(offsets.len());
175 for _ in 0..offsets.len() {
176 diag_data.push(Array1::zeros(max_dim));
177 }
178
179 for (&r, (&c, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
181 let offset = c as isize - r as isize;
182 let diag_idx = offsets
183 .iter()
184 .position(|&o| o == offset)
185 .expect("Operation failed");
186
187 let index = if offset >= 0 { r } else { c };
190 diag_data[diag_idx][index] = val;
191 }
192
193 DiaArray::new(diag_data, offsets, shape)
194 }
195
196 fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
198 let (rows, cols) = self.shape;
199 let mut row_indices = Vec::new();
200 let mut col_indices = Vec::new();
201 let mut values = Vec::new();
202
203 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
204 let diag = &self.data[diag_idx];
205
206 if offset >= 0 {
207 let offset_usize = offset as usize;
209 let length = rows.min(cols.saturating_sub(offset_usize));
210
211 for i in 0..length {
212 let value = diag[i];
213 if !SparseElement::is_zero(&value) {
214 row_indices.push(i);
215 col_indices.push(i + offset_usize);
216 values.push(value);
217 }
218 }
219 } else {
220 let offset_usize = (-offset) as usize;
222 let length = cols.min(rows.saturating_sub(offset_usize));
223
224 for i in 0..length {
225 let value = diag[i];
226 if !SparseElement::is_zero(&value) {
227 row_indices.push(i + offset_usize);
228 col_indices.push(i);
229 values.push(value);
230 }
231 }
232 }
233 }
234
235 (row_indices, col_indices, values)
236 }
237}
238
239impl<T> SparseArray<T> for DiaArray<T>
240where
241 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
242{
243 fn shape(&self) -> (usize, usize) {
244 self.shape
245 }
246
247 fn nnz(&self) -> usize {
248 let (rows, cols) = self.shape;
249 let mut count = 0;
250
251 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
252 let diag = &self.data[diag_idx];
253
254 let length = if offset >= 0 {
256 rows.min(cols.saturating_sub(offset as usize))
257 } else {
258 cols.min(rows.saturating_sub((-offset) as usize))
259 };
260
261 let start_idx = 0; for i in start_idx..start_idx + length {
264 if !SparseElement::is_zero(&diag[i]) {
265 count += 1;
266 }
267 }
268 }
269
270 count
271 }
272
273 fn dtype(&self) -> &str {
274 "float" }
276
277 fn to_array(&self) -> Array2<T> {
278 let (rows, cols) = self.shape;
280 let mut result = Array2::zeros((rows, cols));
281
282 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
288 let diag = &self.data[diag_idx];
289
290 if offset >= 0 {
291 let offset_usize = offset as usize;
293 for i in 0..rows.min(cols.saturating_sub(offset_usize)) {
294 result[[i, i + offset_usize]] = diag[i];
295 }
296 } else {
297 let offset_usize = (-offset) as usize;
299 for i in 0..cols.min(rows.saturating_sub(offset_usize)) {
300 result[[i + offset_usize, i]] = diag[i];
301 }
302 }
303 }
304
305 result
306 }
307
308 fn toarray(&self) -> Array2<T> {
309 self.to_array()
310 }
311
312 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
313 let (row_indices, col_indices, values) = self.to_coo_internal();
314 let row_array = Array1::from_vec(row_indices);
315 let col_array = Array1::from_vec(col_indices);
316 let data_array = Array1::from_vec(values);
317
318 CooArray::from_triplets(
319 &row_array.to_vec(),
320 &col_array.to_vec(),
321 &data_array.to_vec(),
322 self.shape,
323 false,
324 )
325 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
326 }
327
328 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
329 let (row_indices, col_indices, values) = self.to_coo_internal();
330 CsrArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
331 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
332 }
333
334 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
335 self.to_coo()?.to_csc()
336 }
337
338 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
339 let (row_indices, col_indices, values) = self.to_coo_internal();
340 DokArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
341 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
342 }
343
344 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
345 let (row_indices, col_indices, values) = self.to_coo_internal();
346 LilArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
347 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
348 }
349
350 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
351 Ok(Box::new(self.clone()))
352 }
353
354 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
355 self.to_coo()?.to_bsr()
356 }
357
358 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
359 let csr_self = self.to_csr()?;
361 let csr_other = other.to_csr()?;
362 csr_self.add(&*csr_other)
363 }
364
365 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
366 let csr_self = self.to_csr()?;
368 let csr_other = other.to_csr()?;
369 csr_self.sub(&*csr_other)
370 }
371
372 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
373 let csr_self = self.to_csr()?;
375 let csr_other = other.to_csr()?;
376 csr_self.mul(&*csr_other)
377 }
378
379 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
380 let csr_self = self.to_csr()?;
382 let csr_other = other.to_csr()?;
383 csr_self.div(&*csr_other)
384 }
385
386 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
387 let (_, n) = self.shape();
389 let (p, q) = other.shape();
390
391 if n != p {
392 return Err(SparseError::DimensionMismatch {
393 expected: n,
394 found: p,
395 });
396 }
397
398 if q == 1 {
400 let other_array = other.to_array();
402 let vec_view = other_array.column(0);
403
404 let result = self.dot_vector(&vec_view)?;
406
407 let mut rows = Vec::new();
409 let mut cols = Vec::new();
410 let mut values = Vec::new();
411
412 for (i, &val) in result.iter().enumerate() {
413 if !SparseElement::is_zero(&val) {
414 rows.push(i);
415 cols.push(0);
416 values.push(val);
417 }
418 }
419
420 CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
421 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
422 } else {
423 let csr_self = self.to_csr()?;
425 csr_self.dot(other)
426 }
427 }
428
429 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
430 let (rows, cols) = self.shape;
431
432 if cols != other.len() {
433 return Err(SparseError::DimensionMismatch {
434 expected: cols,
435 found: other.len(),
436 });
437 }
438
439 let mut result = Array1::zeros(rows);
440
441 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
442 let diag = &self.data[diag_idx];
443
444 if offset >= 0 {
445 let offset_usize = offset as usize;
447 let length = rows.min(cols.saturating_sub(offset_usize));
448
449 for i in 0..length {
450 result[i] += diag[i] * other[i + offset_usize];
451 }
452 } else {
453 let offset_usize = (-offset) as usize;
455 let length = cols.min(rows.saturating_sub(offset_usize));
456
457 for i in 0..length {
458 result[i + offset_usize] += diag[i] * other[i];
459 }
460 }
461 }
462
463 Ok(result)
464 }
465
466 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
467 let (row_indices, col_indices, values) = self.to_coo_internal();
470
471 let transposed_rows = col_indices;
473 let transposed_cols = row_indices;
474
475 CooArray::from_triplets(
477 &transposed_rows,
478 &transposed_cols,
479 &values,
480 (self.shape.1, self.shape.0),
481 false,
482 )?
483 .to_dia()
484 }
485
486 fn copy(&self) -> Box<dyn SparseArray<T>> {
487 Box::new(self.clone())
488 }
489
490 fn get(&self, i: usize, j: usize) -> T {
491 if i >= self.shape.0 || j >= self.shape.1 {
492 return T::sparse_zero();
493 }
494
495 let offset = j as isize - i as isize;
497
498 if let Some(diag_idx) = self.offsets.iter().position(|&o| o == offset) {
500 let diag = &self.data[diag_idx];
501
502 let index = if offset >= 0 { i } else { j };
505
506 if index < diag.len() {
508 return diag[index];
509 }
510 }
511
512 T::sparse_zero()
513 }
514
515 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
516 if i >= self.shape.0 || j >= self.shape.1 {
517 return Err(SparseError::IndexOutOfBounds {
518 index: (i, j),
519 shape: self.shape,
520 });
521 }
522
523 let offset = j as isize - i as isize;
525
526 let diag_idx = match self.offsets.iter().position(|&o| o == offset) {
528 Some(idx) => idx,
529 None => {
530 self.offsets.push(offset);
532 self.data
533 .push(Array1::zeros(self.shape.0.max(self.shape.1)));
534
535 let mut offset_data: Vec<(isize, Array1<T>)> = self
537 .offsets
538 .iter()
539 .cloned()
540 .zip(self.data.drain(..))
541 .collect();
542 offset_data.sort_by_key(|&(offset_, _)| offset_);
543
544 self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
545 self.data = offset_data.into_iter().map(|(_, data)| data).collect();
546
547 self.offsets
549 .iter()
550 .position(|&o| o == offset)
551 .expect("Operation failed")
552 }
553 };
554
555 let index = if offset >= 0 { i } else { j };
557 self.data[diag_idx][index] = value;
558
559 Ok(())
560 }
561
562 fn eliminate_zeros(&mut self) {
563 let mut new_offsets = Vec::new();
565 let mut new_data = Vec::new();
566
567 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
568 let diag = &self.data[diag_idx];
569
570 let length = if offset >= 0 {
572 self.shape
573 .0
574 .min(self.shape.1.saturating_sub(offset as usize))
575 } else {
576 self.shape
577 .1
578 .min(self.shape.0.saturating_sub((-offset) as usize))
579 };
580
581 let has_nonzero = (0..length).any(|i| !SparseElement::is_zero(&diag[i]));
582
583 if has_nonzero {
584 new_offsets.push(offset);
585 new_data.push(diag.clone());
586 }
587 }
588
589 self.offsets = new_offsets;
590 self.data = new_data;
591 }
592
593 fn sort_indices(&mut self) {
594 let mut offset_data: Vec<(isize, Array1<T>)> = self
597 .offsets
598 .iter()
599 .cloned()
600 .zip(self.data.drain(..))
601 .collect();
602 offset_data.sort_by_key(|&(offset_, _)| offset_);
603
604 self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
605 self.data = offset_data.into_iter().map(|(_, data)| data).collect();
606 }
607
608 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
609 let mut result = self.clone();
611 result.sort_indices();
612 Box::new(result)
613 }
614
615 fn has_sorted_indices(&self) -> bool {
616 self.offsets.windows(2).all(|w| w[0] <= w[1])
618 }
619
620 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
621 match axis {
622 None => {
623 let mut total = T::sparse_zero();
625
626 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
627 let diag = &self.data[diag_idx];
628
629 let length = if offset >= 0 {
630 self.shape
631 .0
632 .min(self.shape.1.saturating_sub(offset as usize))
633 } else {
634 self.shape
635 .1
636 .min(self.shape.0.saturating_sub((-offset) as usize))
637 };
638
639 for i in 0..length {
640 total += diag[i];
641 }
642 }
643
644 Ok(SparseSum::Scalar(total))
645 }
646 Some(0) => {
647 let mut result = Array1::zeros(self.shape.1);
649
650 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
651 let diag = &self.data[diag_idx];
652
653 if offset >= 0 {
654 let offset_usize = offset as usize;
656 let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
657
658 for i in 0..length {
659 result[i + offset_usize] += diag[i];
660 }
661 } else {
662 let offset_usize = (-offset) as usize;
664 let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
665
666 for i in 0..length {
667 result[i] += diag[i];
668 }
669 }
670 }
671
672 match Array2::from_shape_vec((1, self.shape.1), result.to_vec()) {
674 Ok(result_2d) => {
675 let mut row_indices = Vec::new();
677 let mut col_indices = Vec::new();
678 let mut values = Vec::new();
679
680 for j in 0..self.shape.1 {
681 let val: T = result_2d[[0, j]];
682 if !SparseElement::is_zero(&val) {
683 row_indices.push(0);
684 col_indices.push(j);
685 values.push(val);
686 }
687 }
688
689 match CooArray::from_triplets(
691 &row_indices,
692 &col_indices,
693 &values,
694 (1, self.shape.1),
695 false,
696 ) {
697 Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
698 Err(e) => Err(e),
699 }
700 }
701 Err(_) => Err(SparseError::InconsistentData {
702 reason: "Failed to create 2D array from result vector".to_string(),
703 }),
704 }
705 }
706 Some(1) => {
707 let mut result = Array1::zeros(self.shape.0);
709
710 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
711 let diag = &self.data[diag_idx];
712
713 if offset >= 0 {
714 let offset_usize = offset as usize;
716 let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
717
718 for i in 0..length {
719 result[i] += diag[i];
720 }
721 } else {
722 let offset_usize = (-offset) as usize;
724 let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
725
726 for i in 0..length {
727 result[i + offset_usize] += diag[i];
728 }
729 }
730 }
731
732 match Array2::from_shape_vec((self.shape.0, 1), result.to_vec()) {
734 Ok(result_2d) => {
735 let mut row_indices = Vec::new();
737 let mut col_indices = Vec::new();
738 let mut values = Vec::new();
739
740 for i in 0..self.shape.0 {
741 let val: T = result_2d[[i, 0]];
742 if !SparseElement::is_zero(&val) {
743 row_indices.push(i);
744 col_indices.push(0);
745 values.push(val);
746 }
747 }
748
749 match CooArray::from_triplets(
751 &row_indices,
752 &col_indices,
753 &values,
754 (self.shape.0, 1),
755 false,
756 ) {
757 Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
758 Err(e) => Err(e),
759 }
760 }
761 Err(_) => Err(SparseError::InconsistentData {
762 reason: "Failed to create 2D array from result vector".to_string(),
763 }),
764 }
765 }
766 _ => Err(SparseError::InvalidAxis),
767 }
768 }
769
770 fn max(&self) -> T {
771 let mut max_val = T::neg_infinity();
772
773 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
774 let diag = &self.data[diag_idx];
775
776 let length = if offset >= 0 {
777 self.shape
778 .0
779 .min(self.shape.1.saturating_sub(offset as usize))
780 } else {
781 self.shape
782 .1
783 .min(self.shape.0.saturating_sub((-offset) as usize))
784 };
785
786 for i in 0..length {
787 max_val = max_val.max(diag[i]);
788 }
789 }
790
791 if max_val == T::neg_infinity() {
793 T::sparse_zero()
794 } else {
795 max_val
796 }
797 }
798
799 fn min(&self) -> T {
800 let mut min_val = T::sparse_zero();
801 let mut has_nonzero = false;
802
803 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
804 let diag = &self.data[diag_idx];
805
806 let length = if offset >= 0 {
807 self.shape
808 .0
809 .min(self.shape.1.saturating_sub(offset as usize))
810 } else {
811 self.shape
812 .1
813 .min(self.shape.0.saturating_sub((-offset) as usize))
814 };
815
816 for i in 0..length {
817 if !SparseElement::is_zero(&diag[i]) {
818 has_nonzero = true;
819 min_val = min_val.min(diag[i]);
820 }
821 }
822 }
823
824 if !has_nonzero {
826 T::sparse_zero()
827 } else {
828 min_val
829 }
830 }
831
832 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
833 let (row_indices, col_indices, values) = self.to_coo_internal();
834
835 (
836 Array1::from_vec(row_indices),
837 Array1::from_vec(col_indices),
838 Array1::from_vec(values),
839 )
840 }
841
842 fn slice(
843 &self,
844 row_range: (usize, usize),
845 col_range: (usize, usize),
846 ) -> SparseResult<Box<dyn SparseArray<T>>> {
847 let (start_row, end_row) = row_range;
848 let (start_col, end_col) = col_range;
849 let (rows, cols) = self.shape;
850
851 if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
852 return Err(SparseError::IndexOutOfBounds {
853 index: (start_row.max(end_row), start_col.max(end_col)),
854 shape: (rows, cols),
855 });
856 }
857
858 if start_row >= end_row || start_col >= end_col {
859 return Err(SparseError::InvalidSliceRange);
860 }
861
862 let coo = self.to_coo()?;
864 coo.slice(row_range, col_range)?.to_dia()
865 }
866
867 fn as_any(&self) -> &dyn std::any::Any {
868 self
869 }
870}
871
872impl<T> fmt::Display for DiaArray<T>
874where
875 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
876{
877 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
878 writeln!(
879 f,
880 "DiaArray of shape {:?} with {} stored elements",
881 self.shape,
882 self.nnz()
883 )?;
884 writeln!(f, "Offsets: {:?}", self.offsets)?;
885
886 if self.offsets.len() <= 5 {
887 for (i, &offset) in self.offsets.iter().enumerate() {
888 let diag = &self.data[i];
889 let length = if offset >= 0 {
890 self.shape
891 .0
892 .min(self.shape.1.saturating_sub(offset as usize))
893 } else {
894 self.shape
895 .1
896 .min(self.shape.0.saturating_sub((-offset) as usize))
897 };
898
899 write!(f, "Diagonal {offset}: [")?;
900 for j in 0..length.min(10) {
901 if j > 0 {
902 write!(f, ", ")?;
903 }
904 write!(f, "{:?}", diag[j])?;
905 }
906 if length > 10 {
907 write!(f, ", ...")?;
908 }
909 writeln!(f, "]")?;
910 }
911 } else {
912 writeln!(f, "({} diagonals)", self.offsets.len())?;
913 }
914
915 Ok(())
916 }
917}
918
919#[cfg(test)]
920mod tests {
921 use super::*;
922
923 #[test]
924 fn test_dia_array_create() {
925 let data = vec![
927 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), ];
930 let offsets = vec![0, 1]; let shape = (3, 3);
932
933 let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
934
935 assert_eq!(array.shape(), (3, 3));
936 assert_eq!(array.nnz(), 5); assert_eq!(array.get(0, 0), 1.0);
940 assert_eq!(array.get(1, 1), 2.0);
941 assert_eq!(array.get(2, 2), 3.0);
942 assert_eq!(array.get(0, 1), 4.0);
943 assert_eq!(array.get(1, 2), 5.0);
944 assert_eq!(array.get(0, 2), 0.0);
945 }
946
947 #[test]
948 fn test_dia_array_from_triplets() {
949 let row = vec![0, 0, 1, 1, 1, 2, 2];
951 let col = vec![0, 1, 0, 1, 2, 1, 2];
952 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0, 6.0, 7.0];
953 let shape = (3, 3);
954
955 let array = DiaArray::from_triplets(&row, &col, &data, shape).expect("Operation failed");
956
957 assert_eq!(array.offsets.len(), 3);
959 assert!(array.offsets.contains(&0));
960 assert!(array.offsets.contains(&1));
961 assert!(array.offsets.contains(&-1));
962
963 assert_eq!(array.get(0, 0), 1.0);
965 assert_eq!(array.get(0, 1), 4.0);
966 assert_eq!(array.get(1, 0), 2.0);
967 assert_eq!(array.get(1, 1), 3.0);
968 assert_eq!(array.get(1, 2), 5.0);
969 assert_eq!(array.get(2, 1), 6.0);
970 assert_eq!(array.get(2, 2), 7.0);
971 }
972
973 #[test]
974 fn test_dia_array_conversion() {
975 let data = vec![
977 Array1::from_vec(vec![1.0, 3.0, 7.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 2.0, 0.0]), ];
981 let offsets = vec![0, 1, -1]; let shape = (3, 3);
983
984 let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
985
986 let coo = array.to_coo().expect("Operation failed");
988 assert_eq!(coo.shape(), (3, 3));
989 assert_eq!(coo.nnz(), 6); let dense = array.to_array();
993
994 let expected =
997 Array2::from_shape_vec((3, 3), vec![1.0, 4.0, 0.0, 0.0, 3.0, 5.0, 0.0, 2.0, 7.0])
998 .expect("Operation failed");
999 assert_eq!(dense, expected);
1000 }
1001
1002 #[test]
1003 fn test_dia_array_operations() {
1004 let data1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])]; let offsets1 = vec![0];
1007 let shape1 = (3, 3);
1008 let array1 = DiaArray::new(data1, offsets1, shape1).expect("Operation failed");
1009
1010 let data2 = vec![Array1::from_vec(vec![4.0, 5.0, 6.0])]; let offsets2 = vec![0];
1012 let shape2 = (3, 3);
1013 let array2 = DiaArray::new(data2, offsets2, shape2).expect("Operation failed");
1014
1015 let sum = array1.add(&array2).expect("Operation failed");
1017 assert_eq!(sum.get(0, 0), 5.0);
1018 assert_eq!(sum.get(1, 1), 7.0);
1019 assert_eq!(sum.get(2, 2), 9.0);
1020
1021 let product = array1.mul(&array2).expect("Operation failed");
1023 assert_eq!(product.get(0, 0), 4.0);
1024 assert_eq!(product.get(1, 1), 10.0);
1025 assert_eq!(product.get(2, 2), 18.0);
1026
1027 let dot = array1.dot(&array2).expect("Operation failed");
1029 assert_eq!(dot.get(0, 0), 4.0);
1030 assert_eq!(dot.get(1, 1), 10.0);
1031 assert_eq!(dot.get(2, 2), 18.0);
1032 }
1033
1034 #[test]
1035 fn test_dia_array_dot_vector() {
1036 let data = vec![
1038 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 6.0, 7.0]), ];
1042 let offsets = vec![0, 1, -1]; let shape = (3, 3);
1044
1045 let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
1046
1047 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1049
1050 let result = array.dot_vector(&vector.view()).expect("Operation failed");
1052
1053 let expected = Array1::from_vec(vec![9.0, 19.0, 21.0]);
1056 assert_eq!(result, expected);
1057 }
1058
1059 #[test]
1060 fn test_dia_array_transpose() {
1061 let data = vec![
1063 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), Array1::from_vec(vec![0.0, 6.0, 7.0]), ];
1067 let offsets = vec![0, 1, -1]; let shape = (3, 3);
1069
1070 let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
1071 let transposed = array.transpose().expect("Operation failed");
1072
1073 assert_eq!(transposed.shape(), (3, 3));
1075
1076 let original_dense = array.to_array();
1078 let transposed_dense = transposed.to_array();
1079
1080 for i in 0..3 {
1081 for j in 0..3 {
1082 assert_eq!(transposed_dense[[i, j]], original_dense[[j, i]]);
1083 }
1084 }
1085 }
1086
1087 #[test]
1088 fn test_dia_array_sum() {
1089 let data = vec![
1091 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0, 0.0]), ];
1094 let offsets = vec![0, 1]; let shape = (3, 3);
1096
1097 let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
1098
1099 if let SparseSum::Scalar(sum) = array.sum(None).expect("Operation failed") {
1101 assert_eq!(sum, 15.0); } else {
1103 panic!("Expected SparseSum::Scalar");
1104 }
1105
1106 if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).expect("Operation failed") {
1108 assert_eq!(row_sum.shape(), (1, 3));
1109 assert_eq!(row_sum.get(0, 0), 1.0);
1110 assert_eq!(row_sum.get(0, 1), 6.0); assert_eq!(row_sum.get(0, 2), 8.0); } else {
1113 panic!("Expected SparseSum::SparseArray");
1114 }
1115
1116 if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).expect("Operation failed") {
1118 assert_eq!(col_sum.shape(), (3, 1));
1119 assert_eq!(col_sum.get(0, 0), 5.0); assert_eq!(col_sum.get(1, 0), 7.0); assert_eq!(col_sum.get(2, 0), 3.0);
1122 } else {
1123 panic!("Expected SparseSum::SparseArray");
1124 }
1125 }
1126}