scirs2_sparse/
coo_array.rs

1// COO Array implementation
2//
3// This module provides the COO (COOrdinate) array format,
4// which is efficient for incrementally constructing a sparse array.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::csr_array::CsrArray;
12use crate::error::{SparseError, SparseResult};
13use crate::sparray::{SparseArray, SparseSum};
14
15/// COO Array format
16///
17/// The COO (COOrdinate) format stores data as a triplet of arrays:
18/// - row: array of row indices
19/// - col: array of column indices
20/// - data: array of corresponding non-zero values
21///
22/// # Notes
23///
24/// - Efficient for incrementally constructing a sparse array or matrix
25/// - Allows for duplicate entries (summed when converted to other formats)
26/// - Fast conversion to other formats
27/// - Not efficient for arithmetic operations
28/// - Not efficient for slicing operations
29///
30#[derive(Clone)]
31pub struct CooArray<T>
32where
33    T: SparseElement + Div<Output = T> + PartialOrd + 'static,
34{
35    /// Row indices
36    row: Array1<usize>,
37    /// Column indices
38    col: Array1<usize>,
39    /// Data values
40    data: Array1<T>,
41    /// Shape of the array
42    shape: (usize, usize),
43    /// Whether entries are sorted by row
44    has_canonical_format: bool,
45}
46
47impl<T> CooArray<T>
48where
49    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
50{
51    /// Creates a new COO array
52    ///
53    /// # Arguments
54    /// * `data` - Array of non-zero values
55    /// * `row` - Array of row indices
56    /// * `col` - Array of column indices
57    /// * `shape` - Shape of the sparse array
58    /// * `has_canonical_format` - Whether entries are sorted by row
59    ///
60    /// # Returns
61    /// A new `CooArray`
62    ///
63    /// # Errors
64    /// Returns an error if the data is not consistent
65    pub fn new(
66        data: Array1<T>,
67        row: Array1<usize>,
68        col: Array1<usize>,
69        shape: (usize, usize),
70        has_canonical_format: bool,
71    ) -> SparseResult<Self> {
72        // Validation
73        if data.len() != row.len() || data.len() != col.len() {
74            return Err(SparseError::InconsistentData {
75                reason: "data, row, and col must have the same length".to_string(),
76            });
77        }
78
79        if let Some(&max_row) = row.iter().max() {
80            if max_row >= shape.0 {
81                return Err(SparseError::IndexOutOfBounds {
82                    index: (max_row, 0),
83                    shape,
84                });
85            }
86        }
87
88        if let Some(&max_col) = col.iter().max() {
89            if max_col >= shape.1 {
90                return Err(SparseError::IndexOutOfBounds {
91                    index: (0, max_col),
92                    shape,
93                });
94            }
95        }
96
97        Ok(Self {
98            data,
99            row,
100            col,
101            shape,
102            has_canonical_format,
103        })
104    }
105
106    /// Create a COO array from (row, col, data) triplets
107    ///
108    /// # Arguments
109    /// * `row` - Row indices
110    /// * `col` - Column indices
111    /// * `data` - Values
112    /// * `shape` - Shape of the sparse array
113    /// * `sorted` - Whether the triplets are already sorted
114    ///
115    /// # Returns
116    /// A new `CooArray`
117    ///
118    /// # Errors
119    /// Returns an error if the data is not consistent
120    pub fn from_triplets(
121        row: &[usize],
122        col: &[usize],
123        data: &[T],
124        shape: (usize, usize),
125        sorted: bool,
126    ) -> SparseResult<Self> {
127        let row_array = Array1::from_vec(row.to_vec());
128        let col_array = Array1::from_vec(col.to_vec());
129        let data_array = Array1::from_vec(data.to_vec());
130
131        Self::new(data_array, row_array, col_array, shape, sorted)
132    }
133
134    /// Get the rows array
135    pub fn get_rows(&self) -> &Array1<usize> {
136        &self.row
137    }
138
139    /// Get the cols array
140    pub fn get_cols(&self) -> &Array1<usize> {
141        &self.col
142    }
143
144    /// Get the data array
145    pub fn get_data(&self) -> &Array1<T> {
146        &self.data
147    }
148
149    /// Put the array in canonical format (sort by row index, then column index)
150    pub fn canonical_format(&mut self) {
151        if self.has_canonical_format {
152            return;
153        }
154
155        let n = self.data.len();
156        let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(n);
157
158        for i in 0..n {
159            triplets.push((self.row[i], self.col[i], self.data[i]));
160        }
161
162        triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
163
164        for (i, &(r, c, v)) in triplets.iter().enumerate() {
165            self.row[i] = r;
166            self.col[i] = c;
167            self.data[i] = v;
168        }
169
170        self.has_canonical_format = true;
171    }
172
173    /// Converts to a COO with summed duplicate entries
174    pub fn sum_duplicates(&mut self) {
175        self.canonical_format();
176
177        let n = self.data.len();
178        if n == 0 {
179            return;
180        }
181
182        let mut new_data = Vec::new();
183        let mut new_row = Vec::new();
184        let mut new_col = Vec::new();
185
186        let mut curr_row = self.row[0];
187        let mut curr_col = self.col[0];
188        let mut curr_sum = self.data[0];
189
190        for i in 1..n {
191            if self.row[i] == curr_row && self.col[i] == curr_col {
192                curr_sum = curr_sum + self.data[i];
193            } else {
194                if curr_sum != T::sparse_zero() {
195                    new_data.push(curr_sum);
196                    new_row.push(curr_row);
197                    new_col.push(curr_col);
198                }
199                curr_row = self.row[i];
200                curr_col = self.col[i];
201                curr_sum = self.data[i];
202            }
203        }
204
205        // Add the last element
206        if curr_sum != T::sparse_zero() {
207            new_data.push(curr_sum);
208            new_row.push(curr_row);
209            new_col.push(curr_col);
210        }
211
212        self.data = Array1::from_vec(new_data);
213        self.row = Array1::from_vec(new_row);
214        self.col = Array1::from_vec(new_col);
215    }
216}
217
218impl<T> SparseArray<T> for CooArray<T>
219where
220    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
221{
222    fn shape(&self) -> (usize, usize) {
223        self.shape
224    }
225
226    fn nnz(&self) -> usize {
227        self.data.len()
228    }
229
230    fn dtype(&self) -> &str {
231        "float" // Placeholder
232    }
233
234    fn to_array(&self) -> Array2<T> {
235        let (rows, cols) = self.shape;
236        let mut result = Array2::zeros((rows, cols));
237
238        for i in 0..self.data.len() {
239            let r = self.row[i];
240            let c = self.col[i];
241            result[[r, c]] = result[[r, c]] + self.data[i]; // Sum duplicates
242        }
243
244        result
245    }
246
247    fn toarray(&self) -> Array2<T> {
248        self.to_array()
249    }
250
251    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
252        // Create a copy with summed duplicates
253        let mut new_coo = self.clone();
254        new_coo.sum_duplicates();
255        Ok(Box::new(new_coo))
256    }
257
258    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
259        // Convert to CSR format
260        let mut data_vec = self.data.to_vec();
261        let mut row_vec = self.row.to_vec();
262        let mut col_vec = self.col.to_vec();
263
264        // Sort by row, then column
265        let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
266        for i in 0..data_vec.len() {
267            triplets.push((row_vec[i], col_vec[i], data_vec[i]));
268        }
269        triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
270
271        for (i, &(r, c, v)) in triplets.iter().enumerate() {
272            row_vec[i] = r;
273            col_vec[i] = c;
274            data_vec[i] = v;
275        }
276
277        // Convert to CSR format using CsrArray::from_triplets
278        CsrArray::from_triplets(&row_vec, &col_vec, &data_vec, self.shape, true)
279            .map(|csr| Box::new(csr) as Box<dyn SparseArray<T>>)
280    }
281
282    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
283        // For now, convert to CSR and then transpose
284        // In an actual implementation, this would directly convert to CSC
285        let csr = self.to_csr()?;
286        csr.transpose()
287    }
288
289    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
290        // In a real implementation, this would convert directly to DOK format
291        Ok(Box::new(self.clone()))
292    }
293
294    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
295        // In a real implementation, this would convert directly to LIL format
296        Ok(Box::new(self.clone()))
297    }
298
299    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
300        // In a real implementation, this would convert directly to DIA format
301        Ok(Box::new(self.clone()))
302    }
303
304    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
305        // In a real implementation, this would convert directly to BSR format
306        Ok(Box::new(self.clone()))
307    }
308
309    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
310        // For efficiency, convert to CSR format and then add
311        let self_csr = self.to_csr()?;
312        self_csr.add(other)
313    }
314
315    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
316        // For efficiency, convert to CSR format and then subtract
317        let self_csr = self.to_csr()?;
318        self_csr.sub(other)
319    }
320
321    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
322        // For efficiency, convert to CSR format and then multiply
323        let self_csr = self.to_csr()?;
324        self_csr.mul(other)
325    }
326
327    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
328        // For efficiency, convert to CSR format and then divide
329        let self_csr = self.to_csr()?;
330        self_csr.div(other)
331    }
332
333    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
334        // For efficiency, convert to CSR format and then do matrix multiplication
335        let self_csr = self.to_csr()?;
336        self_csr.dot(other)
337    }
338
339    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
340        let (m, n) = self.shape();
341        if n != other.len() {
342            return Err(SparseError::DimensionMismatch {
343                expected: n,
344                found: other.len(),
345            });
346        }
347
348        let mut result = Array1::zeros(m);
349
350        for i in 0..self.data.len() {
351            let row = self.row[i];
352            let col = self.col[i];
353            result[row] = result[row] + self.data[i] * other[col];
354        }
355
356        Ok(result)
357    }
358
359    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
360        // Swap row and column arrays
361        CooArray::new(
362            self.data.clone(),
363            self.col.clone(),             // Note the swap
364            self.row.clone(),             // Note the swap
365            (self.shape.1, self.shape.0), // Swap shape dimensions
366            self.has_canonical_format,
367        )
368        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
369    }
370
371    fn copy(&self) -> Box<dyn SparseArray<T>> {
372        Box::new(self.clone())
373    }
374
375    fn get(&self, i: usize, j: usize) -> T {
376        if i >= self.shape.0 || j >= self.shape.1 {
377            return T::sparse_zero();
378        }
379
380        let mut sum = T::sparse_zero();
381        for idx in 0..self.data.len() {
382            if self.row[idx] == i && self.col[idx] == j {
383                sum = sum + self.data[idx];
384            }
385        }
386
387        sum
388    }
389
390    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
391        if i >= self.shape.0 || j >= self.shape.1 {
392            return Err(SparseError::IndexOutOfBounds {
393                index: (i, j),
394                shape: self.shape,
395            });
396        }
397
398        if value == T::sparse_zero() {
399            // Remove existing entries at (i, j)
400            let mut new_data = Vec::new();
401            let mut new_row = Vec::new();
402            let mut new_col = Vec::new();
403
404            for idx in 0..self.data.len() {
405                if !(self.row[idx] == i && self.col[idx] == j) {
406                    new_data.push(self.data[idx]);
407                    new_row.push(self.row[idx]);
408                    new_col.push(self.col[idx]);
409                }
410            }
411
412            self.data = Array1::from_vec(new_data);
413            self.row = Array1::from_vec(new_row);
414            self.col = Array1::from_vec(new_col);
415        } else {
416            // First remove any existing entries
417            self.set(i, j, T::sparse_zero())?;
418
419            // Then add the new value
420            let mut new_data = self.data.to_vec();
421            let mut new_row = self.row.to_vec();
422            let mut new_col = self.col.to_vec();
423
424            new_data.push(value);
425            new_row.push(i);
426            new_col.push(j);
427
428            self.data = Array1::from_vec(new_data);
429            self.row = Array1::from_vec(new_row);
430            self.col = Array1::from_vec(new_col);
431
432            // No longer in canonical format
433            self.has_canonical_format = false;
434        }
435
436        Ok(())
437    }
438
439    fn eliminate_zeros(&mut self) {
440        let mut new_data = Vec::new();
441        let mut new_row = Vec::new();
442        let mut new_col = Vec::new();
443
444        for i in 0..self.data.len() {
445            if !SparseElement::is_zero(&self.data[i]) {
446                new_data.push(self.data[i]);
447                new_row.push(self.row[i]);
448                new_col.push(self.col[i]);
449            }
450        }
451
452        self.data = Array1::from_vec(new_data);
453        self.row = Array1::from_vec(new_row);
454        self.col = Array1::from_vec(new_col);
455    }
456
457    fn sort_indices(&mut self) {
458        self.canonical_format();
459    }
460
461    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
462        if self.has_canonical_format {
463            return Box::new(self.clone());
464        }
465
466        let mut sorted = self.clone();
467        sorted.canonical_format();
468        Box::new(sorted)
469    }
470
471    fn has_sorted_indices(&self) -> bool {
472        self.has_canonical_format
473    }
474
475    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
476        // For efficiency, convert to CSR format and then sum
477        let self_csr = self.to_csr()?;
478        self_csr.sum(axis)
479    }
480
481    fn max(&self) -> T {
482        if self.data.is_empty() {
483            // Empty sparse matrix - all elements are implicitly zero
484            return T::sparse_zero();
485        }
486
487        let mut max_val = self.data[0];
488        for &val in self.data.iter().skip(1) {
489            if val > max_val {
490                max_val = val;
491            }
492        }
493
494        // Check if max_val is less than zero, as zeros aren't explicitly stored
495        let zero = T::sparse_zero();
496        if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
497            max_val = zero;
498        }
499
500        max_val
501    }
502
503    fn min(&self) -> T {
504        if self.data.is_empty() {
505            // Empty sparse matrix - all elements are implicitly zero
506            return T::sparse_zero();
507        }
508
509        let mut min_val = self.data[0];
510        for &val in self.data.iter().skip(1) {
511            if val < min_val {
512                min_val = val;
513            }
514        }
515
516        // Check if min_val is greater than zero, as zeros aren't explicitly stored
517        if min_val > T::sparse_zero() && self.nnz() < self.shape.0 * self.shape.1 {
518            min_val = T::sparse_zero();
519        }
520
521        min_val
522    }
523
524    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
525        // Return copies of the row, col, and data arrays
526        let data_vec = self.data.to_vec();
527        let row_vec = self.row.to_vec();
528        let col_vec = self.col.to_vec();
529
530        // If there are duplicate entries, sum them
531        if self.has_canonical_format {
532            // We can use a more efficient algorithm if already sorted
533            (self.row.clone(), self.col.clone(), self.data.clone())
534        } else {
535            // Convert to canonical form with summed duplicates
536            let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
537            for i in 0..data_vec.len() {
538                triplets.push((row_vec[i], col_vec[i], data_vec[i]));
539            }
540            triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
541
542            let mut result_row = Vec::new();
543            let mut result_col = Vec::new();
544            let mut result_data = Vec::new();
545
546            if !triplets.is_empty() {
547                let mut curr_row = triplets[0].0;
548                let mut curr_col = triplets[0].1;
549                let mut curr_sum = triplets[0].2;
550
551                for &(r, c, v) in triplets.iter().skip(1) {
552                    if r == curr_row && c == curr_col {
553                        curr_sum = curr_sum + v;
554                    } else {
555                        if curr_sum != T::sparse_zero() {
556                            result_row.push(curr_row);
557                            result_col.push(curr_col);
558                            result_data.push(curr_sum);
559                        }
560                        curr_row = r;
561                        curr_col = c;
562                        curr_sum = v;
563                    }
564                }
565
566                // Add the last element
567                if curr_sum != T::sparse_zero() {
568                    result_row.push(curr_row);
569                    result_col.push(curr_col);
570                    result_data.push(curr_sum);
571                }
572            }
573
574            (
575                Array1::from_vec(result_row),
576                Array1::from_vec(result_col),
577                Array1::from_vec(result_data),
578            )
579        }
580    }
581
582    fn slice(
583        &self,
584        row_range: (usize, usize),
585        col_range: (usize, usize),
586    ) -> SparseResult<Box<dyn SparseArray<T>>> {
587        let (start_row, end_row) = row_range;
588        let (start_col, end_col) = col_range;
589
590        if start_row >= self.shape.0
591            || end_row > self.shape.0
592            || start_col >= self.shape.1
593            || end_col > self.shape.1
594        {
595            return Err(SparseError::InvalidSliceRange);
596        }
597
598        if start_row >= end_row || start_col >= end_col {
599            return Err(SparseError::InvalidSliceRange);
600        }
601
602        let mut new_data = Vec::new();
603        let mut new_row = Vec::new();
604        let mut new_col = Vec::new();
605
606        for i in 0..self.data.len() {
607            let r = self.row[i];
608            let c = self.col[i];
609
610            if r >= start_row && r < end_row && c >= start_col && c < end_col {
611                new_data.push(self.data[i]);
612                new_row.push(r - start_row); // Adjust indices
613                new_col.push(c - start_col); // Adjust indices
614            }
615        }
616
617        CooArray::new(
618            Array1::from_vec(new_data),
619            Array1::from_vec(new_row),
620            Array1::from_vec(new_col),
621            (end_row - start_row, end_col - start_col),
622            false,
623        )
624        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
625    }
626
627    fn as_any(&self) -> &dyn std::any::Any {
628        self
629    }
630}
631
632impl<T> fmt::Debug for CooArray<T>
633where
634    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
635{
636    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
637        write!(
638            f,
639            "CooArray<{}x{}, nnz={}>",
640            self.shape.0,
641            self.shape.1,
642            self.nnz()
643        )
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    #[test]
652    fn test_coo_array_construction() {
653        let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
654        let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
655        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
656        let shape = (3, 3);
657
658        let coo = CooArray::new(data, row, col, shape, false).unwrap();
659
660        assert_eq!(coo.shape(), (3, 3));
661        assert_eq!(coo.nnz(), 5);
662        assert_eq!(coo.get(0, 0), 1.0);
663        assert_eq!(coo.get(0, 2), 2.0);
664        assert_eq!(coo.get(1, 1), 3.0);
665        assert_eq!(coo.get(2, 0), 4.0);
666        assert_eq!(coo.get(2, 2), 5.0);
667        assert_eq!(coo.get(0, 1), 0.0);
668    }
669
670    #[test]
671    fn test_coo_from_triplets() {
672        let rows = vec![0, 0, 1, 2, 2];
673        let cols = vec![0, 2, 1, 0, 2];
674        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
675        let shape = (3, 3);
676
677        let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
678
679        assert_eq!(coo.shape(), (3, 3));
680        assert_eq!(coo.nnz(), 5);
681        assert_eq!(coo.get(0, 0), 1.0);
682        assert_eq!(coo.get(0, 2), 2.0);
683        assert_eq!(coo.get(1, 1), 3.0);
684        assert_eq!(coo.get(2, 0), 4.0);
685        assert_eq!(coo.get(2, 2), 5.0);
686        assert_eq!(coo.get(0, 1), 0.0);
687    }
688
689    #[test]
690    fn test_coo_array_to_array() {
691        let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
692        let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
693        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
694        let shape = (3, 3);
695
696        let coo = CooArray::new(data, row, col, shape, false).unwrap();
697        let dense = coo.to_array();
698
699        assert_eq!(dense.shape(), &[3, 3]);
700        assert_eq!(dense[[0, 0]], 1.0);
701        assert_eq!(dense[[0, 1]], 0.0);
702        assert_eq!(dense[[0, 2]], 2.0);
703        assert_eq!(dense[[1, 0]], 0.0);
704        assert_eq!(dense[[1, 1]], 3.0);
705        assert_eq!(dense[[1, 2]], 0.0);
706        assert_eq!(dense[[2, 0]], 4.0);
707        assert_eq!(dense[[2, 1]], 0.0);
708        assert_eq!(dense[[2, 2]], 5.0);
709    }
710
711    #[test]
712    fn test_coo_array_duplicate_entries() {
713        let row = Array1::from_vec(vec![0, 0, 0, 1, 1]);
714        let col = Array1::from_vec(vec![0, 0, 1, 0, 0]);
715        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
716        let shape = (2, 2);
717
718        let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
719
720        // Test summing duplicates
721        coo.sum_duplicates();
722
723        // Should now have only 3 entries
724        assert_eq!(coo.nnz(), 3);
725        assert_eq!(coo.get(0, 0), 3.0); // 1.0 + 2.0
726        assert_eq!(coo.get(0, 1), 3.0);
727        assert_eq!(coo.get(1, 0), 9.0); // 4.0 + 5.0
728    }
729
730    #[test]
731    fn test_coo_set_get() {
732        let row = Array1::from_vec(vec![0, 1]);
733        let col = Array1::from_vec(vec![0, 1]);
734        let data = Array1::from_vec(vec![1.0, 2.0]);
735        let shape = (2, 2);
736
737        let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
738
739        // Set a new value
740        coo.set(0, 1, 3.0).unwrap();
741        assert_eq!(coo.get(0, 1), 3.0);
742
743        // Update an existing value
744        coo.set(0, 0, 4.0).unwrap();
745        assert_eq!(coo.get(0, 0), 4.0);
746
747        // Set to zero should remove the entry
748        coo.set(0, 0, 0.0).unwrap();
749        assert_eq!(coo.get(0, 0), 0.0);
750
751        // nnz should be 2 now (2.0 at (1,1) and 3.0 at (0,1))
752        assert_eq!(coo.nnz(), 2);
753    }
754
755    #[test]
756    fn test_coo_canonical_format() {
757        let row = Array1::from_vec(vec![1, 0, 2, 0]);
758        let col = Array1::from_vec(vec![1, 0, 2, 2]);
759        let data = Array1::from_vec(vec![3.0, 1.0, 5.0, 2.0]);
760        let shape = (3, 3);
761
762        let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
763
764        // Not in canonical format
765        assert!(!coo.has_canonical_format);
766
767        // Sort to canonical format
768        coo.canonical_format();
769
770        // Now in canonical format
771        assert!(coo.has_canonical_format);
772
773        // Check order: (0,0), (0,2), (1,1), (2,2)
774        assert_eq!(coo.row[0], 0);
775        assert_eq!(coo.col[0], 0);
776        assert_eq!(coo.data[0], 1.0);
777
778        assert_eq!(coo.row[1], 0);
779        assert_eq!(coo.col[1], 2);
780        assert_eq!(coo.data[1], 2.0);
781
782        assert_eq!(coo.row[2], 1);
783        assert_eq!(coo.col[2], 1);
784        assert_eq!(coo.data[2], 3.0);
785
786        assert_eq!(coo.row[3], 2);
787        assert_eq!(coo.col[3], 2);
788        assert_eq!(coo.data[3], 5.0);
789    }
790
791    #[test]
792    fn test_coo_to_csr() {
793        let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
794        let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
795        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
796        let shape = (3, 3);
797
798        let coo = CooArray::new(data, row, col, shape, false).unwrap();
799
800        // Convert to CSR
801        let csr = coo.to_csr().unwrap();
802
803        // Check values
804        let dense = csr.to_array();
805        assert_eq!(dense[[0, 0]], 1.0);
806        assert_eq!(dense[[0, 2]], 2.0);
807        assert_eq!(dense[[1, 1]], 3.0);
808        assert_eq!(dense[[2, 0]], 4.0);
809        assert_eq!(dense[[2, 2]], 5.0);
810    }
811
812    #[test]
813    fn test_coo_transpose() {
814        let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
815        let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
816        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
817        let shape = (3, 3);
818
819        let coo = CooArray::new(data, row, col, shape, false).unwrap();
820
821        // Transpose
822        let transposed = coo.transpose().unwrap();
823
824        // Check shape
825        assert_eq!(transposed.shape(), (3, 3));
826
827        // Check values
828        let dense = transposed.to_array();
829        assert_eq!(dense[[0, 0]], 1.0);
830        assert_eq!(dense[[2, 0]], 2.0);
831        assert_eq!(dense[[1, 1]], 3.0);
832        assert_eq!(dense[[0, 2]], 4.0);
833        assert_eq!(dense[[2, 2]], 5.0);
834    }
835}