1use crate::coo_array::CooArray;
8use crate::csr_array::CsrArray;
9use crate::error::{SparseError, SparseResult};
10use crate::sparray::SparseArray;
11use scirs2_core::numeric::{Float, SparseElement};
12use std::fmt::Debug;
13use std::ops::{Add, AddAssign, Div, Mul, Sub};
14
15#[allow(dead_code)]
41pub fn hstack<'a, T>(
42 arrays: &[&'a dyn SparseArray<T>],
43 format: &str,
44) -> SparseResult<Box<dyn SparseArray<T>>>
45where
46 T: 'a
47 + Float
48 + SparseElement
49 + Add<Output = T>
50 + Sub<Output = T>
51 + Mul<Output = T>
52 + Div<Output = T>
53 + Debug
54 + Copy
55 + 'static,
56{
57 if arrays.is_empty() {
58 return Err(SparseError::ValueError(
59 "Cannot stack empty list of arrays".to_string(),
60 ));
61 }
62
63 let firstshape = arrays[0].shape();
65 let m = firstshape.0;
66
67 for (_i, &array) in arrays.iter().enumerate().skip(1) {
68 let shape = array.shape();
69 if shape.0 != m {
70 return Err(SparseError::DimensionMismatch {
71 expected: m,
72 found: shape.0,
73 });
74 }
75 }
76
77 let mut n = 0;
79 for &array in arrays.iter() {
80 n += array.shape().1;
81 }
82
83 let mut rows = Vec::new();
85 let mut cols = Vec::new();
86 let mut data = Vec::new();
87
88 let mut col_offset = 0;
89 for &array in arrays.iter() {
90 let shape = array.shape();
91 let (array_rows, array_cols, array_data) = array.find();
92
93 for i in 0..array_data.len() {
94 rows.push(array_rows[i]);
95 cols.push(array_cols[i] + col_offset);
96 data.push(array_data[i]);
97 }
98
99 col_offset += shape.1;
100 }
101
102 match format.to_lowercase().as_str() {
104 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), false)
105 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
106 "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), false)
107 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
108 _ => Err(SparseError::ValueError(format!(
109 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
110 ))),
111 }
112}
113
114#[allow(dead_code)]
140pub fn vstack<'a, T>(
141 arrays: &[&'a dyn SparseArray<T>],
142 format: &str,
143) -> SparseResult<Box<dyn SparseArray<T>>>
144where
145 T: 'a
146 + Float
147 + SparseElement
148 + Add<Output = T>
149 + Sub<Output = T>
150 + Mul<Output = T>
151 + Div<Output = T>
152 + Debug
153 + Copy
154 + 'static,
155{
156 if arrays.is_empty() {
157 return Err(SparseError::ValueError(
158 "Cannot stack empty list of arrays".to_string(),
159 ));
160 }
161
162 let firstshape = arrays[0].shape();
164 let n = firstshape.1;
165
166 for (_i, &array) in arrays.iter().enumerate().skip(1) {
167 let shape = array.shape();
168 if shape.1 != n {
169 return Err(SparseError::DimensionMismatch {
170 expected: n,
171 found: shape.1,
172 });
173 }
174 }
175
176 let mut m = 0;
178 for &array in arrays.iter() {
179 m += array.shape().0;
180 }
181
182 let mut rows = Vec::new();
184 let mut cols = Vec::new();
185 let mut data = Vec::new();
186
187 let mut row_offset = 0;
188 for &array in arrays.iter() {
189 let shape = array.shape();
190 let (array_rows, array_cols, array_data) = array.find();
191
192 for i in 0..array_data.len() {
193 rows.push(array_rows[i] + row_offset);
194 cols.push(array_cols[i]);
195 data.push(array_data[i]);
196 }
197
198 row_offset += shape.0;
199 }
200
201 match format.to_lowercase().as_str() {
203 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), false)
204 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
205 "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), false)
206 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
207 _ => Err(SparseError::ValueError(format!(
208 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
209 ))),
210 }
211}
212
213#[allow(dead_code)]
245pub fn block_diag<'a, T>(
246 arrays: &[&'a dyn SparseArray<T>],
247 format: &str,
248) -> SparseResult<Box<dyn SparseArray<T>>>
249where
250 T: 'a
251 + Float
252 + SparseElement
253 + Add<Output = T>
254 + Sub<Output = T>
255 + Mul<Output = T>
256 + Div<Output = T>
257 + Debug
258 + Copy
259 + 'static,
260{
261 if arrays.is_empty() {
262 return Err(SparseError::ValueError(
263 "Cannot create block diagonal with empty list of arrays".to_string(),
264 ));
265 }
266
267 let mut total_rows = 0;
269 let mut total_cols = 0;
270 for &array in arrays.iter() {
271 let shape = array.shape();
272 total_rows += shape.0;
273 total_cols += shape.1;
274 }
275
276 let mut rows = Vec::new();
278 let mut cols = Vec::new();
279 let mut data = Vec::new();
280
281 let mut row_offset = 0;
282 let mut col_offset = 0;
283 for &array in arrays.iter() {
284 let shape = array.shape();
285 let (array_rows, array_cols, array_data) = array.find();
286
287 for i in 0..array_data.len() {
288 rows.push(array_rows[i] + row_offset);
289 cols.push(array_cols[i] + col_offset);
290 data.push(array_data[i]);
291 }
292
293 row_offset += shape.0;
294 col_offset += shape.1;
295 }
296
297 match format.to_lowercase().as_str() {
299 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (total_rows, total_cols), false)
300 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
301 "coo" => CooArray::from_triplets(&rows, &cols, &data, (total_rows, total_cols), false)
302 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
303 _ => Err(SparseError::ValueError(format!(
304 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
305 ))),
306 }
307}
308
309#[allow(dead_code)]
339pub fn tril<T>(
340 array: &dyn SparseArray<T>,
341 k: isize,
342 format: &str,
343) -> SparseResult<Box<dyn SparseArray<T>>>
344where
345 T: Float
346 + SparseElement
347 + Add<Output = T>
348 + Sub<Output = T>
349 + Mul<Output = T>
350 + Div<Output = T>
351 + Debug
352 + Copy
353 + 'static,
354{
355 let shape = array.shape();
356 let (rows, cols, data) = array.find();
357
358 let mut tril_rows = Vec::new();
360 let mut tril_cols = Vec::new();
361 let mut tril_data = Vec::new();
362
363 for i in 0..data.len() {
364 let row = rows[i];
365 let col = cols[i];
366
367 if (row as isize) >= (col as isize) - k {
368 tril_rows.push(row);
369 tril_cols.push(col);
370 tril_data.push(data[i]);
371 }
372 }
373
374 match format.to_lowercase().as_str() {
376 "csr" => CsrArray::from_triplets(&tril_rows, &tril_cols, &tril_data, shape, false)
377 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
378 "coo" => CooArray::from_triplets(&tril_rows, &tril_cols, &tril_data, shape, false)
379 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
380 _ => Err(SparseError::ValueError(format!(
381 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
382 ))),
383 }
384}
385
386#[allow(dead_code)]
416pub fn triu<T>(
417 array: &dyn SparseArray<T>,
418 k: isize,
419 format: &str,
420) -> SparseResult<Box<dyn SparseArray<T>>>
421where
422 T: Float
423 + SparseElement
424 + Add<Output = T>
425 + Sub<Output = T>
426 + Mul<Output = T>
427 + Div<Output = T>
428 + Debug
429 + Copy
430 + 'static,
431{
432 let shape = array.shape();
433 let (rows, cols, data) = array.find();
434
435 let mut triu_rows = Vec::new();
437 let mut triu_cols = Vec::new();
438 let mut triu_data = Vec::new();
439
440 for i in 0..data.len() {
441 let row = rows[i];
442 let col = cols[i];
443
444 if (row as isize) <= (col as isize) - k {
445 triu_rows.push(row);
446 triu_cols.push(col);
447 triu_data.push(data[i]);
448 }
449 }
450
451 match format.to_lowercase().as_str() {
453 "csr" => CsrArray::from_triplets(&triu_rows, &triu_cols, &triu_data, shape, false)
454 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
455 "coo" => CooArray::from_triplets(&triu_rows, &triu_cols, &triu_data, shape, false)
456 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
457 _ => Err(SparseError::ValueError(format!(
458 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
459 ))),
460 }
461}
462
463#[allow(dead_code)]
498pub fn kron<'a, T>(
499 a: &'a dyn SparseArray<T>,
500 b: &'a dyn SparseArray<T>,
501 format: &str,
502) -> SparseResult<Box<dyn SparseArray<T>>>
503where
504 T: 'a
505 + Float
506 + SparseElement
507 + Add<Output = T>
508 + AddAssign
509 + Sub<Output = T>
510 + Mul<Output = T>
511 + Div<Output = T>
512 + Debug
513 + Copy
514 + 'static,
515{
516 let ashape = a.shape();
517 let bshape = b.shape();
518
519 let outputshape = (ashape.0 * bshape.0, ashape.1 * bshape.1);
521
522 if a.nnz() == 0 || b.nnz() == 0 {
524 let empty_rows: Vec<usize> = Vec::new();
526 let empty_cols: Vec<usize> = Vec::new();
527 let empty_data: Vec<T> = Vec::new();
528
529 return match format.to_lowercase().as_str() {
530 "csr" => {
531 CsrArray::from_triplets(&empty_rows, &empty_cols, &empty_data, outputshape, false)
532 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
533 }
534 "coo" => {
535 CooArray::from_triplets(&empty_rows, &empty_cols, &empty_data, outputshape, false)
536 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
537 }
538 _ => Err(SparseError::ValueError(format!(
539 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
540 ))),
541 };
542 }
543
544 let b_coo = b.to_coo().unwrap();
546 let (b_rows, b_cols, b_data) = b_coo.find();
547
548 let a_coo = a.to_coo().unwrap();
553 let (a_rows, a_cols, a_data) = a_coo.find();
554
555 let nnz_a = a_data.len();
557 let nnz_b = b_data.len();
558 let nnz_output = nnz_a * nnz_b;
559
560 let mut out_rows = Vec::with_capacity(nnz_output);
562 let mut out_cols = Vec::with_capacity(nnz_output);
563 let mut out_data = Vec::with_capacity(nnz_output);
564
565 for i in 0..nnz_a {
567 for j in 0..nnz_b {
568 let row = a_rows[i] * bshape.0 + b_rows[j];
570 let col = a_cols[i] * bshape.1 + b_cols[j];
571
572 let val = a_data[i] * b_data[j];
574
575 out_rows.push(row);
577 out_cols.push(col);
578 out_data.push(val);
579 }
580 }
581
582 match format.to_lowercase().as_str() {
584 "csr" => CsrArray::from_triplets(&out_rows, &out_cols, &out_data, outputshape, false)
585 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
586 "coo" => CooArray::from_triplets(&out_rows, &out_cols, &out_data, outputshape, false)
587 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
588 _ => Err(SparseError::ValueError(format!(
589 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
590 ))),
591 }
592}
593
594#[allow(dead_code)]
632pub fn kronsum<'a, T>(
633 a: &'a dyn SparseArray<T>,
634 b: &'a dyn SparseArray<T>,
635 format: &str,
636) -> SparseResult<Box<dyn SparseArray<T>>>
637where
638 T: 'a
639 + Float
640 + SparseElement
641 + Add<Output = T>
642 + AddAssign
643 + Sub<Output = T>
644 + Mul<Output = T>
645 + Div<Output = T>
646 + Debug
647 + Copy
648 + 'static,
649{
650 let ashape = a.shape();
651 let bshape = b.shape();
652
653 if ashape.0 != ashape.1 {
655 return Err(SparseError::ValueError(
656 "First matrix must be square".to_string(),
657 ));
658 }
659 if bshape.0 != bshape.1 {
660 return Err(SparseError::ValueError(
661 "Second matrix must be square".to_string(),
662 ));
663 }
664
665 let m = ashape.0;
667 let n = bshape.0;
668
669 if is_identity_matrix(a) && is_identity_matrix(b) {
672 let outputshape = (m * n, m * n);
673 let mut rows = Vec::new();
674 let mut cols = Vec::new();
675 let mut data = Vec::new();
676
677 for i in 0..m * n {
679 rows.push(i);
680 cols.push(i);
681 data.push(T::sparse_one() + T::sparse_one()); }
683
684 for i in 0..n {
686 for j in 0..n {
687 if i != j && (b.get(i, j) > T::sparse_zero() || b.get(j, i) > T::sparse_zero()) {
688 for k in 0..m {
689 rows.push(i * m + k);
690 cols.push(j * m + k);
691 data.push(T::sparse_one());
692 }
693 }
694 }
695 }
696
697 for i in 0..n - 1 {
701 for j in 0..m {
702 rows.push(i * m + j);
705 cols.push((i + 1) * m + j);
706 data.push(T::sparse_one());
707
708 rows.push((i + 1) * m + j);
710 cols.push(i * m + j);
711 data.push(T::sparse_one());
712 }
713 }
714
715 return match format.to_lowercase().as_str() {
717 "csr" => CsrArray::from_triplets(&rows, &cols, &data, outputshape, true)
718 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
719 "coo" => CooArray::from_triplets(&rows, &cols, &data, outputshape, true)
720 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
721 _ => Err(SparseError::ValueError(format!(
722 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
723 ))),
724 };
725 }
726
727 let outputshape = (m * n, m * n);
729
730 let mut rows = Vec::new();
732 let mut cols = Vec::new();
733 let mut data = Vec::new();
734
735 let (a_rows, a_cols, a_data) = a.find();
737 for i in 0..n {
738 for k in 0..a_data.len() {
739 let row_idx = i * m + a_rows[k];
740 let col_idx = i * m + a_cols[k];
741 rows.push(row_idx);
742 cols.push(col_idx);
743 data.push(a_data[k]);
744 }
745 }
746
747 let (b_rows, b_cols, b_data) = b.find();
749 for k in 0..b_data.len() {
750 let b_row = b_rows[k];
751 let b_col = b_cols[k];
752
753 for i in 0..m {
754 let row_idx = b_row * m + i;
755 let col_idx = b_col * m + i;
756 rows.push(row_idx);
757 cols.push(col_idx);
758 data.push(b_data[k]);
759 }
760 }
761
762 match format.to_lowercase().as_str() {
764 "csr" => CsrArray::from_triplets(&rows, &cols, &data, outputshape, true)
765 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
766 "coo" => CooArray::from_triplets(&rows, &cols, &data, outputshape, true)
767 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
768 _ => Err(SparseError::ValueError(format!(
769 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
770 ))),
771 }
772}
773
774#[allow(dead_code)]
809pub fn bmat<'a, T>(
810 blocks: &[Vec<Option<&'a dyn SparseArray<T>>>],
811 format: &str,
812) -> SparseResult<Box<dyn SparseArray<T>>>
813where
814 T: 'a
815 + Float
816 + SparseElement
817 + Add<Output = T>
818 + AddAssign
819 + Sub<Output = T>
820 + Mul<Output = T>
821 + Div<Output = T>
822 + Debug
823 + Copy
824 + 'static,
825{
826 if blocks.is_empty() {
827 return Err(SparseError::ValueError(
828 "Empty blocks array provided".to_string(),
829 ));
830 }
831
832 let m = blocks.len(); let n = blocks[0].len(); for (i, row) in blocks.iter().enumerate() {
837 if row.len() != n {
838 return Err(SparseError::ValueError(format!(
839 "Block row {i} has length {}, expected {n}",
840 row.len()
841 )));
842 }
843 }
844
845 let mut row_sizes = vec![0; m];
847 let mut col_sizes = vec![0; n];
848 let mut block_mask = vec![vec![false; n]; m];
849
850 for (i, row_size) in row_sizes.iter_mut().enumerate().take(m) {
852 for (j, col_size) in col_sizes.iter_mut().enumerate().take(n) {
853 if let Some(block) = blocks[i][j] {
854 let shape = block.shape();
855
856 if *row_size == 0 {
858 *row_size = shape.0;
859 } else if *row_size != shape.0 {
860 return Err(SparseError::ValueError(format!(
861 "Inconsistent row dimensions in block row {i}. Expected {}, got {}",
862 row_sizes[i], shape.0
863 )));
864 }
865
866 if *col_size == 0 {
868 *col_size = shape.1;
869 } else if *col_size != shape.1 {
870 return Err(SparseError::ValueError(format!(
871 "Inconsistent column dimensions in block column {j}. Expected {}, got {}",
872 *col_size, shape.1
873 )));
874 }
875
876 block_mask[i][j] = true;
877 }
878 }
879 }
880
881 for (i, &row_size) in row_sizes.iter().enumerate().take(m) {
883 if row_size == 0 {
884 return Err(SparseError::ValueError(format!(
885 "Block row {i} has no arrays, cannot determine dimensions"
886 )));
887 }
888 }
889 for (j, &col_size) in col_sizes.iter().enumerate().take(n) {
890 if col_size == 0 {
891 return Err(SparseError::ValueError(format!(
892 "Block column {j} has no arrays, cannot determine dimensions"
893 )));
894 }
895 }
896
897 let mut row_offsets = vec![0; m + 1];
899 let mut col_offsets = vec![0; n + 1];
900
901 for i in 0..m {
902 row_offsets[i + 1] = row_offsets[i] + row_sizes[i];
903 }
904 for j in 0..n {
905 col_offsets[j + 1] = col_offsets[j] + col_sizes[j];
906 }
907
908 let totalshape = (row_offsets[m], col_offsets[n]);
910
911 let mut has_blocks = false;
913 for mask_row in block_mask.iter().take(m) {
914 for &mask_elem in mask_row.iter().take(n) {
915 if mask_elem {
916 has_blocks = true;
917 break;
918 }
919 }
920 if has_blocks {
921 break;
922 }
923 }
924
925 if !has_blocks {
926 let empty_rows: Vec<usize> = Vec::new();
928 let empty_cols: Vec<usize> = Vec::new();
929 let empty_data: Vec<T> = Vec::new();
930
931 return match format.to_lowercase().as_str() {
932 "csr" => {
933 CsrArray::from_triplets(&empty_rows, &empty_cols, &empty_data, totalshape, false)
934 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
935 }
936 "coo" => {
937 CooArray::from_triplets(&empty_rows, &empty_cols, &empty_data, totalshape, false)
938 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
939 }
940 _ => Err(SparseError::ValueError(format!(
941 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
942 ))),
943 };
944 }
945
946 let mut rows = Vec::new();
948 let mut cols = Vec::new();
949 let mut data = Vec::new();
950
951 for (i, row_offset) in row_offsets.iter().take(m).enumerate() {
952 for (j, col_offset) in col_offsets.iter().take(n).enumerate() {
953 if let Some(block) = blocks[i][j] {
954 let (block_rows, block_cols, block_data) = block.find();
955
956 for (((row, col), val), _) in block_rows
957 .iter()
958 .zip(block_cols.iter())
959 .zip(block_data.iter())
960 .zip(0..block_data.len())
961 {
962 rows.push(*row + *row_offset);
963 cols.push(*col + *col_offset);
964 data.push(*val);
965 }
966 }
967 }
968 }
969
970 match format.to_lowercase().as_str() {
972 "csr" => CsrArray::from_triplets(&rows, &cols, &data, totalshape, false)
973 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
974 "coo" => CooArray::from_triplets(&rows, &cols, &data, totalshape, false)
975 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
976 _ => Err(SparseError::ValueError(format!(
977 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
978 ))),
979 }
980}
981
982#[allow(dead_code)]
984fn is_identity_matrix<T>(array: &dyn SparseArray<T>) -> bool
985where
986 T: Float + SparseElement + Debug + Copy + 'static,
987{
988 let shape = array.shape();
989
990 if shape.0 != shape.1 {
992 return false;
993 }
994
995 let n = shape.0;
996
997 if array.nnz() != n {
999 return false;
1000 }
1001
1002 let (rows, cols, data) = array.find();
1004
1005 if rows.len() != n {
1006 return false;
1007 }
1008
1009 for i in 0..rows.len() {
1010 if rows[i] != cols[i] {
1012 return false;
1013 }
1014
1015 if (data[i] - T::sparse_one()).abs() > T::epsilon() {
1017 return false;
1018 }
1019 }
1020
1021 true
1022}
1023
1024#[cfg(test)]
1025mod tests {
1026 use super::*;
1027 use crate::construct::eye_array;
1028
1029 #[test]
1030 fn test_hstack() {
1031 let a = eye_array::<f64>(2, "csr").unwrap();
1032 let b = eye_array::<f64>(2, "csr").unwrap();
1033 let c = hstack(&[&*a, &*b], "csr").unwrap();
1034
1035 assert_eq!(c.shape(), (2, 4));
1036 assert_eq!(c.get(0, 0), 1.0);
1037 assert_eq!(c.get(1, 1), 1.0);
1038 assert_eq!(c.get(0, 2), 1.0);
1039 assert_eq!(c.get(1, 3), 1.0);
1040 assert_eq!(c.get(0, 1), 0.0);
1041 assert_eq!(c.get(0, 3), 0.0);
1042 }
1043
1044 #[test]
1045 fn test_vstack() {
1046 let a = eye_array::<f64>(2, "csr").unwrap();
1047 let b = eye_array::<f64>(2, "csr").unwrap();
1048 let c = vstack(&[&*a, &*b], "csr").unwrap();
1049
1050 assert_eq!(c.shape(), (4, 2));
1051 assert_eq!(c.get(0, 0), 1.0);
1052 assert_eq!(c.get(1, 1), 1.0);
1053 assert_eq!(c.get(2, 0), 1.0);
1054 assert_eq!(c.get(3, 1), 1.0);
1055 assert_eq!(c.get(0, 1), 0.0);
1056 assert_eq!(c.get(1, 0), 0.0);
1057 }
1058
1059 #[test]
1060 fn test_block_diag() {
1061 let a = eye_array::<f64>(2, "csr").unwrap();
1062 let b = eye_array::<f64>(3, "csr").unwrap();
1063 let c = block_diag(&[&*a, &*b], "csr").unwrap();
1064
1065 assert_eq!(c.shape(), (5, 5));
1066 assert_eq!(c.get(0, 0), 1.0);
1068 assert_eq!(c.get(1, 1), 1.0);
1069 assert_eq!(c.get(2, 2), 1.0);
1071 assert_eq!(c.get(3, 3), 1.0);
1072 assert_eq!(c.get(4, 4), 1.0);
1073 assert_eq!(c.get(0, 2), 0.0);
1075 assert_eq!(c.get(2, 0), 0.0);
1076 }
1077
1078 #[test]
1079 fn test_kron() {
1080 let a = eye_array::<f64>(2, "csr").unwrap();
1082 let b = eye_array::<f64>(2, "csr").unwrap();
1083 let c = kron(&*a, &*b, "csr").unwrap();
1084
1085 assert_eq!(c.shape(), (4, 4));
1086 assert_eq!(c.get(0, 0), 1.0);
1088 assert_eq!(c.get(1, 1), 1.0);
1089 assert_eq!(c.get(2, 2), 1.0);
1090 assert_eq!(c.get(3, 3), 1.0);
1091 assert_eq!(c.get(0, 1), 0.0);
1092 assert_eq!(c.get(0, 2), 0.0);
1093 assert_eq!(c.get(1, 0), 0.0);
1094
1095 let rowsa = vec![0, 0, 1];
1097 let cols_a = vec![0, 1, 0];
1098 let data_a = vec![1.0, 2.0, 3.0];
1099 let a = CooArray::from_triplets(&rowsa, &cols_a, &data_a, (2, 2), false).unwrap();
1100
1101 let rowsb = vec![0, 1];
1102 let cols_b = vec![0, 1];
1103 let data_b = vec![4.0, 5.0];
1104 let b = CooArray::from_triplets(&rowsb, &cols_b, &data_b, (2, 2), false).unwrap();
1105
1106 let c = kron(&a, &b, "csr").unwrap();
1107 assert_eq!(c.shape(), (4, 4));
1108
1109 assert_eq!(c.get(0, 0), 4.0);
1122 assert_eq!(c.get(0, 2), 8.0);
1123 assert_eq!(c.get(1, 1), 5.0);
1124 assert_eq!(c.get(1, 3), 10.0);
1125 assert_eq!(c.get(2, 0), 12.0);
1126 assert_eq!(c.get(3, 1), 15.0);
1127 assert_eq!(c.get(0, 1), 0.0);
1129 assert_eq!(c.get(0, 3), 0.0);
1130 assert_eq!(c.get(2, 1), 0.0);
1131 assert_eq!(c.get(2, 2), 0.0);
1132 assert_eq!(c.get(2, 3), 0.0);
1133 assert_eq!(c.get(3, 0), 0.0);
1134 assert_eq!(c.get(3, 2), 0.0);
1135 assert_eq!(c.get(3, 3), 0.0);
1136 }
1137
1138 #[test]
1139 fn test_kronsum() {
1140 let a = eye_array::<f64>(2, "csr").unwrap();
1142 let b = eye_array::<f64>(2, "csr").unwrap();
1143 let c = kronsum(&*a, &*b, "csr").unwrap();
1144
1145 assert_eq!(c.shape(), (4, 4));
1150
1151 let (rows, _cols, data) = c.find();
1153 assert!(!rows.is_empty());
1154 assert!(!data.is_empty());
1155
1156 let c_coo = kronsum(&*a, &*b, "coo").unwrap();
1158 assert_eq!(c_coo.shape(), (4, 4));
1159
1160 let (coo_rows, _coo_cols, coo_data) = c_coo.find();
1162 assert!(!coo_rows.is_empty());
1163 assert!(!coo_data.is_empty());
1164 }
1165
1166 #[test]
1167 fn test_tril() {
1168 let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1170 let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
1171 let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
1172 let a = CooArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
1173
1174 let b = tril(&a, 0, "csr").unwrap();
1176 assert_eq!(b.shape(), (3, 3));
1177 assert_eq!(b.get(0, 0), 1.0);
1178 assert_eq!(b.get(1, 0), 1.0);
1179 assert_eq!(b.get(1, 1), 1.0);
1180 assert_eq!(b.get(2, 0), 1.0);
1181 assert_eq!(b.get(2, 1), 1.0);
1182 assert_eq!(b.get(2, 2), 1.0);
1183 assert_eq!(b.get(0, 1), 0.0);
1184 assert_eq!(b.get(0, 2), 0.0);
1185 assert_eq!(b.get(1, 2), 0.0);
1186
1187 let c = tril(&a, 1, "csr").unwrap();
1189 assert_eq!(c.get(0, 0), 1.0);
1190 assert_eq!(c.get(0, 1), 1.0); assert_eq!(c.get(0, 2), 0.0); assert_eq!(c.get(1, 1), 1.0);
1193 assert_eq!(c.get(1, 2), 1.0); }
1195
1196 #[test]
1197 fn test_triu() {
1198 let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1200 let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
1201 let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
1202 let a = CooArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
1203
1204 let b = triu(&a, 0, "csr").unwrap();
1206 assert_eq!(b.shape(), (3, 3));
1207 assert_eq!(b.get(0, 0), 1.0);
1208 assert_eq!(b.get(0, 1), 1.0);
1209 assert_eq!(b.get(0, 2), 1.0);
1210 assert_eq!(b.get(1, 1), 1.0);
1211 assert_eq!(b.get(1, 2), 1.0);
1212 assert_eq!(b.get(2, 2), 1.0);
1213 assert_eq!(b.get(1, 0), 0.0);
1214 assert_eq!(b.get(2, 0), 0.0);
1215 assert_eq!(b.get(2, 1), 0.0);
1216
1217 let c = triu(&a, -1, "csr").unwrap();
1219 assert_eq!(c.get(0, 0), 1.0);
1220 assert_eq!(c.get(1, 0), 1.0); assert_eq!(c.get(2, 0), 0.0); assert_eq!(c.get(1, 1), 1.0);
1223 assert_eq!(c.get(2, 1), 1.0); }
1225
1226 #[test]
1227 fn test_bmat() {
1228 let a = eye_array::<f64>(2, "csr").unwrap();
1229 let b = eye_array::<f64>(2, "csr").unwrap();
1230
1231 let blocks1 = vec![vec![Some(&*a), Some(&*b)], vec![Some(&*b), Some(&*a)]];
1233 let c1 = bmat(&blocks1, "csr").unwrap();
1234
1235 assert_eq!(c1.shape(), (4, 4));
1236 assert_eq!(c1.get(0, 0), 1.0);
1238 assert_eq!(c1.get(1, 1), 1.0);
1239 assert_eq!(c1.get(2, 2), 1.0);
1240 assert_eq!(c1.get(3, 3), 1.0);
1241 assert_eq!(c1.get(0, 2), 1.0);
1243 assert_eq!(c1.get(1, 3), 1.0);
1244 assert_eq!(c1.get(2, 0), 1.0);
1245 assert_eq!(c1.get(3, 1), 1.0);
1246 assert_eq!(c1.get(0, 1), 0.0);
1248 assert_eq!(c1.get(0, 3), 0.0);
1249 assert_eq!(c1.get(2, 1), 0.0);
1250 assert_eq!(c1.get(2, 3), 0.0);
1251
1252 let blocks2 = vec![vec![Some(&*a), Some(&*b)], vec![None, Some(&*a)]];
1254 let c2 = bmat(&blocks2, "csr").unwrap();
1255
1256 assert_eq!(c2.shape(), (4, 4));
1257 assert_eq!(c2.get(0, 0), 1.0);
1259 assert_eq!(c2.get(1, 1), 1.0);
1260 assert_eq!(c2.get(2, 0), 0.0); assert_eq!(c2.get(2, 1), 0.0); assert_eq!(c2.get(2, 2), 1.0);
1263 assert_eq!(c2.get(3, 3), 1.0);
1264
1265 let b1 = eye_array::<f64>(2, "csr").unwrap();
1267 let b2 = eye_array::<f64>(2, "csr").unwrap();
1268
1269 let blocks3 = vec![vec![Some(&*b1), Some(&*b2)], vec![Some(&*b2), Some(&*b1)]];
1270 let c3 = bmat(&blocks3, "csr").unwrap();
1271
1272 assert_eq!(c3.shape(), (4, 4));
1273 assert_eq!(c3.get(0, 0), 1.0);
1274 assert_eq!(c3.get(1, 1), 1.0);
1275 assert_eq!(c3.get(2, 2), 1.0);
1276 assert_eq!(c3.get(3, 3), 1.0);
1277 }
1278}