1#![allow(unused_variables)]
7#![allow(unused_assignments)]
8#![allow(unused_mut)]
9
10use scirs2_core::ndarray::Array1;
11use scirs2_core::numeric::{Float, SparseElement};
12use scirs2_core::random::seq::SliceRandom;
13use scirs2_core::random::{Rng, SeedableRng};
14use std::fmt::Debug;
15use std::ops::{Add, Div, Mul, Sub};
16
17use crate::coo_array::CooArray;
18use crate::csr_array::CsrArray;
19use crate::dok_array::DokArray;
20use crate::error::{SparseError, SparseResult};
21use crate::lil_array::LilArray;
22use crate::sparray::SparseArray;
23
24use scirs2_core::parallel_ops::*;
26
27#[allow(dead_code)]
50pub fn eye_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SparseArray<T>>>
51where
52 T: SparseElement + Div<Output = T> + Float + 'static,
53{
54 if n == 0 {
55 return Err(SparseError::ValueError(
56 "Matrix dimension must be positive".to_string(),
57 ));
58 }
59
60 let mut rows = Vec::with_capacity(n);
61 let mut cols = Vec::with_capacity(n);
62 let mut data = Vec::with_capacity(n);
63
64 for i in 0..n {
65 rows.push(i);
66 cols.push(i);
67 data.push(T::sparse_one());
68 }
69
70 match format.to_lowercase().as_str() {
71 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
72 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
73 "coo" => CooArray::from_triplets(&rows, &cols, &data, (n, n), true)
74 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
75 "dok" => DokArray::from_triplets(&rows, &cols, &data, (n, n))
76 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
77 "lil" => LilArray::from_triplets(&rows, &cols, &data, (n, n))
78 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
79 _ => Err(SparseError::ValueError(format!(
80 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
81 ))),
82 }
83}
84
85#[allow(dead_code)]
120pub fn eye_array_k<T>(
121 m: usize,
122 n: usize,
123 k: isize,
124 format: &str,
125) -> SparseResult<Box<dyn SparseArray<T>>>
126where
127 T: SparseElement + Div<Output = T> + Float + 'static,
128{
129 if m == 0 || n == 0 {
130 return Err(SparseError::ValueError(
131 "Matrix dimensions must be positive".to_string(),
132 ));
133 }
134
135 let mut rows = Vec::new();
136 let mut cols = Vec::new();
137 let mut data = Vec::new();
138
139 if k >= 0 {
141 let k_usize = k as usize;
142 let len = std::cmp::min(m, n.saturating_sub(k_usize));
143
144 for i in 0..len {
145 rows.push(i);
146 cols.push(i + k_usize);
147 data.push(T::sparse_one());
148 }
149 } else {
150 let k_abs = (-k) as usize;
151 let len = std::cmp::min(m.saturating_sub(k_abs), n);
152
153 for i in 0..len {
154 rows.push(i + k_abs);
155 cols.push(i);
156 data.push(T::sparse_one());
157 }
158 }
159
160 match format.to_lowercase().as_str() {
161 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
162 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
163 "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), true)
164 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
165 "dok" => DokArray::from_triplets(&rows, &cols, &data, (m, n))
166 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
167 "lil" => LilArray::from_triplets(&rows, &cols, &data, (m, n))
168 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
169 _ => Err(SparseError::ValueError(format!(
170 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
171 ))),
172 }
173}
174
175#[allow(dead_code)]
208pub fn diags_array<T>(
209 diagonals: &[Array1<T>],
210 offsets: &[isize],
211 shape: (usize, usize),
212 format: &str,
213) -> SparseResult<Box<dyn SparseArray<T>>>
214where
215 T: SparseElement + Div<Output = T> + Float + 'static,
216{
217 if diagonals.len() != offsets.len() {
218 return Err(SparseError::InconsistentData {
219 reason: "Number of diagonals must match number of offsets".to_string(),
220 });
221 }
222
223 if shape.0 == 0 || shape.1 == 0 {
224 return Err(SparseError::ValueError(
225 "Matrix dimensions must be positive".to_string(),
226 ));
227 }
228
229 let (m, n) = shape;
230 let mut rows = Vec::new();
231 let mut cols = Vec::new();
232 let mut data = Vec::new();
233
234 for (i, (diag, &offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
235 if offset >= 0 {
236 let offset_usize = offset as usize;
237 let max_len = std::cmp::min(m, n.saturating_sub(offset_usize));
238
239 if diag.len() > max_len {
240 return Err(SparseError::InconsistentData {
241 reason: format!("Diagonal {i} is too long ({} > {})", diag.len(), max_len),
242 });
243 }
244
245 for (j, &value) in diag.iter().enumerate() {
246 if !SparseElement::is_zero(&value) {
247 rows.push(j);
248 cols.push(j + offset_usize);
249 data.push(value);
250 }
251 }
252 } else {
253 let offset_abs = (-offset) as usize;
254 let max_len = std::cmp::min(m.saturating_sub(offset_abs), n);
255
256 if diag.len() > max_len {
257 return Err(SparseError::InconsistentData {
258 reason: format!("Diagonal {i} is too long ({} > {})", diag.len(), max_len),
259 });
260 }
261
262 for (j, &value) in diag.iter().enumerate() {
263 if !SparseElement::is_zero(&value) {
264 rows.push(j + offset_abs);
265 cols.push(j);
266 data.push(value);
267 }
268 }
269 }
270 }
271
272 match format.to_lowercase().as_str() {
273 "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
274 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
275 "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
276 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
277 "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
278 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
279 "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
280 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
281 _ => Err(SparseError::ValueError(format!(
282 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
283 ))),
284 }
285}
286
287#[allow(dead_code)]
312pub fn random_array<T>(
313 shape: (usize, usize),
314 density: f64,
315 seed: Option<u64>,
316 format: &str,
317) -> SparseResult<Box<dyn SparseArray<T>>>
318where
319 T: Float + SparseElement + Div<Output = T> + 'static,
320{
321 let (m, n) = shape;
322
323 if !(0.0..=1.0).contains(&density) {
324 return Err(SparseError::ValueError(
325 "Density must be between 0.0 and 1.0".to_string(),
326 ));
327 }
328
329 if m == 0 || n == 0 {
330 return Err(SparseError::ValueError(
331 "Matrix dimensions must be positive".to_string(),
332 ));
333 }
334
335 let nnz = (m * n) as f64 * density;
337 let nnz = nnz.round() as usize;
338
339 let mut rows = Vec::with_capacity(nnz);
341 let mut cols = Vec::with_capacity(nnz);
342 let mut data = Vec::with_capacity(nnz);
343
344 let mut rng = if let Some(seed_value) = seed {
346 scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value)
347 } else {
348 let seed = scirs2_core::random::random::<u64>();
350 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
351 };
352
353 let total = m * n;
355
356 if density > 0.4 {
357 let mut indices: Vec<usize> = (0..total).collect();
359 indices.shuffle(&mut rng);
360
361 for &idx in indices.iter().take(nnz) {
362 let row = idx / n;
363 let col = idx % n;
364
365 rows.push(row);
366 cols.push(col);
367
368 let mut val: f64 = rng.random_range(-1.0..1.0);
371 while val.abs() < 1e-10 {
373 val = rng.random_range(-1.0..1.0);
374 }
375 data.push(T::from(val).unwrap());
376 }
377 } else {
378 let mut positions = std::collections::HashSet::with_capacity(nnz);
380
381 while positions.len() < nnz {
382 let row = rng.random_range(0..m);
383 let col = rng.random_range(0..n);
384 let pos = row * n + col; if positions.insert(pos) {
387 rows.push(row);
388 cols.push(col);
389
390 let mut val: f64 = rng.random_range(-1.0..1.0);
392 while val.abs() < 1e-10 {
394 val = rng.random_range(-1.0..1.0);
395 }
396 data.push(T::from(val).unwrap());
397 }
398 }
399 }
400
401 match format.to_lowercase().as_str() {
403 "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
404 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
405 "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
406 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
407 "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
408 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
409 "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
410 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
411 _ => Err(SparseError::ValueError(format!(
412 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
413 ))),
414 }
415}
416
417#[allow(dead_code)]
443pub fn random_array_parallel<T>(
444 shape: (usize, usize),
445 density: f64,
446 seed: Option<u64>,
447 format: &str,
448 parallel_threshold: usize,
449) -> SparseResult<Box<dyn SparseArray<T>>>
450where
451 T: Float + SparseElement + Div<Output = T> + Send + Sync + 'static,
452{
453 if !(0.0..=1.0).contains(&density) {
454 return Err(SparseError::ValueError(
455 "Density must be between 0.0 and 1.0".to_string(),
456 ));
457 }
458
459 let (rows, cols) = shape;
460 if rows == 0 || cols == 0 {
461 return Err(SparseError::ValueError(
462 "Matrix dimensions must be positive".to_string(),
463 ));
464 }
465
466 let total_elements = rows * cols;
467 let expected_nnz = (total_elements as f64 * density) as usize;
468
469 if total_elements >= parallel_threshold && expected_nnz >= 1000 {
471 parallel_random_construction(shape, density, seed, format)
472 } else {
473 random_array(shape, density, seed, format)
475 }
476}
477
478#[allow(dead_code)]
480fn parallel_random_construction<T>(
481 shape: (usize, usize),
482 density: f64,
483 seed: Option<u64>,
484 format: &str,
485) -> SparseResult<Box<dyn SparseArray<T>>>
486where
487 T: Float + SparseElement + Div<Output = T> + Send + Sync + 'static,
488{
489 let (rows, cols) = shape;
490 let total_elements = rows * cols;
491 let expected_nnz = (total_elements as f64 * density) as usize;
492
493 let num_chunks = std::cmp::min(scirs2_core::parallel_ops::get_num_threads(), rows.min(cols));
495 let chunk_size = std::cmp::max(1, rows / num_chunks);
496
497 let row_chunks: Vec<_> = (0..rows)
499 .collect::<Vec<_>>()
500 .chunks(chunk_size)
501 .map(|chunk| chunk.to_vec())
502 .collect();
503
504 let chunk_data: Vec<_> = row_chunks.iter().enumerate().collect();
506 let results: Vec<_> = parallel_map(&chunk_data, |(chunk_idx, row_chunk)| {
507 let mut local_rows = Vec::new();
508 let mut local_cols = Vec::new();
509 let mut local_data = Vec::new();
510
511 let chunk_seed = seed.unwrap_or(42) + *chunk_idx as u64 * 1000007; let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(chunk_seed);
514
515 for &row in row_chunk.iter() {
516 let row_elements = cols;
518 let row_expected_nnz = std::cmp::max(1, (row_elements as f64 * density) as usize);
519
520 let mut col_indices: Vec<usize> = (0..cols).collect();
522 col_indices.shuffle(&mut rng);
523
524 for &col in col_indices.iter().take(row_expected_nnz) {
526 let mut val = rng.random_range(-1.0..1.0);
528 while val.abs() < 1e-10 {
530 val = rng.random_range(-1.0..1.0);
531 }
532
533 local_rows.push(row);
534 local_cols.push(col);
535 local_data.push(T::from(val).unwrap());
536 }
537 }
538
539 (local_rows, local_cols, local_data)
540 });
541
542 let mut all_rows = Vec::new();
544 let mut all_cols = Vec::new();
545 let mut all_data = Vec::new();
546
547 for (mut rowschunk, mut cols_chunk, mut data_chunk) in results {
548 all_rows.extend(rowschunk);
549 all_cols.append(&mut cols_chunk);
550 all_data.append(&mut data_chunk);
551 }
552
553 match format.to_lowercase().as_str() {
555 "csr" => CsrArray::from_triplets(&all_rows, &all_cols, &all_data, shape, false)
556 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
557 "coo" => CooArray::from_triplets(&all_rows, &all_cols, &all_data, shape, false)
558 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
559 "dok" => DokArray::from_triplets(&all_rows, &all_cols, &all_data, shape)
560 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
561 "lil" => LilArray::from_triplets(&all_rows, &all_cols, &all_data, shape)
562 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
563 _ => Err(SparseError::ValueError(format!(
564 "Unknown sparse format: {format}. Supported formats are 'csr', 'coo', 'dok', and 'lil'"
565 ))),
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_eye_array() {
575 let eye = eye_array::<f64>(3, "csr").unwrap();
576
577 assert_eq!(eye.shape(), (3, 3));
578 assert_eq!(eye.nnz(), 3);
579 assert_eq!(eye.get(0, 0), 1.0);
580 assert_eq!(eye.get(1, 1), 1.0);
581 assert_eq!(eye.get(2, 2), 1.0);
582 assert_eq!(eye.get(0, 1), 0.0);
583
584 let eye_coo = eye_array::<f64>(3, "coo").unwrap();
586 assert_eq!(eye_coo.shape(), (3, 3));
587 assert_eq!(eye_coo.nnz(), 3);
588
589 let eye_dok = eye_array::<f64>(3, "dok").unwrap();
591 assert_eq!(eye_dok.shape(), (3, 3));
592 assert_eq!(eye_dok.nnz(), 3);
593 assert_eq!(eye_dok.get(0, 0), 1.0);
594 assert_eq!(eye_dok.get(1, 1), 1.0);
595 assert_eq!(eye_dok.get(2, 2), 1.0);
596
597 let eye_lil = eye_array::<f64>(3, "lil").unwrap();
599 assert_eq!(eye_lil.shape(), (3, 3));
600 assert_eq!(eye_lil.nnz(), 3);
601 assert_eq!(eye_lil.get(0, 0), 1.0);
602 assert_eq!(eye_lil.get(1, 1), 1.0);
603 assert_eq!(eye_lil.get(2, 2), 1.0);
604 }
605
606 #[test]
607 fn test_eye_array_k() {
608 let eye = eye_array_k::<f64>(3, 3, 0, "csr").unwrap();
610 assert_eq!(eye.get(0, 0), 1.0);
611 assert_eq!(eye.get(1, 1), 1.0);
612 assert_eq!(eye.get(2, 2), 1.0);
613
614 let superdiag = eye_array_k::<f64>(3, 4, 1, "csr").unwrap();
616 assert_eq!(superdiag.get(0, 1), 1.0);
617 assert_eq!(superdiag.get(1, 2), 1.0);
618 assert_eq!(superdiag.get(2, 3), 1.0);
619
620 let subdiag = eye_array_k::<f64>(4, 3, -1, "csr").unwrap();
622 assert_eq!(subdiag.get(1, 0), 1.0);
623 assert_eq!(subdiag.get(2, 1), 1.0);
624 assert_eq!(subdiag.get(3, 2), 1.0);
625
626 let eye_lil = eye_array_k::<f64>(3, 3, 0, "lil").unwrap();
628 assert_eq!(eye_lil.get(0, 0), 1.0);
629 assert_eq!(eye_lil.get(1, 1), 1.0);
630 assert_eq!(eye_lil.get(2, 2), 1.0);
631 }
632
633 #[test]
634 fn test_diags_array() {
635 let diags = vec![
636 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0]), ];
639 let offsets = vec![0, 1];
640 let shape = (3, 3);
641
642 let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
643 assert_eq!(result.shape(), (3, 3));
644 assert_eq!(result.get(0, 0), 1.0);
645 assert_eq!(result.get(1, 1), 2.0);
646 assert_eq!(result.get(2, 2), 3.0);
647 assert_eq!(result.get(0, 1), 4.0);
648 assert_eq!(result.get(1, 2), 5.0);
649
650 let diags = vec![
652 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0]), Array1::from_vec(vec![6.0, 7.0]), ];
656 let offsets = vec![0, 1, -1];
657 let shape = (3, 3);
658
659 let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
660 assert_eq!(result.shape(), (3, 3));
661 assert_eq!(result.get(0, 0), 1.0);
662 assert_eq!(result.get(1, 1), 2.0);
663 assert_eq!(result.get(2, 2), 3.0);
664 assert_eq!(result.get(0, 1), 4.0);
665 assert_eq!(result.get(1, 2), 5.0);
666 assert_eq!(result.get(1, 0), 6.0);
667 assert_eq!(result.get(2, 1), 7.0);
668
669 let result_lil = diags_array(&diags, &offsets, shape, "lil").unwrap();
671 assert_eq!(result_lil.shape(), (3, 3));
672 assert_eq!(result_lil.get(0, 0), 1.0);
673 assert_eq!(result_lil.get(1, 1), 2.0);
674 assert_eq!(result_lil.get(2, 2), 3.0);
675 assert_eq!(result_lil.get(0, 1), 4.0);
676 assert_eq!(result_lil.get(1, 2), 5.0);
677 assert_eq!(result_lil.get(1, 0), 6.0);
678 assert_eq!(result_lil.get(2, 1), 7.0);
679 }
680
681 #[test]
682 fn test_random_array() {
683 let shape = (10, 10);
684 let density = 0.3;
685
686 let random = random_array::<f64>(shape, density, None, "csr").unwrap();
687
688 assert_eq!(random.shape(), shape);
690 let nnz = random.nnz();
691 let expected_nnz = (shape.0 * shape.1) as f64 * density;
692
693 assert!(
695 (nnz as f64) > expected_nnz * 0.7,
696 "Too few non-zeros: {nnz}"
697 );
698 assert!(
699 (nnz as f64) < expected_nnz * 1.3,
700 "Too many non-zeros: {nnz}"
701 );
702
703 let random_seeded = random_array::<f64>(shape, density, Some(42), "csr").unwrap();
705 assert_eq!(random_seeded.shape(), shape);
706
707 let random_lil = random_array::<f64>((5, 5), 0.5, Some(42), "lil").unwrap();
709 assert_eq!(random_lil.shape(), (5, 5));
710 let nnz_lil = random_lil.nnz();
711 let expected_nnz_lil = 25.0 * 0.5;
712 assert!(
713 (nnz_lil as f64) > expected_nnz_lil * 0.7,
714 "Too few non-zeros in LIL: {nnz_lil}"
715 );
716 assert!(
717 (nnz_lil as f64) < expected_nnz_lil * 1.3,
718 "Too many non-zeros in LIL: {nnz_lil}"
719 );
720 }
721}