1use ndarray::Array1;
7use num_traits::Float;
8use rand::seq::SliceRandom;
9use rand::{Rng, SeedableRng};
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::csr_array::CsrArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::SparseArray;
19
20pub fn eye_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SparseArray<T>>>
43where
44 T: Float
45 + Add<Output = T>
46 + Sub<Output = T>
47 + Mul<Output = T>
48 + Div<Output = T>
49 + Debug
50 + Copy
51 + 'static,
52{
53 if n == 0 {
54 return Err(SparseError::ValueError(
55 "Matrix dimension must be positive".to_string(),
56 ));
57 }
58
59 let mut rows = Vec::with_capacity(n);
60 let mut cols = Vec::with_capacity(n);
61 let mut data = Vec::with_capacity(n);
62
63 for i in 0..n {
64 rows.push(i);
65 cols.push(i);
66 data.push(T::one());
67 }
68
69 match format.to_lowercase().as_str() {
70 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
71 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
72 "coo" => CooArray::from_triplets(&rows, &cols, &data, (n, n), true)
73 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
74 "dok" => DokArray::from_triplets(&rows, &cols, &data, (n, n))
75 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
76 "lil" => LilArray::from_triplets(&rows, &cols, &data, (n, n))
77 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
78 _ => Err(SparseError::ValueError(format!(
79 "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
80 format
81 ))),
82 }
83}
84
85pub fn eye_array_k<T>(
120 m: usize,
121 n: usize,
122 k: isize,
123 format: &str,
124) -> SparseResult<Box<dyn SparseArray<T>>>
125where
126 T: Float
127 + Add<Output = T>
128 + Sub<Output = T>
129 + Mul<Output = T>
130 + Div<Output = T>
131 + Debug
132 + Copy
133 + 'static,
134{
135 if m == 0 || n == 0 {
136 return Err(SparseError::ValueError(
137 "Matrix dimensions must be positive".to_string(),
138 ));
139 }
140
141 let mut rows = Vec::new();
142 let mut cols = Vec::new();
143 let mut data = Vec::new();
144
145 if k >= 0 {
147 let k_usize = k as usize;
148 let len = std::cmp::min(m, n.saturating_sub(k_usize));
149
150 for i in 0..len {
151 rows.push(i);
152 cols.push(i + k_usize);
153 data.push(T::one());
154 }
155 } else {
156 let k_abs = (-k) as usize;
157 let len = std::cmp::min(m.saturating_sub(k_abs), n);
158
159 for i in 0..len {
160 rows.push(i + k_abs);
161 cols.push(i);
162 data.push(T::one());
163 }
164 }
165
166 match format.to_lowercase().as_str() {
167 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
168 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
169 "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), true)
170 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
171 "dok" => DokArray::from_triplets(&rows, &cols, &data, (m, n))
172 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
173 "lil" => LilArray::from_triplets(&rows, &cols, &data, (m, n))
174 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
175 _ => Err(SparseError::ValueError(format!(
176 "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
177 format
178 ))),
179 }
180}
181
182pub fn diags_array<T>(
215 diagonals: &[Array1<T>],
216 offsets: &[isize],
217 shape: (usize, usize),
218 format: &str,
219) -> SparseResult<Box<dyn SparseArray<T>>>
220where
221 T: Float
222 + Add<Output = T>
223 + Sub<Output = T>
224 + Mul<Output = T>
225 + Div<Output = T>
226 + Debug
227 + Copy
228 + 'static,
229{
230 if diagonals.len() != offsets.len() {
231 return Err(SparseError::InconsistentData {
232 reason: "Number of diagonals must match number of offsets".to_string(),
233 });
234 }
235
236 if shape.0 == 0 || shape.1 == 0 {
237 return Err(SparseError::ValueError(
238 "Matrix dimensions must be positive".to_string(),
239 ));
240 }
241
242 let (m, n) = shape;
243 let mut rows = Vec::new();
244 let mut cols = Vec::new();
245 let mut data = Vec::new();
246
247 for (i, (diag, &offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
248 if offset >= 0 {
249 let offset_usize = offset as usize;
250 let max_len = std::cmp::min(m, n.saturating_sub(offset_usize));
251
252 if diag.len() > max_len {
253 return Err(SparseError::InconsistentData {
254 reason: format!("Diagonal {} is too long ({} > {})", i, diag.len(), max_len),
255 });
256 }
257
258 for (j, &value) in diag.iter().enumerate() {
259 if !value.is_zero() {
260 rows.push(j);
261 cols.push(j + offset_usize);
262 data.push(value);
263 }
264 }
265 } else {
266 let offset_abs = (-offset) as usize;
267 let max_len = std::cmp::min(m.saturating_sub(offset_abs), n);
268
269 if diag.len() > max_len {
270 return Err(SparseError::InconsistentData {
271 reason: format!("Diagonal {} is too long ({} > {})", i, diag.len(), max_len),
272 });
273 }
274
275 for (j, &value) in diag.iter().enumerate() {
276 if !value.is_zero() {
277 rows.push(j + offset_abs);
278 cols.push(j);
279 data.push(value);
280 }
281 }
282 }
283 }
284
285 match format.to_lowercase().as_str() {
286 "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
287 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
288 "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
289 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
290 "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
291 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
292 "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
293 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
294 _ => Err(SparseError::ValueError(format!(
295 "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
296 format
297 ))),
298 }
299}
300
301pub fn random_array<T>(
326 shape: (usize, usize),
327 density: f64,
328 seed: Option<u64>,
329 format: &str,
330) -> SparseResult<Box<dyn SparseArray<T>>>
331where
332 T: Float
333 + Add<Output = T>
334 + Sub<Output = T>
335 + Mul<Output = T>
336 + Div<Output = T>
337 + Debug
338 + Copy
339 + 'static,
340{
341 let (m, n) = shape;
342
343 if !(0.0..=1.0).contains(&density) {
344 return Err(SparseError::ValueError(
345 "Density must be between 0.0 and 1.0".to_string(),
346 ));
347 }
348
349 if m == 0 || n == 0 {
350 return Err(SparseError::ValueError(
351 "Matrix dimensions must be positive".to_string(),
352 ));
353 }
354
355 let nnz = (m * n) as f64 * density;
357 let nnz = nnz.round() as usize;
358
359 let mut rows = Vec::with_capacity(nnz);
361 let mut cols = Vec::with_capacity(nnz);
362 let mut data = Vec::with_capacity(nnz);
363
364 let mut rng = if let Some(seed_value) = seed {
366 rand::rngs::StdRng::seed_from_u64(seed_value)
367 } else {
368 let seed = rand::Rng::random::<u64>(&mut rand::rng());
370 rand::rngs::StdRng::seed_from_u64(seed)
371 };
372
373 let total = m * n;
375
376 if density > 0.4 {
377 let mut indices: Vec<usize> = (0..total).collect();
379 indices.shuffle(&mut rng);
380
381 for &idx in indices.iter().take(nnz) {
382 let row = idx / n;
383 let col = idx % n;
384
385 rows.push(row);
386 cols.push(col);
387
388 let mut val: f64 = rng.random_range(-1.0..1.0);
391 while val.abs() < 1e-10 {
393 val = rng.random_range(-1.0..1.0);
394 }
395 data.push(T::from(val).unwrap());
396 }
397 } else {
398 let mut positions = std::collections::HashSet::with_capacity(nnz);
400
401 while positions.len() < nnz {
402 let row = rng.random_range(0..m);
403 let col = rng.random_range(0..n);
404 let pos = row * n + col; if positions.insert(pos) {
407 rows.push(row);
408 cols.push(col);
409
410 let mut val: f64 = rng.random_range(-1.0..1.0);
412 while val.abs() < 1e-10 {
414 val = rng.random_range(-1.0..1.0);
415 }
416 data.push(T::from(val).unwrap());
417 }
418 }
419 }
420
421 match format.to_lowercase().as_str() {
423 "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
424 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
425 "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
426 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
427 "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
428 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
429 "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
430 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
431 _ => Err(SparseError::ValueError(format!(
432 "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
433 format
434 ))),
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_eye_array() {
444 let eye = eye_array::<f64>(3, "csr").unwrap();
445
446 assert_eq!(eye.shape(), (3, 3));
447 assert_eq!(eye.nnz(), 3);
448 assert_eq!(eye.get(0, 0), 1.0);
449 assert_eq!(eye.get(1, 1), 1.0);
450 assert_eq!(eye.get(2, 2), 1.0);
451 assert_eq!(eye.get(0, 1), 0.0);
452
453 let eye_coo = eye_array::<f64>(3, "coo").unwrap();
455 assert_eq!(eye_coo.shape(), (3, 3));
456 assert_eq!(eye_coo.nnz(), 3);
457
458 let eye_dok = eye_array::<f64>(3, "dok").unwrap();
460 assert_eq!(eye_dok.shape(), (3, 3));
461 assert_eq!(eye_dok.nnz(), 3);
462 assert_eq!(eye_dok.get(0, 0), 1.0);
463 assert_eq!(eye_dok.get(1, 1), 1.0);
464 assert_eq!(eye_dok.get(2, 2), 1.0);
465
466 let eye_lil = eye_array::<f64>(3, "lil").unwrap();
468 assert_eq!(eye_lil.shape(), (3, 3));
469 assert_eq!(eye_lil.nnz(), 3);
470 assert_eq!(eye_lil.get(0, 0), 1.0);
471 assert_eq!(eye_lil.get(1, 1), 1.0);
472 assert_eq!(eye_lil.get(2, 2), 1.0);
473 }
474
475 #[test]
476 fn test_eye_array_k() {
477 let eye = eye_array_k::<f64>(3, 3, 0, "csr").unwrap();
479 assert_eq!(eye.get(0, 0), 1.0);
480 assert_eq!(eye.get(1, 1), 1.0);
481 assert_eq!(eye.get(2, 2), 1.0);
482
483 let superdiag = eye_array_k::<f64>(3, 4, 1, "csr").unwrap();
485 assert_eq!(superdiag.get(0, 1), 1.0);
486 assert_eq!(superdiag.get(1, 2), 1.0);
487 assert_eq!(superdiag.get(2, 3), 1.0);
488
489 let subdiag = eye_array_k::<f64>(4, 3, -1, "csr").unwrap();
491 assert_eq!(subdiag.get(1, 0), 1.0);
492 assert_eq!(subdiag.get(2, 1), 1.0);
493 assert_eq!(subdiag.get(3, 2), 1.0);
494
495 let eye_lil = eye_array_k::<f64>(3, 3, 0, "lil").unwrap();
497 assert_eq!(eye_lil.get(0, 0), 1.0);
498 assert_eq!(eye_lil.get(1, 1), 1.0);
499 assert_eq!(eye_lil.get(2, 2), 1.0);
500 }
501
502 #[test]
503 fn test_diags_array() {
504 let diags = vec![
505 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0]), ];
508 let offsets = vec![0, 1];
509 let shape = (3, 3);
510
511 let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
512 assert_eq!(result.shape(), (3, 3));
513 assert_eq!(result.get(0, 0), 1.0);
514 assert_eq!(result.get(1, 1), 2.0);
515 assert_eq!(result.get(2, 2), 3.0);
516 assert_eq!(result.get(0, 1), 4.0);
517 assert_eq!(result.get(1, 2), 5.0);
518
519 let diags = vec![
521 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0]), Array1::from_vec(vec![6.0, 7.0]), ];
525 let offsets = vec![0, 1, -1];
526 let shape = (3, 3);
527
528 let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
529 assert_eq!(result.shape(), (3, 3));
530 assert_eq!(result.get(0, 0), 1.0);
531 assert_eq!(result.get(1, 1), 2.0);
532 assert_eq!(result.get(2, 2), 3.0);
533 assert_eq!(result.get(0, 1), 4.0);
534 assert_eq!(result.get(1, 2), 5.0);
535 assert_eq!(result.get(1, 0), 6.0);
536 assert_eq!(result.get(2, 1), 7.0);
537
538 let result_lil = diags_array(&diags, &offsets, shape, "lil").unwrap();
540 assert_eq!(result_lil.shape(), (3, 3));
541 assert_eq!(result_lil.get(0, 0), 1.0);
542 assert_eq!(result_lil.get(1, 1), 2.0);
543 assert_eq!(result_lil.get(2, 2), 3.0);
544 assert_eq!(result_lil.get(0, 1), 4.0);
545 assert_eq!(result_lil.get(1, 2), 5.0);
546 assert_eq!(result_lil.get(1, 0), 6.0);
547 assert_eq!(result_lil.get(2, 1), 7.0);
548 }
549
550 #[test]
551 fn test_random_array() {
552 let shape = (10, 10);
553 let density = 0.3;
554
555 let random = random_array::<f64>(shape, density, None, "csr").unwrap();
556
557 assert_eq!(random.shape(), shape);
559 let nnz = random.nnz();
560 let expected_nnz = (shape.0 * shape.1) as f64 * density;
561
562 assert!(
564 (nnz as f64) > expected_nnz * 0.7,
565 "Too few non-zeros: {}",
566 nnz
567 );
568 assert!(
569 (nnz as f64) < expected_nnz * 1.3,
570 "Too many non-zeros: {}",
571 nnz
572 );
573
574 let random_seeded = random_array::<f64>(shape, density, Some(42), "csr").unwrap();
576 assert_eq!(random_seeded.shape(), shape);
577
578 let random_lil = random_array::<f64>((5, 5), 0.5, Some(42), "lil").unwrap();
580 assert_eq!(random_lil.shape(), (5, 5));
581 let nnz_lil = random_lil.nnz();
582 let expected_nnz_lil = 25.0 * 0.5;
583 assert!(
584 (nnz_lil as f64) > expected_nnz_lil * 0.7,
585 "Too few non-zeros in LIL: {}",
586 nnz_lil
587 );
588 assert!(
589 (nnz_lil as f64) < expected_nnz_lil * 1.3,
590 "Too many non-zeros in LIL: {}",
591 nnz_lil
592 );
593 }
594}