1use ndarray::Array1;
7use num_traits::Float;
8use rand::seq::SliceRandom;
9use rand::{Rng, SeedableRng};
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::csr_array::CsrArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::SparseArray;
19
20pub fn eye_array<T>(n: usize, format: &str) -> SparseResult<Box<dyn SparseArray<T>>>
43where
44 T: Float
45 + Add<Output = T>
46 + Sub<Output = T>
47 + Mul<Output = T>
48 + Div<Output = T>
49 + Debug
50 + Copy
51 + 'static,
52{
53 if n == 0 {
54 return Err(SparseError::ValueError(
55 "Matrix dimension must be positive".to_string(),
56 ));
57 }
58
59 let mut rows = Vec::with_capacity(n);
60 let mut cols = Vec::with_capacity(n);
61 let mut data = Vec::with_capacity(n);
62
63 for i in 0..n {
64 rows.push(i);
65 cols.push(i);
66 data.push(T::one());
67 }
68
69 match format.to_lowercase().as_str() {
70 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
71 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
72 "coo" => CooArray::from_triplets(&rows, &cols, &data, (n, n), true)
73 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
74 "dok" => DokArray::from_triplets(&rows, &cols, &data, (n, n))
75 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
76 "lil" => LilArray::from_triplets(&rows, &cols, &data, (n, n))
77 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
78 _ => Err(SparseError::ValueError(format!(
79 "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
80 format
81 ))),
82 }
83}
84
85pub fn eye_array_k<T>(
120 m: usize,
121 n: usize,
122 k: isize,
123 format: &str,
124) -> SparseResult<Box<dyn SparseArray<T>>>
125where
126 T: Float
127 + Add<Output = T>
128 + Sub<Output = T>
129 + Mul<Output = T>
130 + Div<Output = T>
131 + Debug
132 + Copy
133 + 'static,
134{
135 if m == 0 || n == 0 {
136 return Err(SparseError::ValueError(
137 "Matrix dimensions must be positive".to_string(),
138 ));
139 }
140
141 let mut rows = Vec::new();
142 let mut cols = Vec::new();
143 let mut data = Vec::new();
144
145 if k >= 0 {
147 let k_usize = k as usize;
148 let len = std::cmp::min(m, n.saturating_sub(k_usize));
149
150 for i in 0..len {
151 rows.push(i);
152 cols.push(i + k_usize);
153 data.push(T::one());
154 }
155 } else {
156 let k_abs = (-k) as usize;
157 let len = std::cmp::min(m.saturating_sub(k_abs), n);
158
159 for i in 0..len {
160 rows.push(i + k_abs);
161 cols.push(i);
162 data.push(T::one());
163 }
164 }
165
166 match format.to_lowercase().as_str() {
167 "csr" => CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
168 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
169 "coo" => CooArray::from_triplets(&rows, &cols, &data, (m, n), true)
170 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
171 "dok" => DokArray::from_triplets(&rows, &cols, &data, (m, n))
172 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
173 "lil" => LilArray::from_triplets(&rows, &cols, &data, (m, n))
174 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
175 _ => Err(SparseError::ValueError(format!(
176 "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
177 format
178 ))),
179 }
180}
181
182pub fn diags_array<T>(
215 diagonals: &[Array1<T>],
216 offsets: &[isize],
217 shape: (usize, usize),
218 format: &str,
219) -> SparseResult<Box<dyn SparseArray<T>>>
220where
221 T: Float
222 + Add<Output = T>
223 + Sub<Output = T>
224 + Mul<Output = T>
225 + Div<Output = T>
226 + Debug
227 + Copy
228 + 'static,
229{
230 if diagonals.len() != offsets.len() {
231 return Err(SparseError::InconsistentData {
232 reason: "Number of diagonals must match number of offsets".to_string(),
233 });
234 }
235
236 if shape.0 == 0 || shape.1 == 0 {
237 return Err(SparseError::ValueError(
238 "Matrix dimensions must be positive".to_string(),
239 ));
240 }
241
242 let (m, n) = shape;
243 let mut rows = Vec::new();
244 let mut cols = Vec::new();
245 let mut data = Vec::new();
246
247 for (i, (diag, &offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
248 if offset >= 0 {
249 let offset_usize = offset as usize;
250 let max_len = std::cmp::min(m, n.saturating_sub(offset_usize));
251
252 if diag.len() > max_len {
253 return Err(SparseError::InconsistentData {
254 reason: format!("Diagonal {} is too long ({} > {})", i, diag.len(), max_len),
255 });
256 }
257
258 for (j, &value) in diag.iter().enumerate() {
259 if !value.is_zero() {
260 rows.push(j);
261 cols.push(j + offset_usize);
262 data.push(value);
263 }
264 }
265 } else {
266 let offset_abs = (-offset) as usize;
267 let max_len = std::cmp::min(m.saturating_sub(offset_abs), n);
268
269 if diag.len() > max_len {
270 return Err(SparseError::InconsistentData {
271 reason: format!("Diagonal {} is too long ({} > {})", i, diag.len(), max_len),
272 });
273 }
274
275 for (j, &value) in diag.iter().enumerate() {
276 if !value.is_zero() {
277 rows.push(j + offset_abs);
278 cols.push(j);
279 data.push(value);
280 }
281 }
282 }
283 }
284
285 match format.to_lowercase().as_str() {
286 "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
287 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
288 "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
289 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
290 "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
291 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
292 "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
293 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
294 _ => Err(SparseError::ValueError(format!(
295 "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
296 format
297 ))),
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_eye_array() {
307 let eye = eye_array::<f64>(3, "csr").unwrap();
308
309 assert_eq!(eye.shape(), (3, 3));
310 assert_eq!(eye.nnz(), 3);
311 assert_eq!(eye.get(0, 0), 1.0);
312 assert_eq!(eye.get(1, 1), 1.0);
313 assert_eq!(eye.get(2, 2), 1.0);
314 assert_eq!(eye.get(0, 1), 0.0);
315
316 let eye_coo = eye_array::<f64>(3, "coo").unwrap();
318 assert_eq!(eye_coo.shape(), (3, 3));
319 assert_eq!(eye_coo.nnz(), 3);
320
321 let eye_dok = eye_array::<f64>(3, "dok").unwrap();
323 assert_eq!(eye_dok.shape(), (3, 3));
324 assert_eq!(eye_dok.nnz(), 3);
325 assert_eq!(eye_dok.get(0, 0), 1.0);
326 assert_eq!(eye_dok.get(1, 1), 1.0);
327 assert_eq!(eye_dok.get(2, 2), 1.0);
328
329 let eye_lil = eye_array::<f64>(3, "lil").unwrap();
331 assert_eq!(eye_lil.shape(), (3, 3));
332 assert_eq!(eye_lil.nnz(), 3);
333 assert_eq!(eye_lil.get(0, 0), 1.0);
334 assert_eq!(eye_lil.get(1, 1), 1.0);
335 assert_eq!(eye_lil.get(2, 2), 1.0);
336 }
337
338 #[test]
339 fn test_eye_array_k() {
340 let eye = eye_array_k::<f64>(3, 3, 0, "csr").unwrap();
342 assert_eq!(eye.get(0, 0), 1.0);
343 assert_eq!(eye.get(1, 1), 1.0);
344 assert_eq!(eye.get(2, 2), 1.0);
345
346 let superdiag = eye_array_k::<f64>(3, 4, 1, "csr").unwrap();
348 assert_eq!(superdiag.get(0, 1), 1.0);
349 assert_eq!(superdiag.get(1, 2), 1.0);
350 assert_eq!(superdiag.get(2, 3), 1.0);
351
352 let subdiag = eye_array_k::<f64>(4, 3, -1, "csr").unwrap();
354 assert_eq!(subdiag.get(1, 0), 1.0);
355 assert_eq!(subdiag.get(2, 1), 1.0);
356 assert_eq!(subdiag.get(3, 2), 1.0);
357
358 let eye_lil = eye_array_k::<f64>(3, 3, 0, "lil").unwrap();
360 assert_eq!(eye_lil.get(0, 0), 1.0);
361 assert_eq!(eye_lil.get(1, 1), 1.0);
362 assert_eq!(eye_lil.get(2, 2), 1.0);
363 }
364
365 #[test]
366 fn test_diags_array() {
367 let diags = vec![
368 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0]), ];
371 let offsets = vec![0, 1];
372 let shape = (3, 3);
373
374 let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
375 assert_eq!(result.shape(), (3, 3));
376 assert_eq!(result.get(0, 0), 1.0);
377 assert_eq!(result.get(1, 1), 2.0);
378 assert_eq!(result.get(2, 2), 3.0);
379 assert_eq!(result.get(0, 1), 4.0);
380 assert_eq!(result.get(1, 2), 5.0);
381
382 let diags = vec![
384 Array1::from_vec(vec![1.0, 2.0, 3.0]), Array1::from_vec(vec![4.0, 5.0]), Array1::from_vec(vec![6.0, 7.0]), ];
388 let offsets = vec![0, 1, -1];
389 let shape = (3, 3);
390
391 let result = diags_array(&diags, &offsets, shape, "csr").unwrap();
392 assert_eq!(result.shape(), (3, 3));
393 assert_eq!(result.get(0, 0), 1.0);
394 assert_eq!(result.get(1, 1), 2.0);
395 assert_eq!(result.get(2, 2), 3.0);
396 assert_eq!(result.get(0, 1), 4.0);
397 assert_eq!(result.get(1, 2), 5.0);
398 assert_eq!(result.get(1, 0), 6.0);
399 assert_eq!(result.get(2, 1), 7.0);
400
401 let result_lil = diags_array(&diags, &offsets, shape, "lil").unwrap();
403 assert_eq!(result_lil.shape(), (3, 3));
404 assert_eq!(result_lil.get(0, 0), 1.0);
405 assert_eq!(result_lil.get(1, 1), 2.0);
406 assert_eq!(result_lil.get(2, 2), 3.0);
407 assert_eq!(result_lil.get(0, 1), 4.0);
408 assert_eq!(result_lil.get(1, 2), 5.0);
409 assert_eq!(result_lil.get(1, 0), 6.0);
410 assert_eq!(result_lil.get(2, 1), 7.0);
411 }
412
413 #[test]
414 fn test_random_array() {
415 let shape = (10, 10);
416 let density = 0.3;
417
418 let random = random_array::<f64>(shape, density, None, "csr").unwrap();
419
420 assert_eq!(random.shape(), shape);
422 let nnz = random.nnz();
423 let expected_nnz = (shape.0 * shape.1) as f64 * density;
424
425 assert!(
427 (nnz as f64) > expected_nnz * 0.7,
428 "Too few non-zeros: {}",
429 nnz
430 );
431 assert!(
432 (nnz as f64) < expected_nnz * 1.3,
433 "Too many non-zeros: {}",
434 nnz
435 );
436
437 let random_seeded = random_array::<f64>(shape, density, Some(42), "csr").unwrap();
439 assert_eq!(random_seeded.shape(), shape);
440
441 let random_lil = random_array::<f64>((5, 5), 0.5, Some(42), "lil").unwrap();
443 assert_eq!(random_lil.shape(), (5, 5));
444 let nnz_lil = random_lil.nnz();
445 let expected_nnz_lil = 25.0 * 0.5;
446 assert!(
447 (nnz_lil as f64) > expected_nnz_lil * 0.7,
448 "Too few non-zeros in LIL: {}",
449 nnz_lil
450 );
451 assert!(
452 (nnz_lil as f64) < expected_nnz_lil * 1.3,
453 "Too many non-zeros in LIL: {}",
454 nnz_lil
455 );
456 }
457}
458
459pub fn random_array<T>(
484 shape: (usize, usize),
485 density: f64,
486 seed: Option<u64>,
487 format: &str,
488) -> SparseResult<Box<dyn SparseArray<T>>>
489where
490 T: Float
491 + Add<Output = T>
492 + Sub<Output = T>
493 + Mul<Output = T>
494 + Div<Output = T>
495 + Debug
496 + Copy
497 + 'static,
498{
499 let (m, n) = shape;
500
501 if !(0.0..=1.0).contains(&density) {
502 return Err(SparseError::ValueError(
503 "Density must be between 0.0 and 1.0".to_string(),
504 ));
505 }
506
507 if m == 0 || n == 0 {
508 return Err(SparseError::ValueError(
509 "Matrix dimensions must be positive".to_string(),
510 ));
511 }
512
513 let nnz = (m * n) as f64 * density;
515 let nnz = nnz.round() as usize;
516
517 let mut rows = Vec::with_capacity(nnz);
519 let mut cols = Vec::with_capacity(nnz);
520 let mut data = Vec::with_capacity(nnz);
521
522 let mut rng = if let Some(seed_value) = seed {
524 rand::rngs::StdRng::seed_from_u64(seed_value)
525 } else {
526 let seed = rand::Rng::random::<u64>(&mut rand::rng());
528 rand::rngs::StdRng::seed_from_u64(seed)
529 };
530
531 let total = m * n;
533
534 if density > 0.4 {
535 let mut indices: Vec<usize> = (0..total).collect();
537 indices.shuffle(&mut rng);
538
539 for &idx in indices.iter().take(nnz) {
540 let row = idx / n;
541 let col = idx % n;
542
543 rows.push(row);
544 cols.push(col);
545
546 let mut val: f64 = rng.random_range(-1.0..1.0);
549 while val.abs() < 1e-10 {
551 val = rng.random_range(-1.0..1.0);
552 }
553 data.push(T::from(val).unwrap());
554 }
555 } else {
556 let mut positions = std::collections::HashSet::with_capacity(nnz);
558
559 while positions.len() < nnz {
560 let row = rng.random_range(0..m);
561 let col = rng.random_range(0..n);
562 let pos = row * n + col; if positions.insert(pos) {
565 rows.push(row);
566 cols.push(col);
567
568 let mut val: f64 = rng.random_range(-1.0..1.0);
570 while val.abs() < 1e-10 {
572 val = rng.random_range(-1.0..1.0);
573 }
574 data.push(T::from(val).unwrap());
575 }
576 }
577 }
578
579 match format.to_lowercase().as_str() {
581 "csr" => CsrArray::from_triplets(&rows, &cols, &data, shape, false)
582 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
583 "coo" => CooArray::from_triplets(&rows, &cols, &data, shape, false)
584 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
585 "dok" => DokArray::from_triplets(&rows, &cols, &data, shape)
586 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
587 "lil" => LilArray::from_triplets(&rows, &cols, &data, shape)
588 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>),
589 _ => Err(SparseError::ValueError(format!(
590 "Unknown sparse format: {}. Supported formats are 'csr', 'coo', 'dok', and 'lil'",
591 format
592 ))),
593 }
594}