1use crate::coo_array::CooArray;
8use crate::csr_array::CsrArray;
9use crate::error::{SparseError, SparseResult};
10use crate::sparray::SparseArray;
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13use std::ops::{Add, AddAssign, Div, Mul, Sub};
14
15#[allow(dead_code)]
41pub fn hstack<'a, T>(
42 arrays: &[&'a dyn SparseArray<T>],
43 format: &str,
44) -> SparseResult<Box<dyn SparseArray<T>>>
45where
46 T: 'a
47 + Float
48 + Add<Output = T>
49 + Sub<Output = T>
50 + Mul<Output = T>
51 + Div<Output = T>
52 + Debug
53 + Copy
54 + 'static,
55{
56 if arrays.is_empty() {
57 return Err(SparseError::ValueError(
58 "Cannot stack empty list of arrays".to_string(),
59 ));
60 }
61
62 let firstshape = arrays[0].shape();
64 let m = firstshape.0;
65
66 for (_i, &array) in arrays.iter().enumerate().skip(1) {
67 let shape = array.shape();
68 if shape.0 != m {
69 return Err(SparseError::DimensionMismatch {
70 expected: m,
71 found: shape.0,
72 });
73 }
74 }
75
76 let mut n = 0;
78 for &array in arrays.iter() {
79 n += array.shape().1;
80 }
81
82 let mut rows = Vec::new();
84 let mut cols = Vec::new();
85 let mut data = Vec::new();
86
87 let mut col_offset = 0;
88 for &array in arrays.iter() {
89 let shape = array.shape();
90 let (array_rows, array_cols, array_data) = array.find();
91
92 for i in 0..array_data.len() {
93 rows.push(array_rows[i]);
94 cols.push(array_cols[i] + col_offset);
95 data.push(array_data[i]);
96 }
97
98 col_offset += shape.1;
99 }
100
101 match format.to_lowercase().as_str() {
103 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), false)
104 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
105 "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), false)
106 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
107 _ => Err(SparseError::ValueError(format!(
108 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
109 ))),
110 }
111}
112
113#[allow(dead_code)]
139pub fn vstack<'a, T>(
140 arrays: &[&'a dyn SparseArray<T>],
141 format: &str,
142) -> SparseResult<Box<dyn SparseArray<T>>>
143where
144 T: 'a
145 + Float
146 + Add<Output = T>
147 + Sub<Output = T>
148 + Mul<Output = T>
149 + Div<Output = T>
150 + Debug
151 + Copy
152 + 'static,
153{
154 if arrays.is_empty() {
155 return Err(SparseError::ValueError(
156 "Cannot stack empty list of arrays".to_string(),
157 ));
158 }
159
160 let firstshape = arrays[0].shape();
162 let n = firstshape.1;
163
164 for (_i, &array) in arrays.iter().enumerate().skip(1) {
165 let shape = array.shape();
166 if shape.1 != n {
167 return Err(SparseError::DimensionMismatch {
168 expected: n,
169 found: shape.1,
170 });
171 }
172 }
173
174 let mut m = 0;
176 for &array in arrays.iter() {
177 m += array.shape().0;
178 }
179
180 let mut rows = Vec::new();
182 let mut cols = Vec::new();
183 let mut data = Vec::new();
184
185 let mut row_offset = 0;
186 for &array in arrays.iter() {
187 let shape = array.shape();
188 let (array_rows, array_cols, array_data) = array.find();
189
190 for i in 0..array_data.len() {
191 rows.push(array_rows[i] + row_offset);
192 cols.push(array_cols[i]);
193 data.push(array_data[i]);
194 }
195
196 row_offset += shape.0;
197 }
198
199 match format.to_lowercase().as_str() {
201 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), false)
202 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
203 "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), false)
204 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
205 _ => Err(SparseError::ValueError(format!(
206 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
207 ))),
208 }
209}
210
211#[allow(dead_code)]
243pub fn block_diag<'a, T>(
244 arrays: &[&'a dyn SparseArray<T>],
245 format: &str,
246) -> SparseResult<Box<dyn SparseArray<T>>>
247where
248 T: 'a
249 + Float
250 + Add<Output = T>
251 + Sub<Output = T>
252 + Mul<Output = T>
253 + Div<Output = T>
254 + Debug
255 + Copy
256 + 'static,
257{
258 if arrays.is_empty() {
259 return Err(SparseError::ValueError(
260 "Cannot create block diagonal with empty list of arrays".to_string(),
261 ));
262 }
263
264 let mut total_rows = 0;
266 let mut total_cols = 0;
267 for &array in arrays.iter() {
268 let shape = array.shape();
269 total_rows += shape.0;
270 total_cols += shape.1;
271 }
272
273 let mut rows = Vec::new();
275 let mut cols = Vec::new();
276 let mut data = Vec::new();
277
278 let mut row_offset = 0;
279 let mut col_offset = 0;
280 for &array in arrays.iter() {
281 let shape = array.shape();
282 let (array_rows, array_cols, array_data) = array.find();
283
284 for i in 0..array_data.len() {
285 rows.push(array_rows[i] + row_offset);
286 cols.push(array_cols[i] + col_offset);
287 data.push(array_data[i]);
288 }
289
290 row_offset += shape.0;
291 col_offset += shape.1;
292 }
293
294 match format.to_lowercase().as_str() {
296 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (total_rows, total_cols), false)
297 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
298 "coo" => CooArray::from_triplets(&rows, &cols, &data, (total_rows, total_cols), false)
299 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
300 _ => Err(SparseError::ValueError(format!(
301 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
302 ))),
303 }
304}
305
306#[allow(dead_code)]
336pub fn tril<T>(
337 array: &dyn SparseArray<T>,
338 k: isize,
339 format: &str,
340) -> SparseResult<Box<dyn SparseArray<T>>>
341where
342 T: Float
343 + Add<Output = T>
344 + Sub<Output = T>
345 + Mul<Output = T>
346 + Div<Output = T>
347 + Debug
348 + Copy
349 + 'static,
350{
351 let shape = array.shape();
352 let (rows, cols, data) = array.find();
353
354 let mut tril_rows = Vec::new();
356 let mut tril_cols = Vec::new();
357 let mut tril_data = Vec::new();
358
359 for i in 0..data.len() {
360 let row = rows[i];
361 let col = cols[i];
362
363 if (row as isize) >= (col as isize) - k {
364 tril_rows.push(row);
365 tril_cols.push(col);
366 tril_data.push(data[i]);
367 }
368 }
369
370 match format.to_lowercase().as_str() {
372 "csr" => CsrArray::from_triplets(&tril_rows, &tril_cols, &tril_data, shape, false)
373 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
374 "coo" => CooArray::from_triplets(&tril_rows, &tril_cols, &tril_data, shape, false)
375 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
376 _ => Err(SparseError::ValueError(format!(
377 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
378 ))),
379 }
380}
381
382#[allow(dead_code)]
412pub fn triu<T>(
413 array: &dyn SparseArray<T>,
414 k: isize,
415 format: &str,
416) -> SparseResult<Box<dyn SparseArray<T>>>
417where
418 T: Float
419 + Add<Output = T>
420 + Sub<Output = T>
421 + Mul<Output = T>
422 + Div<Output = T>
423 + Debug
424 + Copy
425 + 'static,
426{
427 let shape = array.shape();
428 let (rows, cols, data) = array.find();
429
430 let mut triu_rows = Vec::new();
432 let mut triu_cols = Vec::new();
433 let mut triu_data = Vec::new();
434
435 for i in 0..data.len() {
436 let row = rows[i];
437 let col = cols[i];
438
439 if (row as isize) <= (col as isize) - k {
440 triu_rows.push(row);
441 triu_cols.push(col);
442 triu_data.push(data[i]);
443 }
444 }
445
446 match format.to_lowercase().as_str() {
448 "csr" => CsrArray::from_triplets(&triu_rows, &triu_cols, &triu_data, shape, false)
449 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
450 "coo" => CooArray::from_triplets(&triu_rows, &triu_cols, &triu_data, shape, false)
451 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
452 _ => Err(SparseError::ValueError(format!(
453 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
454 ))),
455 }
456}
457
458#[allow(dead_code)]
493pub fn kron<'a, T>(
494 a: &'a dyn SparseArray<T>,
495 b: &'a dyn SparseArray<T>,
496 format: &str,
497) -> SparseResult<Box<dyn SparseArray<T>>>
498where
499 T: 'a
500 + Float
501 + Add<Output = T>
502 + AddAssign
503 + Sub<Output = T>
504 + Mul<Output = T>
505 + Div<Output = T>
506 + Debug
507 + Copy
508 + 'static,
509{
510 let ashape = a.shape();
511 let bshape = b.shape();
512
513 let outputshape = (ashape.0 * bshape.0, ashape.1 * bshape.1);
515
516 if a.nnz() == 0 || b.nnz() == 0 {
518 let empty_rows: Vec<usize> = Vec::new();
520 let empty_cols: Vec<usize> = Vec::new();
521 let empty_data: Vec<T> = Vec::new();
522
523 return match format.to_lowercase().as_str() {
524 "csr" => {
525 CsrArray::from_triplets(&empty_rows, &empty_cols, &empty_data, outputshape, false)
526 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
527 }
528 "coo" => {
529 CooArray::from_triplets(&empty_rows, &empty_cols, &empty_data, outputshape, false)
530 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
531 }
532 _ => Err(SparseError::ValueError(format!(
533 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
534 ))),
535 };
536 }
537
538 let b_coo = b.to_coo().unwrap();
540 let (b_rows, b_cols, b_data) = b_coo.find();
541
542 let a_coo = a.to_coo().unwrap();
547 let (a_rows, a_cols, a_data) = a_coo.find();
548
549 let nnz_a = a_data.len();
551 let nnz_b = b_data.len();
552 let nnz_output = nnz_a * nnz_b;
553
554 let mut out_rows = Vec::with_capacity(nnz_output);
556 let mut out_cols = Vec::with_capacity(nnz_output);
557 let mut out_data = Vec::with_capacity(nnz_output);
558
559 for i in 0..nnz_a {
561 for j in 0..nnz_b {
562 let row = a_rows[i] * bshape.0 + b_rows[j];
564 let col = a_cols[i] * bshape.1 + b_cols[j];
565
566 let val = a_data[i] * b_data[j];
568
569 out_rows.push(row);
571 out_cols.push(col);
572 out_data.push(val);
573 }
574 }
575
576 match format.to_lowercase().as_str() {
578 "csr" => CsrArray::from_triplets(&out_rows, &out_cols, &out_data, outputshape, false)
579 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
580 "coo" => CooArray::from_triplets(&out_rows, &out_cols, &out_data, outputshape, false)
581 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
582 _ => Err(SparseError::ValueError(format!(
583 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
584 ))),
585 }
586}
587
588#[allow(dead_code)]
626pub fn kronsum<'a, T>(
627 a: &'a dyn SparseArray<T>,
628 b: &'a dyn SparseArray<T>,
629 format: &str,
630) -> SparseResult<Box<dyn SparseArray<T>>>
631where
632 T: 'a
633 + Float
634 + Add<Output = T>
635 + AddAssign
636 + Sub<Output = T>
637 + Mul<Output = T>
638 + Div<Output = T>
639 + Debug
640 + Copy
641 + 'static,
642{
643 let ashape = a.shape();
644 let bshape = b.shape();
645
646 if ashape.0 != ashape.1 {
648 return Err(SparseError::ValueError(
649 "First matrix must be square".to_string(),
650 ));
651 }
652 if bshape.0 != bshape.1 {
653 return Err(SparseError::ValueError(
654 "Second matrix must be square".to_string(),
655 ));
656 }
657
658 let m = ashape.0;
660 let n = bshape.0;
661
662 if is_identity_matrix(a) && is_identity_matrix(b) {
665 let outputshape = (m * n, m * n);
666 let mut rows = Vec::new();
667 let mut cols = Vec::new();
668 let mut data = Vec::new();
669
670 for i in 0..m * n {
672 rows.push(i);
673 cols.push(i);
674 data.push(T::one() + T::one()); }
676
677 for i in 0..n {
679 for j in 0..n {
680 if i != j && (b.get(i, j) > T::zero() || b.get(j, i) > T::zero()) {
681 for k in 0..m {
682 rows.push(i * m + k);
683 cols.push(j * m + k);
684 data.push(T::one());
685 }
686 }
687 }
688 }
689
690 for i in 0..n - 1 {
694 for j in 0..m {
695 rows.push(i * m + j);
698 cols.push((i + 1) * m + j);
699 data.push(T::one());
700
701 rows.push((i + 1) * m + j);
703 cols.push(i * m + j);
704 data.push(T::one());
705 }
706 }
707
708 return match format.to_lowercase().as_str() {
710 "csr" => CsrArray::from_triplets(&rows, &cols, &data, outputshape, true)
711 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
712 "coo" => CooArray::from_triplets(&rows, &cols, &data, outputshape, true)
713 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
714 _ => Err(SparseError::ValueError(format!(
715 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
716 ))),
717 };
718 }
719
720 let outputshape = (m * n, m * n);
722
723 let mut rows = Vec::new();
725 let mut cols = Vec::new();
726 let mut data = Vec::new();
727
728 let (a_rows, a_cols, a_data) = a.find();
730 for i in 0..n {
731 for k in 0..a_data.len() {
732 let row_idx = i * m + a_rows[k];
733 let col_idx = i * m + a_cols[k];
734 rows.push(row_idx);
735 cols.push(col_idx);
736 data.push(a_data[k]);
737 }
738 }
739
740 let (b_rows, b_cols, b_data) = b.find();
742 for k in 0..b_data.len() {
743 let b_row = b_rows[k];
744 let b_col = b_cols[k];
745
746 for i in 0..m {
747 let row_idx = b_row * m + i;
748 let col_idx = b_col * m + i;
749 rows.push(row_idx);
750 cols.push(col_idx);
751 data.push(b_data[k]);
752 }
753 }
754
755 match format.to_lowercase().as_str() {
757 "csr" => CsrArray::from_triplets(&rows, &cols, &data, outputshape, true)
758 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
759 "coo" => CooArray::from_triplets(&rows, &cols, &data, outputshape, true)
760 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
761 _ => Err(SparseError::ValueError(format!(
762 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
763 ))),
764 }
765}
766
767#[allow(dead_code)]
802pub fn bmat<'a, T>(
803 blocks: &[Vec<Option<&'a dyn SparseArray<T>>>],
804 format: &str,
805) -> SparseResult<Box<dyn SparseArray<T>>>
806where
807 T: 'a
808 + Float
809 + Add<Output = T>
810 + AddAssign
811 + Sub<Output = T>
812 + Mul<Output = T>
813 + Div<Output = T>
814 + Debug
815 + Copy
816 + 'static,
817{
818 if blocks.is_empty() {
819 return Err(SparseError::ValueError(
820 "Empty blocks array provided".to_string(),
821 ));
822 }
823
824 let m = blocks.len(); let n = blocks[0].len(); for (i, row) in blocks.iter().enumerate() {
829 if row.len() != n {
830 return Err(SparseError::ValueError(format!(
831 "Block row {i} has length {}, expected {n}",
832 row.len()
833 )));
834 }
835 }
836
837 let mut row_sizes = vec![0; m];
839 let mut col_sizes = vec![0; n];
840 let mut block_mask = vec![vec![false; n]; m];
841
842 for (i, row_size) in row_sizes.iter_mut().enumerate().take(m) {
844 for (j, col_size) in col_sizes.iter_mut().enumerate().take(n) {
845 if let Some(block) = blocks[i][j] {
846 let shape = block.shape();
847
848 if *row_size == 0 {
850 *row_size = shape.0;
851 } else if *row_size != shape.0 {
852 return Err(SparseError::ValueError(format!(
853 "Inconsistent row dimensions in block row {i}. Expected {}, got {}",
854 row_sizes[i], shape.0
855 )));
856 }
857
858 if *col_size == 0 {
860 *col_size = shape.1;
861 } else if *col_size != shape.1 {
862 return Err(SparseError::ValueError(format!(
863 "Inconsistent column dimensions in block column {j}. Expected {}, got {}",
864 *col_size, shape.1
865 )));
866 }
867
868 block_mask[i][j] = true;
869 }
870 }
871 }
872
873 for (i, &row_size) in row_sizes.iter().enumerate().take(m) {
875 if row_size == 0 {
876 return Err(SparseError::ValueError(format!(
877 "Block row {i} has no arrays, cannot determine dimensions"
878 )));
879 }
880 }
881 for (j, &col_size) in col_sizes.iter().enumerate().take(n) {
882 if col_size == 0 {
883 return Err(SparseError::ValueError(format!(
884 "Block column {j} has no arrays, cannot determine dimensions"
885 )));
886 }
887 }
888
889 let mut row_offsets = vec![0; m + 1];
891 let mut col_offsets = vec![0; n + 1];
892
893 for i in 0..m {
894 row_offsets[i + 1] = row_offsets[i] + row_sizes[i];
895 }
896 for j in 0..n {
897 col_offsets[j + 1] = col_offsets[j] + col_sizes[j];
898 }
899
900 let totalshape = (row_offsets[m], col_offsets[n]);
902
903 let mut has_blocks = false;
905 for mask_row in block_mask.iter().take(m) {
906 for &mask_elem in mask_row.iter().take(n) {
907 if mask_elem {
908 has_blocks = true;
909 break;
910 }
911 }
912 if has_blocks {
913 break;
914 }
915 }
916
917 if !has_blocks {
918 let empty_rows: Vec<usize> = Vec::new();
920 let empty_cols: Vec<usize> = Vec::new();
921 let empty_data: Vec<T> = Vec::new();
922
923 return match format.to_lowercase().as_str() {
924 "csr" => {
925 CsrArray::from_triplets(&empty_rows, &empty_cols, &empty_data, totalshape, false)
926 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
927 }
928 "coo" => {
929 CooArray::from_triplets(&empty_rows, &empty_cols, &empty_data, totalshape, false)
930 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
931 }
932 _ => Err(SparseError::ValueError(format!(
933 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
934 ))),
935 };
936 }
937
938 let mut rows = Vec::new();
940 let mut cols = Vec::new();
941 let mut data = Vec::new();
942
943 for (i, row_offset) in row_offsets.iter().take(m).enumerate() {
944 for (j, col_offset) in col_offsets.iter().take(n).enumerate() {
945 if let Some(block) = blocks[i][j] {
946 let (block_rows, block_cols, block_data) = block.find();
947
948 for (((row, col), val), _) in block_rows
949 .iter()
950 .zip(block_cols.iter())
951 .zip(block_data.iter())
952 .zip(0..block_data.len())
953 {
954 rows.push(*row + *row_offset);
955 cols.push(*col + *col_offset);
956 data.push(*val);
957 }
958 }
959 }
960 }
961
962 match format.to_lowercase().as_str() {
964 "csr" => CsrArray::from_triplets(&rows, &cols, &data, totalshape, false)
965 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
966 "coo" => CooArray::from_triplets(&rows, &cols, &data, totalshape, false)
967 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
968 _ => Err(SparseError::ValueError(format!(
969 "Unknown sparse format: {format}. Supported formats are 'csr' and 'coo'"
970 ))),
971 }
972}
973
974#[allow(dead_code)]
976fn is_identity_matrix<T>(array: &dyn SparseArray<T>) -> bool
977where
978 T: Float + Debug + Copy + 'static,
979{
980 let shape = array.shape();
981
982 if shape.0 != shape.1 {
984 return false;
985 }
986
987 let n = shape.0;
988
989 if array.nnz() != n {
991 return false;
992 }
993
994 let (rows, cols, data) = array.find();
996
997 if rows.len() != n {
998 return false;
999 }
1000
1001 for i in 0..rows.len() {
1002 if rows[i] != cols[i] {
1004 return false;
1005 }
1006
1007 if (data[i] - T::one()).abs() > T::epsilon() {
1009 return false;
1010 }
1011 }
1012
1013 true
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use super::*;
1019 use crate::construct::eye_array;
1020
1021 #[test]
1022 fn test_hstack() {
1023 let a = eye_array::<f64>(2, "csr").unwrap();
1024 let b = eye_array::<f64>(2, "csr").unwrap();
1025 let c = hstack(&[&*a, &*b], "csr").unwrap();
1026
1027 assert_eq!(c.shape(), (2, 4));
1028 assert_eq!(c.get(0, 0), 1.0);
1029 assert_eq!(c.get(1, 1), 1.0);
1030 assert_eq!(c.get(0, 2), 1.0);
1031 assert_eq!(c.get(1, 3), 1.0);
1032 assert_eq!(c.get(0, 1), 0.0);
1033 assert_eq!(c.get(0, 3), 0.0);
1034 }
1035
1036 #[test]
1037 fn test_vstack() {
1038 let a = eye_array::<f64>(2, "csr").unwrap();
1039 let b = eye_array::<f64>(2, "csr").unwrap();
1040 let c = vstack(&[&*a, &*b], "csr").unwrap();
1041
1042 assert_eq!(c.shape(), (4, 2));
1043 assert_eq!(c.get(0, 0), 1.0);
1044 assert_eq!(c.get(1, 1), 1.0);
1045 assert_eq!(c.get(2, 0), 1.0);
1046 assert_eq!(c.get(3, 1), 1.0);
1047 assert_eq!(c.get(0, 1), 0.0);
1048 assert_eq!(c.get(1, 0), 0.0);
1049 }
1050
1051 #[test]
1052 fn test_block_diag() {
1053 let a = eye_array::<f64>(2, "csr").unwrap();
1054 let b = eye_array::<f64>(3, "csr").unwrap();
1055 let c = block_diag(&[&*a, &*b], "csr").unwrap();
1056
1057 assert_eq!(c.shape(), (5, 5));
1058 assert_eq!(c.get(0, 0), 1.0);
1060 assert_eq!(c.get(1, 1), 1.0);
1061 assert_eq!(c.get(2, 2), 1.0);
1063 assert_eq!(c.get(3, 3), 1.0);
1064 assert_eq!(c.get(4, 4), 1.0);
1065 assert_eq!(c.get(0, 2), 0.0);
1067 assert_eq!(c.get(2, 0), 0.0);
1068 }
1069
1070 #[test]
1071 fn test_kron() {
1072 let a = eye_array::<f64>(2, "csr").unwrap();
1074 let b = eye_array::<f64>(2, "csr").unwrap();
1075 let c = kron(&*a, &*b, "csr").unwrap();
1076
1077 assert_eq!(c.shape(), (4, 4));
1078 assert_eq!(c.get(0, 0), 1.0);
1080 assert_eq!(c.get(1, 1), 1.0);
1081 assert_eq!(c.get(2, 2), 1.0);
1082 assert_eq!(c.get(3, 3), 1.0);
1083 assert_eq!(c.get(0, 1), 0.0);
1084 assert_eq!(c.get(0, 2), 0.0);
1085 assert_eq!(c.get(1, 0), 0.0);
1086
1087 let rowsa = vec![0, 0, 1];
1089 let cols_a = vec![0, 1, 0];
1090 let data_a = vec![1.0, 2.0, 3.0];
1091 let a = CooArray::from_triplets(&rowsa, &cols_a, &data_a, (2, 2), false).unwrap();
1092
1093 let rowsb = vec![0, 1];
1094 let cols_b = vec![0, 1];
1095 let data_b = vec![4.0, 5.0];
1096 let b = CooArray::from_triplets(&rowsb, &cols_b, &data_b, (2, 2), false).unwrap();
1097
1098 let c = kron(&a, &b, "csr").unwrap();
1099 assert_eq!(c.shape(), (4, 4));
1100
1101 assert_eq!(c.get(0, 0), 4.0);
1114 assert_eq!(c.get(0, 2), 8.0);
1115 assert_eq!(c.get(1, 1), 5.0);
1116 assert_eq!(c.get(1, 3), 10.0);
1117 assert_eq!(c.get(2, 0), 12.0);
1118 assert_eq!(c.get(3, 1), 15.0);
1119 assert_eq!(c.get(0, 1), 0.0);
1121 assert_eq!(c.get(0, 3), 0.0);
1122 assert_eq!(c.get(2, 1), 0.0);
1123 assert_eq!(c.get(2, 2), 0.0);
1124 assert_eq!(c.get(2, 3), 0.0);
1125 assert_eq!(c.get(3, 0), 0.0);
1126 assert_eq!(c.get(3, 2), 0.0);
1127 assert_eq!(c.get(3, 3), 0.0);
1128 }
1129
1130 #[test]
1131 fn test_kronsum() {
1132 let a = eye_array::<f64>(2, "csr").unwrap();
1134 let b = eye_array::<f64>(2, "csr").unwrap();
1135 let c = kronsum(&*a, &*b, "csr").unwrap();
1136
1137 assert_eq!(c.shape(), (4, 4));
1142
1143 let (rows, _cols, data) = c.find();
1145 assert!(!rows.is_empty());
1146 assert!(!data.is_empty());
1147
1148 let c_coo = kronsum(&*a, &*b, "coo").unwrap();
1150 assert_eq!(c_coo.shape(), (4, 4));
1151
1152 let (coo_rows, _coo_cols, coo_data) = c_coo.find();
1154 assert!(!coo_rows.is_empty());
1155 assert!(!coo_data.is_empty());
1156 }
1157
1158 #[test]
1159 fn test_tril() {
1160 let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1162 let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
1163 let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
1164 let a = CooArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
1165
1166 let b = tril(&a, 0, "csr").unwrap();
1168 assert_eq!(b.shape(), (3, 3));
1169 assert_eq!(b.get(0, 0), 1.0);
1170 assert_eq!(b.get(1, 0), 1.0);
1171 assert_eq!(b.get(1, 1), 1.0);
1172 assert_eq!(b.get(2, 0), 1.0);
1173 assert_eq!(b.get(2, 1), 1.0);
1174 assert_eq!(b.get(2, 2), 1.0);
1175 assert_eq!(b.get(0, 1), 0.0);
1176 assert_eq!(b.get(0, 2), 0.0);
1177 assert_eq!(b.get(1, 2), 0.0);
1178
1179 let c = tril(&a, 1, "csr").unwrap();
1181 assert_eq!(c.get(0, 0), 1.0);
1182 assert_eq!(c.get(0, 1), 1.0); assert_eq!(c.get(0, 2), 0.0); assert_eq!(c.get(1, 1), 1.0);
1185 assert_eq!(c.get(1, 2), 1.0); }
1187
1188 #[test]
1189 fn test_triu() {
1190 let rows = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1192 let cols = vec![0, 1, 2, 0, 1, 2, 0, 1, 2];
1193 let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
1194 let a = CooArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
1195
1196 let b = triu(&a, 0, "csr").unwrap();
1198 assert_eq!(b.shape(), (3, 3));
1199 assert_eq!(b.get(0, 0), 1.0);
1200 assert_eq!(b.get(0, 1), 1.0);
1201 assert_eq!(b.get(0, 2), 1.0);
1202 assert_eq!(b.get(1, 1), 1.0);
1203 assert_eq!(b.get(1, 2), 1.0);
1204 assert_eq!(b.get(2, 2), 1.0);
1205 assert_eq!(b.get(1, 0), 0.0);
1206 assert_eq!(b.get(2, 0), 0.0);
1207 assert_eq!(b.get(2, 1), 0.0);
1208
1209 let c = triu(&a, -1, "csr").unwrap();
1211 assert_eq!(c.get(0, 0), 1.0);
1212 assert_eq!(c.get(1, 0), 1.0); assert_eq!(c.get(2, 0), 0.0); assert_eq!(c.get(1, 1), 1.0);
1215 assert_eq!(c.get(2, 1), 1.0); }
1217
1218 #[test]
1219 fn test_bmat() {
1220 let a = eye_array::<f64>(2, "csr").unwrap();
1221 let b = eye_array::<f64>(2, "csr").unwrap();
1222
1223 let blocks1 = vec![vec![Some(&*a), Some(&*b)], vec![Some(&*b), Some(&*a)]];
1225 let c1 = bmat(&blocks1, "csr").unwrap();
1226
1227 assert_eq!(c1.shape(), (4, 4));
1228 assert_eq!(c1.get(0, 0), 1.0);
1230 assert_eq!(c1.get(1, 1), 1.0);
1231 assert_eq!(c1.get(2, 2), 1.0);
1232 assert_eq!(c1.get(3, 3), 1.0);
1233 assert_eq!(c1.get(0, 2), 1.0);
1235 assert_eq!(c1.get(1, 3), 1.0);
1236 assert_eq!(c1.get(2, 0), 1.0);
1237 assert_eq!(c1.get(3, 1), 1.0);
1238 assert_eq!(c1.get(0, 1), 0.0);
1240 assert_eq!(c1.get(0, 3), 0.0);
1241 assert_eq!(c1.get(2, 1), 0.0);
1242 assert_eq!(c1.get(2, 3), 0.0);
1243
1244 let blocks2 = vec![vec![Some(&*a), Some(&*b)], vec![None, Some(&*a)]];
1246 let c2 = bmat(&blocks2, "csr").unwrap();
1247
1248 assert_eq!(c2.shape(), (4, 4));
1249 assert_eq!(c2.get(0, 0), 1.0);
1251 assert_eq!(c2.get(1, 1), 1.0);
1252 assert_eq!(c2.get(2, 0), 0.0); assert_eq!(c2.get(2, 1), 0.0); assert_eq!(c2.get(2, 2), 1.0);
1255 assert_eq!(c2.get(3, 3), 1.0);
1256
1257 let b1 = eye_array::<f64>(2, "csr").unwrap();
1259 let b2 = eye_array::<f64>(2, "csr").unwrap();
1260
1261 let blocks3 = vec![vec![Some(&*b1), Some(&*b2)], vec![Some(&*b2), Some(&*b1)]];
1262 let c3 = bmat(&blocks3, "csr").unwrap();
1263
1264 assert_eq!(c3.shape(), (4, 4));
1265 assert_eq!(c3.get(0, 0), 1.0);
1266 assert_eq!(c3.get(1, 1), 1.0);
1267 assert_eq!(c3.get(2, 2), 1.0);
1268 assert_eq!(c3.get(3, 3), 1.0);
1269 }
1270}