1#![allow(unused_variables)]
7#![allow(unused_assignments)]
8#![allow(unused_mut)]
9
10use scirs2_core::ndarray::Array1;
11use scirs2_core::numeric::Float;
12use scirs2_core::random::seq::SliceRandom;
13use scirs2_core::random::{Rng, SeedableRng};
14use std::fmt::Debug;
15use std::ops::{Add, Div, Mul, Sub};
16
17use crate::coo_array::CooArray;
18use crate::csr_array::CsrArray;
19use crate::dok_array::DokArray;
20use crate::error::{SparseError, SparseResult};
21use crate::lil_array::LilArray;
22use crate::sparray::SparseArray;
23
24use scirs2_core::parallel_ops::*;
26
27#[allow(dead_code)]
50pub fn eye_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SparseArray<T>>>
51where
52 T: Float
53 + Add<Output = T>
54 + Sub<Output = T>
55 + Mul<Output = T>
56 + Div<Output = T>
57 + Debug
58 + Copy
59 + 'static,
60{
61 if n == 0 {
62 return Err(SparseError::ValueError(
63 "Matrix dimension must be positive".to_string(),
64 ));
65 }
66
67 let mut rows = Vec::with_capacity(n);
68 let mut cols = Vec::with_capacity(n);
69 let mut data = Vec::with_capacity(n);
70
71 for i in 0..n {
72 rows.push(i);
73 cols.push(i);
74 data.push(T::one());
75 }
76
77 match format.to_lowercase().as_str() {
78 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
79 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
80 "coo" => CooArray::from_triplets(&rows, &cols, &data, (n, n), true)
81 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
82 "dok" => DokArray::from_triplets(&rows, &cols, &data, (n, n))
83 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
84 "lil" => LilArray::from_triplets(&rows, &cols, &data, (n, n))
85 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
86 _ => Err(SparseError::ValueError(format!(
87 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
88 ))),
89 }
90}
91
92#[allow(dead_code)]
127pub fn eye_array_k<T>(
128 m: usize,
129 n: usize,
130 k: isize,
131 format: &str,
132) -> SparseResult<Box<dyn SparseArray<T>>>
133where
134 T: Float
135 + Add<Output = T>
136 + Sub<Output = T>
137 + Mul<Output = T>
138 + Div<Output = T>
139 + Debug
140 + Copy
141 + 'static,
142{
143 if m == 0 || n == 0 {
144 return Err(SparseError::ValueError(
145 "Matrix dimensions must be positive".to_string(),
146 ));
147 }
148
149 let mut rows = Vec::new();
150 let mut cols = Vec::new();
151 let mut data = Vec::new();
152
153 if k >= 0 {
155 let k_usize = k as usize;
156 let len = std::cmp::min(m, n.saturating_sub(k_usize));
157
158 for i in 0..len {
159 rows.push(i);
160 cols.push(i + k_usize);
161 data.push(T::one());
162 }
163 } else {
164 let k_abs = (-k) as usize;
165 let len = std::cmp::min(m.saturating_sub(k_abs), n);
166
167 for i in 0..len {
168 rows.push(i + k_abs);
169 cols.push(i);
170 data.push(T::one());
171 }
172 }
173
174 match format.to_lowercase().as_str() {
175 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
176 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
177 "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), true)
178 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
179 "dok" => DokArray::from_triplets(&rows, &cols, &data, (m, n))
180 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
181 "lil" => LilArray::from_triplets(&rows, &cols, &data, (m, n))
182 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
183 _ => Err(SparseError::ValueError(format!(
184 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
185 ))),
186 }
187}
188
189#[allow(dead_code)]
222pub fn diags_array<T>(
223 diagonals: &[Array1<T>],
224 offsets: &[isize],
225 shape: (usize, usize),
226 format: &str,
227) -> SparseResult<Box<dyn SparseArray<T>>>
228where
229 T: Float
230 + Add<Output = T>
231 + Sub<Output = T>
232 + Mul<Output = T>
233 + Div<Output = T>
234 + Debug
235 + Copy
236 + 'static,
237{
238 if diagonals.len() != offsets.len() {
239 return Err(SparseError::InconsistentData {
240 reason: "Number of diagonals must match number of offsets".to_string(),
241 });
242 }
243
244 if shape.0 == 0 || shape.1 == 0 {
245 return Err(SparseError::ValueError(
246 "Matrix dimensions must be positive".to_string(),
247 ));
248 }
249
250 let (m, n) = shape;
251 let mut rows = Vec::new();
252 let mut cols = Vec::new();
253 let mut data = Vec::new();
254
255 for (i, (diag, &offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
256 if offset >= 0 {
257 let offset_usize = offset as usize;
258 let max_len = std::cmp::min(m, n.saturating_sub(offset_usize));
259
260 if diag.len() > max_len {
261 return Err(SparseError::InconsistentData {
262 reason: format!("Diagonal {i} is too long ({} > {})", diag.len(), max_len),
263 });
264 }
265
266 for (j, &value) in diag.iter().enumerate() {
267 if !value.is_zero() {
268 rows.push(j);
269 cols.push(j + offset_usize);
270 data.push(value);
271 }
272 }
273 } else {
274 let offset_abs = (-offset) as usize;
275 let max_len = std::cmp::min(m.saturating_sub(offset_abs), n);
276
277 if diag.len() > max_len {
278 return Err(SparseError::InconsistentData {
279 reason: format!("Diagonal {i} is too long ({} > {})", diag.len(), max_len),
280 });
281 }
282
283 for (j, &value) in diag.iter().enumerate() {
284 if !value.is_zero() {
285 rows.push(j + offset_abs);
286 cols.push(j);
287 data.push(value);
288 }
289 }
290 }
291 }
292
293 match format.to_lowercase().as_str() {
294 "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
295 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
296 "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
297 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
298 "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
299 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
300 "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
301 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
302 _ => Err(SparseError::ValueError(format!(
303 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
304 ))),
305 }
306}
307
308#[allow(dead_code)]
333pub fn random_array<T>(
334 shape: (usize, usize),
335 density: f64,
336 seed: Option<u64>,
337 format: &str,
338) -> SparseResult<Box<dyn SparseArray<T>>>
339where
340 T: Float
341 + Add<Output = T>
342 + Sub<Output = T>
343 + Mul<Output = T>
344 + Div<Output = T>
345 + Debug
346 + Copy
347 + 'static,
348{
349 let (m, n) = shape;
350
351 if !(0.0..=1.0).contains(&density) {
352 return Err(SparseError::ValueError(
353 "Density must be between 0.0 and 1.0".to_string(),
354 ));
355 }
356
357 if m == 0 || n == 0 {
358 return Err(SparseError::ValueError(
359 "Matrix dimensions must be positive".to_string(),
360 ));
361 }
362
363 let nnz = (m * n) as f64 * density;
365 let nnz = nnz.round() as usize;
366
367 let mut rows = Vec::with_capacity(nnz);
369 let mut cols = Vec::with_capacity(nnz);
370 let mut data = Vec::with_capacity(nnz);
371
372 let mut rng = if let Some(seed_value) = seed {
374 scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value)
375 } else {
376 let seed = scirs2_core::random::random::<u64>();
378 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
379 };
380
381 let total = m * n;
383
384 if density > 0.4 {
385 let mut indices: Vec<usize> = (0..total).collect();
387 indices.shuffle(&mut rng);
388
389 for &idx in indices.iter().take(nnz) {
390 let row = idx / n;
391 let col = idx % n;
392
393 rows.push(row);
394 cols.push(col);
395
396 let mut val: f64 = rng.random_range(-1.0..1.0);
399 while val.abs() < 1e-10 {
401 val = rng.random_range(-1.0..1.0);
402 }
403 data.push(T::from(val).unwrap());
404 }
405 } else {
406 let mut positions = std::collections::HashSet::with_capacity(nnz);
408
409 while positions.len() < nnz {
410 let row = rng.random_range(0..m);
411 let col = rng.random_range(0..n);
412 let pos = row * n + col; if positions.insert(pos) {
415 rows.push(row);
416 cols.push(col);
417
418 let mut val: f64 = rng.random_range(-1.0..1.0);
420 while val.abs() < 1e-10 {
422 val = rng.random_range(-1.0..1.0);
423 }
424 data.push(T::from(val).unwrap());
425 }
426 }
427 }
428
429 match format.to_lowercase().as_str() {
431 "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
432 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
433 "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
434 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
435 "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
436 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
437 "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
438 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
439 _ => Err(SparseError::ValueError(format!(
440 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
441 ))),
442 }
443}
444
445#[allow(dead_code)]
471pub fn random_array_parallel<T>(
472 shape: (usize, usize),
473 density: f64,
474 seed: Option<u64>,
475 format: &str,
476 parallel_threshold: usize,
477) -> SparseResult<Box<dyn SparseArray<T>>>
478where
479 T: Float
480 + Add<Output = T>
481 + Sub<Output = T>
482 + Mul<Output = T>
483 + Div<Output = T>
484 + Debug
485 + Copy
486 + Send
487 + Sync
488 + 'static,
489{
490 if !(0.0..=1.0).contains(&density) {
491 return Err(SparseError::ValueError(
492 "Density must be between 0.0 and 1.0".to_string(),
493 ));
494 }
495
496 let (rows, cols) = shape;
497 if rows == 0 || cols == 0 {
498 return Err(SparseError::ValueError(
499 "Matrix dimensions must be positive".to_string(),
500 ));
501 }
502
503 let total_elements = rows * cols;
504 let expected_nnz = (total_elements as f64 * density) as usize;
505
506 if total_elements >= parallel_threshold && expected_nnz >= 1000 {
508 parallel_random_construction(shape, density, seed, format)
509 } else {
510 random_array(shape, density, seed, format)
512 }
513}
514
515#[allow(dead_code)]
517fn parallel_random_construction<T>(
518 shape: (usize, usize),
519 density: f64,
520 seed: Option<u64>,
521 format: &str,
522) -> SparseResult<Box<dyn SparseArray<T>>>
523where
524 T: Float
525 + Add<Output = T>
526 + Sub<Output = T>
527 + Mul<Output = T>
528 + Div<Output = T>
529 + Debug
530 + Copy
531 + Send
532 + Sync
533 + 'static,
534{
535 let (rows, cols) = shape;
536 let total_elements = rows * cols;
537 let expected_nnz = (total_elements as f64 * density) as usize;
538
539 let num_chunks = std::cmp::min(scirs2_core::parallel_ops::get_num_threads(), rows.min(cols));
541 let chunk_size = std::cmp::max(1, rows / num_chunks);
542
543 let row_chunks: Vec<_> = (0..rows)
545 .collect::<Vec<_>>()
546 .chunks(chunk_size)
547 .map(|chunk| chunk.to_vec())
548 .collect();
549
550 let chunk_data: Vec<_> = row_chunks.iter().enumerate().collect();
552 let results: Vec<_> = parallel_map(&chunk_data, |(chunk_idx, row_chunk)| {
553 let mut local_rows = Vec::new();
554 let mut local_cols = Vec::new();
555 let mut local_data = Vec::new();
556
557 let chunk_seed = seed.unwrap_or(42) + *chunk_idx as u64 * 1000007; let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(chunk_seed);
560
561 for &row in row_chunk.iter() {
562 let row_elements = cols;
564 let row_expected_nnz = std::cmp::max(1, (row_elements as f64 * density) as usize);
565
566 let mut col_indices: Vec<usize> = (0..cols).collect();
568 col_indices.shuffle(&mut rng);
569
570 for &col in col_indices.iter().take(row_expected_nnz) {
572 let mut val = rng.random_range(-1.0..1.0);
574 while val.abs() < 1e-10 {
576 val = rng.random_range(-1.0..1.0);
577 }
578
579 local_rows.push(row);
580 local_cols.push(col);
581 local_data.push(T::from(val).unwrap());
582 }
583 }
584
585 (local_rows, local_cols, local_data)
586 });
587
588 let mut all_rows = Vec::new();
590 let mut all_cols = Vec::new();
591 let mut all_data = Vec::new();
592
593 for (mut rowschunk, mut cols_chunk, mut data_chunk) in results {
594 all_rows.extend(rowschunk);
595 all_cols.append(&mut cols_chunk);
596 all_data.append(&mut data_chunk);
597 }
598
599 match format.to_lowercase().as_str() {
601 "csr" => CsrArray::from_triplets(&all_rows, &all_cols, &all_data, shape, false)
602 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
603 "coo" => CooArray::from_triplets(&all_rows, &all_cols, &all_data, shape, false)
604 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
605 "dok" => DokArray::from_triplets(&all_rows, &all_cols, &all_data, shape)
606 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
607 "lil" => LilArray::from_triplets(&all_rows, &all_cols, &all_data, shape)
608 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
609 _ => Err(SparseError::ValueError(format!(
610 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
611 ))),
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618
619 #[test]
620 fn test_eye_array() {
621 let eye = eye_array::<f64>(3, "csr").unwrap();
622
623 assert_eq!(eye.shape(), (3, 3));
624 assert_eq!(eye.nnz(), 3);
625 assert_eq!(eye.get(0, 0), 1.0);
626 assert_eq!(eye.get(1, 1), 1.0);
627 assert_eq!(eye.get(2, 2), 1.0);
628 assert_eq!(eye.get(0, 1), 0.0);
629
630 let eye_coo = eye_array::<f64>(3, "coo").unwrap();
632 assert_eq!(eye_coo.shape(), (3, 3));
633 assert_eq!(eye_coo.nnz(), 3);
634
635 let eye_dok = eye_array::<f64>(3, "dok").unwrap();
637 assert_eq!(eye_dok.shape(), (3, 3));
638 assert_eq!(eye_dok.nnz(), 3);
639 assert_eq!(eye_dok.get(0, 0), 1.0);
640 assert_eq!(eye_dok.get(1, 1), 1.0);
641 assert_eq!(eye_dok.get(2, 2), 1.0);
642
643 let eye_lil = eye_array::<f64>(3, "lil").unwrap();
645 assert_eq!(eye_lil.shape(), (3, 3));
646 assert_eq!(eye_lil.nnz(), 3);
647 assert_eq!(eye_lil.get(0, 0), 1.0);
648 assert_eq!(eye_lil.get(1, 1), 1.0);
649 assert_eq!(eye_lil.get(2, 2), 1.0);
650 }
651
652 #[test]
653 fn test_eye_array_k() {
654 let eye = eye_array_k::<f64>(3, 3, 0, "csr").unwrap();
656 assert_eq!(eye.get(0, 0), 1.0);
657 assert_eq!(eye.get(1, 1), 1.0);
658 assert_eq!(eye.get(2, 2), 1.0);
659
660 let superdiag = eye_array_k::<f64>(3, 4, 1, "csr").unwrap();
662 assert_eq!(superdiag.get(0, 1), 1.0);
663 assert_eq!(superdiag.get(1, 2), 1.0);
664 assert_eq!(superdiag.get(2, 3), 1.0);
665
666 let subdiag = eye_array_k::<f64>(4, 3, -1, "csr").unwrap();
668 assert_eq!(subdiag.get(1, 0), 1.0);
669 assert_eq!(subdiag.get(2, 1), 1.0);
670 assert_eq!(subdiag.get(3, 2), 1.0);
671
672 let eye_lil = eye_array_k::<f64>(3, 3, 0, "lil").unwrap();
674 assert_eq!(eye_lil.get(0, 0), 1.0);
675 assert_eq!(eye_lil.get(1, 1), 1.0);
676 assert_eq!(eye_lil.get(2, 2), 1.0);
677 }
678
679 #[test]
680 fn test_diags_array() {
681 let diags = vec![
682 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0]), ];
685 let offsets = vec![0, 1];
686 let shape = (3, 3);
687
688 let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
689 assert_eq!(result.shape(), (3, 3));
690 assert_eq!(result.get(0, 0), 1.0);
691 assert_eq!(result.get(1, 1), 2.0);
692 assert_eq!(result.get(2, 2), 3.0);
693 assert_eq!(result.get(0, 1), 4.0);
694 assert_eq!(result.get(1, 2), 5.0);
695
696 let diags = vec![
698 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0]), Array1::from_vec(vec![6.0, 7.0]), ];
702 let offsets = vec![0, 1, -1];
703 let shape = (3, 3);
704
705 let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
706 assert_eq!(result.shape(), (3, 3));
707 assert_eq!(result.get(0, 0), 1.0);
708 assert_eq!(result.get(1, 1), 2.0);
709 assert_eq!(result.get(2, 2), 3.0);
710 assert_eq!(result.get(0, 1), 4.0);
711 assert_eq!(result.get(1, 2), 5.0);
712 assert_eq!(result.get(1, 0), 6.0);
713 assert_eq!(result.get(2, 1), 7.0);
714
715 let result_lil = diags_array(&diags, &offsets, shape, "lil").unwrap();
717 assert_eq!(result_lil.shape(), (3, 3));
718 assert_eq!(result_lil.get(0, 0), 1.0);
719 assert_eq!(result_lil.get(1, 1), 2.0);
720 assert_eq!(result_lil.get(2, 2), 3.0);
721 assert_eq!(result_lil.get(0, 1), 4.0);
722 assert_eq!(result_lil.get(1, 2), 5.0);
723 assert_eq!(result_lil.get(1, 0), 6.0);
724 assert_eq!(result_lil.get(2, 1), 7.0);
725 }
726
727 #[test]
728 fn test_random_array() {
729 let shape = (10, 10);
730 let density = 0.3;
731
732 let random = random_array::<f64>(shape, density, None, "csr").unwrap();
733
734 assert_eq!(random.shape(), shape);
736 let nnz = random.nnz();
737 let expected_nnz = (shape.0 * shape.1) as f64 * density;
738
739 assert!(
741 (nnz as f64) > expected_nnz * 0.7,
742 "Too few non-zeros: {nnz}"
743 );
744 assert!(
745 (nnz as f64) < expected_nnz * 1.3,
746 "Too many non-zeros: {nnz}"
747 );
748
749 let random_seeded = random_array::<f64>(shape, density, Some(42), "csr").unwrap();
751 assert_eq!(random_seeded.shape(), shape);
752
753 let random_lil = random_array::<f64>((5, 5), 0.5, Some(42), "lil").unwrap();
755 assert_eq!(random_lil.shape(), (5, 5));
756 let nnz_lil = random_lil.nnz();
757 let expected_nnz_lil = 25.0 * 0.5;
758 assert!(
759 (nnz_lil as f64) > expected_nnz_lil * 0.7,
760 "Too few non-zeros in LIL: {nnz_lil}"
761 );
762 assert!(
763 (nnz_lil as f64) < expected_nnz_lil * 1.3,
764 "Too many non-zeros in LIL: {nnz_lil}"
765 );
766 }
767}