scirs2_sparse/
coo_array.rs

1// COO Array implementation
2//
3// This module provides the COO (COOrdinate) array format,
4// which is efficient for incrementally constructing a sparse array.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::csr_array::CsrArray;
12use crate::error::{SparseError, SparseResult};
13use crate::sparray::{SparseArray, SparseSum};
14
15/// COO Array format
16///
17/// The COO (COOrdinate) format stores data as a triplet of arrays:
18/// - row: array of row indices
19/// - col: array of column indices
20/// - data: array of corresponding non-zero values
21///
22/// # Notes
23///
24/// - Efficient for incrementally constructing a sparse array or matrix
25/// - Allows for duplicate entries (summed when converted to other formats)
26/// - Fast conversion to other formats
27/// - Not efficient for arithmetic operations
28/// - Not efficient for slicing operations
29///
30#[derive(Clone)]
31pub struct CooArray<T>
32where
33    T: Float
34        + Add<Output = T>
35        + Sub<Output = T>
36        + Mul<Output = T>
37        + Div<Output = T>
38        + Debug
39        + Copy
40        + 'static,
41{
42    /// Row indices
43    row: Array1<usize>,
44    /// Column indices
45    col: Array1<usize>,
46    /// Data values
47    data: Array1<T>,
48    /// Shape of the array
49    shape: (usize, usize),
50    /// Whether entries are sorted by row
51    has_canonical_format: bool,
52}
53
54impl<T> CooArray<T>
55where
56    T: Float
57        + Add<Output = T>
58        + Sub<Output = T>
59        + Mul<Output = T>
60        + Div<Output = T>
61        + Debug
62        + Copy
63        + 'static,
64{
65    /// Creates a new COO array
66    ///
67    /// # Arguments
68    /// * `data` - Array of non-zero values
69    /// * `row` - Array of row indices
70    /// * `col` - Array of column indices
71    /// * `shape` - Shape of the sparse array
72    /// * `has_canonical_format` - Whether entries are sorted by row
73    ///
74    /// # Returns
75    /// A new `CooArray`
76    ///
77    /// # Errors
78    /// Returns an error if the data is not consistent
79    pub fn new(
80        data: Array1<T>,
81        row: Array1<usize>,
82        col: Array1<usize>,
83        shape: (usize, usize),
84        has_canonical_format: bool,
85    ) -> SparseResult<Self> {
86        // Validation
87        if data.len() != row.len() || data.len() != col.len() {
88            return Err(SparseError::InconsistentData {
89                reason: "data, row, and col must have the same length".to_string(),
90            });
91        }
92
93        if let Some(&max_row) = row.iter().max() {
94            if max_row >= shape.0 {
95                return Err(SparseError::IndexOutOfBounds {
96                    index: (max_row, 0),
97                    shape,
98                });
99            }
100        }
101
102        if let Some(&max_col) = col.iter().max() {
103            if max_col >= shape.1 {
104                return Err(SparseError::IndexOutOfBounds {
105                    index: (0, max_col),
106                    shape,
107                });
108            }
109        }
110
111        Ok(Self {
112            data,
113            row,
114            col,
115            shape,
116            has_canonical_format,
117        })
118    }
119
120    /// Create a COO array from (row, col, data) triplets
121    ///
122    /// # Arguments
123    /// * `row` - Row indices
124    /// * `col` - Column indices
125    /// * `data` - Values
126    /// * `shape` - Shape of the sparse array
127    /// * `sorted` - Whether the triplets are already sorted
128    ///
129    /// # Returns
130    /// A new `CooArray`
131    ///
132    /// # Errors
133    /// Returns an error if the data is not consistent
134    pub fn from_triplets(
135        row: &[usize],
136        col: &[usize],
137        data: &[T],
138        shape: (usize, usize),
139        sorted: bool,
140    ) -> SparseResult<Self> {
141        let row_array = Array1::from_vec(row.to_vec());
142        let col_array = Array1::from_vec(col.to_vec());
143        let data_array = Array1::from_vec(data.to_vec());
144
145        Self::new(data_array, row_array, col_array, shape, sorted)
146    }
147
148    /// Get the rows array
149    pub fn get_rows(&self) -> &Array1<usize> {
150        &self.row
151    }
152
153    /// Get the cols array
154    pub fn get_cols(&self) -> &Array1<usize> {
155        &self.col
156    }
157
158    /// Get the data array
159    pub fn get_data(&self) -> &Array1<T> {
160        &self.data
161    }
162
163    /// Put the array in canonical format (sort by row index, then column index)
164    pub fn canonical_format(&mut self) {
165        if self.has_canonical_format {
166            return;
167        }
168
169        let n = self.data.len();
170        let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(n);
171
172        for i in 0..n {
173            triplets.push((self.row[i], self.col[i], self.data[i]));
174        }
175
176        triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
177
178        for (i, &(r, c, v)) in triplets.iter().enumerate() {
179            self.row[i] = r;
180            self.col[i] = c;
181            self.data[i] = v;
182        }
183
184        self.has_canonical_format = true;
185    }
186
187    /// Converts to a COO with summed duplicate entries
188    pub fn sum_duplicates(&mut self) {
189        self.canonical_format();
190
191        let n = self.data.len();
192        if n == 0 {
193            return;
194        }
195
196        let mut new_data = Vec::new();
197        let mut new_row = Vec::new();
198        let mut new_col = Vec::new();
199
200        let mut curr_row = self.row[0];
201        let mut curr_col = self.col[0];
202        let mut curr_sum = self.data[0];
203
204        for i in 1..n {
205            if self.row[i] == curr_row && self.col[i] == curr_col {
206                curr_sum = curr_sum + self.data[i];
207            } else {
208                if !curr_sum.is_zero() {
209                    new_data.push(curr_sum);
210                    new_row.push(curr_row);
211                    new_col.push(curr_col);
212                }
213                curr_row = self.row[i];
214                curr_col = self.col[i];
215                curr_sum = self.data[i];
216            }
217        }
218
219        // Add the last element
220        if !curr_sum.is_zero() {
221            new_data.push(curr_sum);
222            new_row.push(curr_row);
223            new_col.push(curr_col);
224        }
225
226        self.data = Array1::from_vec(new_data);
227        self.row = Array1::from_vec(new_row);
228        self.col = Array1::from_vec(new_col);
229    }
230}
231
232impl<T> SparseArray<T> for CooArray<T>
233where
234    T: Float
235        + Add<Output = T>
236        + Sub<Output = T>
237        + Mul<Output = T>
238        + Div<Output = T>
239        + Debug
240        + Copy
241        + 'static,
242{
243    fn shape(&self) -> (usize, usize) {
244        self.shape
245    }
246
247    fn nnz(&self) -> usize {
248        self.data.len()
249    }
250
251    fn dtype(&self) -> &str {
252        "float" // Placeholder
253    }
254
255    fn to_array(&self) -> Array2<T> {
256        let (rows, cols) = self.shape;
257        let mut result = Array2::zeros((rows, cols));
258
259        for i in 0..self.data.len() {
260            let r = self.row[i];
261            let c = self.col[i];
262            result[[r, c]] = result[[r, c]] + self.data[i]; // Sum duplicates
263        }
264
265        result
266    }
267
268    fn toarray(&self) -> Array2<T> {
269        self.to_array()
270    }
271
272    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
273        // Create a copy with summed duplicates
274        let mut new_coo = self.clone();
275        new_coo.sum_duplicates();
276        Ok(Box::new(new_coo))
277    }
278
279    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
280        // Convert to CSR format
281        let mut data_vec = self.data.to_vec();
282        let mut row_vec = self.row.to_vec();
283        let mut col_vec = self.col.to_vec();
284
285        // Sort by row, then column
286        let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
287        for i in 0..data_vec.len() {
288            triplets.push((row_vec[i], col_vec[i], data_vec[i]));
289        }
290        triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
291
292        for (i, &(r, c, v)) in triplets.iter().enumerate() {
293            row_vec[i] = r;
294            col_vec[i] = c;
295            data_vec[i] = v;
296        }
297
298        // Convert to CSR format using CsrArray::from_triplets
299        CsrArray::from_triplets(&row_vec, &col_vec, &data_vec, self.shape, true)
300            .map(|csr| Box::new(csr) as Box<dyn SparseArray<T>>)
301    }
302
303    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
304        // For now, convert to CSR and then transpose
305        // In an actual implementation, this would directly convert to CSC
306        let csr = self.to_csr()?;
307        csr.transpose()
308    }
309
310    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
311        // In a real implementation, this would convert directly to DOK format
312        Ok(Box::new(self.clone()))
313    }
314
315    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
316        // In a real implementation, this would convert directly to LIL format
317        Ok(Box::new(self.clone()))
318    }
319
320    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
321        // In a real implementation, this would convert directly to DIA format
322        Ok(Box::new(self.clone()))
323    }
324
325    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
326        // In a real implementation, this would convert directly to BSR format
327        Ok(Box::new(self.clone()))
328    }
329
330    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
331        // For efficiency, convert to CSR format and then add
332        let self_csr = self.to_csr()?;
333        self_csr.add(other)
334    }
335
336    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
337        // For efficiency, convert to CSR format and then subtract
338        let self_csr = self.to_csr()?;
339        self_csr.sub(other)
340    }
341
342    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
343        // For efficiency, convert to CSR format and then multiply
344        let self_csr = self.to_csr()?;
345        self_csr.mul(other)
346    }
347
348    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
349        // For efficiency, convert to CSR format and then divide
350        let self_csr = self.to_csr()?;
351        self_csr.div(other)
352    }
353
354    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
355        // For efficiency, convert to CSR format and then do matrix multiplication
356        let self_csr = self.to_csr()?;
357        self_csr.dot(other)
358    }
359
360    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
361        let (m, n) = self.shape();
362        if n != other.len() {
363            return Err(SparseError::DimensionMismatch {
364                expected: n,
365                found: other.len(),
366            });
367        }
368
369        let mut result = Array1::zeros(m);
370
371        for i in 0..self.data.len() {
372            let row = self.row[i];
373            let col = self.col[i];
374            result[row] = result[row] + self.data[i] * other[col];
375        }
376
377        Ok(result)
378    }
379
380    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
381        // Swap row and column arrays
382        CooArray::new(
383            self.data.clone(),
384            self.col.clone(),             // Note the swap
385            self.row.clone(),             // Note the swap
386            (self.shape.1, self.shape.0), // Swap shape dimensions
387            self.has_canonical_format,
388        )
389        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
390    }
391
392    fn copy(&self) -> Box<dyn SparseArray<T>> {
393        Box::new(self.clone())
394    }
395
396    fn get(&self, i: usize, j: usize) -> T {
397        if i >= self.shape.0 || j >= self.shape.1 {
398            return T::zero();
399        }
400
401        let mut sum = T::zero();
402        for idx in 0..self.data.len() {
403            if self.row[idx] == i && self.col[idx] == j {
404                sum = sum + self.data[idx];
405            }
406        }
407
408        sum
409    }
410
411    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
412        if i >= self.shape.0 || j >= self.shape.1 {
413            return Err(SparseError::IndexOutOfBounds {
414                index: (i, j),
415                shape: self.shape,
416            });
417        }
418
419        if value.is_zero() {
420            // Remove existing entries at (i, j)
421            let mut new_data = Vec::new();
422            let mut new_row = Vec::new();
423            let mut new_col = Vec::new();
424
425            for idx in 0..self.data.len() {
426                if !(self.row[idx] == i && self.col[idx] == j) {
427                    new_data.push(self.data[idx]);
428                    new_row.push(self.row[idx]);
429                    new_col.push(self.col[idx]);
430                }
431            }
432
433            self.data = Array1::from_vec(new_data);
434            self.row = Array1::from_vec(new_row);
435            self.col = Array1::from_vec(new_col);
436        } else {
437            // First remove any existing entries
438            self.set(i, j, T::zero())?;
439
440            // Then add the new value
441            let mut new_data = self.data.to_vec();
442            let mut new_row = self.row.to_vec();
443            let mut new_col = self.col.to_vec();
444
445            new_data.push(value);
446            new_row.push(i);
447            new_col.push(j);
448
449            self.data = Array1::from_vec(new_data);
450            self.row = Array1::from_vec(new_row);
451            self.col = Array1::from_vec(new_col);
452
453            // No longer in canonical format
454            self.has_canonical_format = false;
455        }
456
457        Ok(())
458    }
459
460    fn eliminate_zeros(&mut self) {
461        let mut new_data = Vec::new();
462        let mut new_row = Vec::new();
463        let mut new_col = Vec::new();
464
465        for i in 0..self.data.len() {
466            if !self.data[i].is_zero() {
467                new_data.push(self.data[i]);
468                new_row.push(self.row[i]);
469                new_col.push(self.col[i]);
470            }
471        }
472
473        self.data = Array1::from_vec(new_data);
474        self.row = Array1::from_vec(new_row);
475        self.col = Array1::from_vec(new_col);
476    }
477
478    fn sort_indices(&mut self) {
479        self.canonical_format();
480    }
481
482    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
483        if self.has_canonical_format {
484            return Box::new(self.clone());
485        }
486
487        let mut sorted = self.clone();
488        sorted.canonical_format();
489        Box::new(sorted)
490    }
491
492    fn has_sorted_indices(&self) -> bool {
493        self.has_canonical_format
494    }
495
496    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
497        // For efficiency, convert to CSR format and then sum
498        let self_csr = self.to_csr()?;
499        self_csr.sum(axis)
500    }
501
502    fn max(&self) -> T {
503        if self.data.is_empty() {
504            return T::neg_infinity();
505        }
506
507        let mut max_val = self.data[0];
508        for &val in self.data.iter().skip(1) {
509            if val > max_val {
510                max_val = val;
511            }
512        }
513
514        // Check if max_val is less than zero, as zeros aren't explicitly stored
515        if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
516            max_val = T::zero();
517        }
518
519        max_val
520    }
521
522    fn min(&self) -> T {
523        if self.data.is_empty() {
524            return T::infinity();
525        }
526
527        let mut min_val = self.data[0];
528        for &val in self.data.iter().skip(1) {
529            if val < min_val {
530                min_val = val;
531            }
532        }
533
534        // Check if min_val is greater than zero, as zeros aren't explicitly stored
535        if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
536            min_val = T::zero();
537        }
538
539        min_val
540    }
541
542    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
543        // Return copies of the row, col, and data arrays
544        let data_vec = self.data.to_vec();
545        let row_vec = self.row.to_vec();
546        let col_vec = self.col.to_vec();
547
548        // If there are duplicate entries, sum them
549        if self.has_canonical_format {
550            // We can use a more efficient algorithm if already sorted
551            (self.row.clone(), self.col.clone(), self.data.clone())
552        } else {
553            // Convert to canonical form with summed duplicates
554            let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
555            for i in 0..data_vec.len() {
556                triplets.push((row_vec[i], col_vec[i], data_vec[i]));
557            }
558            triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
559
560            let mut result_row = Vec::new();
561            let mut result_col = Vec::new();
562            let mut result_data = Vec::new();
563
564            if !triplets.is_empty() {
565                let mut curr_row = triplets[0].0;
566                let mut curr_col = triplets[0].1;
567                let mut curr_sum = triplets[0].2;
568
569                for &(r, c, v) in triplets.iter().skip(1) {
570                    if r == curr_row && c == curr_col {
571                        curr_sum = curr_sum + v;
572                    } else {
573                        if !curr_sum.is_zero() {
574                            result_row.push(curr_row);
575                            result_col.push(curr_col);
576                            result_data.push(curr_sum);
577                        }
578                        curr_row = r;
579                        curr_col = c;
580                        curr_sum = v;
581                    }
582                }
583
584                // Add the last element
585                if !curr_sum.is_zero() {
586                    result_row.push(curr_row);
587                    result_col.push(curr_col);
588                    result_data.push(curr_sum);
589                }
590            }
591
592            (
593                Array1::from_vec(result_row),
594                Array1::from_vec(result_col),
595                Array1::from_vec(result_data),
596            )
597        }
598    }
599
600    fn slice(
601        &self,
602        row_range: (usize, usize),
603        col_range: (usize, usize),
604    ) -> SparseResult<Box<dyn SparseArray<T>>> {
605        let (start_row, end_row) = row_range;
606        let (start_col, end_col) = col_range;
607
608        if start_row >= self.shape.0
609            || end_row > self.shape.0
610            || start_col >= self.shape.1
611            || end_col > self.shape.1
612        {
613            return Err(SparseError::InvalidSliceRange);
614        }
615
616        if start_row >= end_row || start_col >= end_col {
617            return Err(SparseError::InvalidSliceRange);
618        }
619
620        let mut new_data = Vec::new();
621        let mut new_row = Vec::new();
622        let mut new_col = Vec::new();
623
624        for i in 0..self.data.len() {
625            let r = self.row[i];
626            let c = self.col[i];
627
628            if r >= start_row && r < end_row && c >= start_col && c < end_col {
629                new_data.push(self.data[i]);
630                new_row.push(r - start_row); // Adjust indices
631                new_col.push(c - start_col); // Adjust indices
632            }
633        }
634
635        CooArray::new(
636            Array1::from_vec(new_data),
637            Array1::from_vec(new_row),
638            Array1::from_vec(new_col),
639            (end_row - start_row, end_col - start_col),
640            false,
641        )
642        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
643    }
644
645    fn as_any(&self) -> &dyn std::any::Any {
646        self
647    }
648}
649
650impl<T> fmt::Debug for CooArray<T>
651where
652    T: Float
653        + Add<Output = T>
654        + Sub<Output = T>
655        + Mul<Output = T>
656        + Div<Output = T>
657        + Debug
658        + Copy
659        + 'static,
660{
661    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
662        write!(
663            f,
664            "CooArray<{}x{}, nnz={}>",
665            self.shape.0,
666            self.shape.1,
667            self.nnz()
668        )
669    }
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn test_coo_array_construction() {
678        let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
679        let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
680        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
681        let shape = (3, 3);
682
683        let coo = CooArray::new(data, row, col, shape, false).unwrap();
684
685        assert_eq!(coo.shape(), (3, 3));
686        assert_eq!(coo.nnz(), 5);
687        assert_eq!(coo.get(0, 0), 1.0);
688        assert_eq!(coo.get(0, 2), 2.0);
689        assert_eq!(coo.get(1, 1), 3.0);
690        assert_eq!(coo.get(2, 0), 4.0);
691        assert_eq!(coo.get(2, 2), 5.0);
692        assert_eq!(coo.get(0, 1), 0.0);
693    }
694
695    #[test]
696    fn test_coo_from_triplets() {
697        let rows = vec![0, 0, 1, 2, 2];
698        let cols = vec![0, 2, 1, 0, 2];
699        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
700        let shape = (3, 3);
701
702        let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
703
704        assert_eq!(coo.shape(), (3, 3));
705        assert_eq!(coo.nnz(), 5);
706        assert_eq!(coo.get(0, 0), 1.0);
707        assert_eq!(coo.get(0, 2), 2.0);
708        assert_eq!(coo.get(1, 1), 3.0);
709        assert_eq!(coo.get(2, 0), 4.0);
710        assert_eq!(coo.get(2, 2), 5.0);
711        assert_eq!(coo.get(0, 1), 0.0);
712    }
713
714    #[test]
715    fn test_coo_array_to_array() {
716        let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
717        let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
718        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
719        let shape = (3, 3);
720
721        let coo = CooArray::new(data, row, col, shape, false).unwrap();
722        let dense = coo.to_array();
723
724        assert_eq!(dense.shape(), &[3, 3]);
725        assert_eq!(dense[[0, 0]], 1.0);
726        assert_eq!(dense[[0, 1]], 0.0);
727        assert_eq!(dense[[0, 2]], 2.0);
728        assert_eq!(dense[[1, 0]], 0.0);
729        assert_eq!(dense[[1, 1]], 3.0);
730        assert_eq!(dense[[1, 2]], 0.0);
731        assert_eq!(dense[[2, 0]], 4.0);
732        assert_eq!(dense[[2, 1]], 0.0);
733        assert_eq!(dense[[2, 2]], 5.0);
734    }
735
736    #[test]
737    fn test_coo_array_duplicate_entries() {
738        let row = Array1::from_vec(vec![0, 0, 0, 1, 1]);
739        let col = Array1::from_vec(vec![0, 0, 1, 0, 0]);
740        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
741        let shape = (2, 2);
742
743        let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
744
745        // Test summing duplicates
746        coo.sum_duplicates();
747
748        // Should now have only 3 entries
749        assert_eq!(coo.nnz(), 3);
750        assert_eq!(coo.get(0, 0), 3.0); // 1.0 + 2.0
751        assert_eq!(coo.get(0, 1), 3.0);
752        assert_eq!(coo.get(1, 0), 9.0); // 4.0 + 5.0
753    }
754
755    #[test]
756    fn test_coo_set_get() {
757        let row = Array1::from_vec(vec![0, 1]);
758        let col = Array1::from_vec(vec![0, 1]);
759        let data = Array1::from_vec(vec![1.0, 2.0]);
760        let shape = (2, 2);
761
762        let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
763
764        // Set a new value
765        coo.set(0, 1, 3.0).unwrap();
766        assert_eq!(coo.get(0, 1), 3.0);
767
768        // Update an existing value
769        coo.set(0, 0, 4.0).unwrap();
770        assert_eq!(coo.get(0, 0), 4.0);
771
772        // Set to zero should remove the entry
773        coo.set(0, 0, 0.0).unwrap();
774        assert_eq!(coo.get(0, 0), 0.0);
775
776        // nnz should be 2 now (2.0 at (1,1) and 3.0 at (0,1))
777        assert_eq!(coo.nnz(), 2);
778    }
779
780    #[test]
781    fn test_coo_canonical_format() {
782        let row = Array1::from_vec(vec![1, 0, 2, 0]);
783        let col = Array1::from_vec(vec![1, 0, 2, 2]);
784        let data = Array1::from_vec(vec![3.0, 1.0, 5.0, 2.0]);
785        let shape = (3, 3);
786
787        let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
788
789        // Not in canonical format
790        assert!(!coo.has_canonical_format);
791
792        // Sort to canonical format
793        coo.canonical_format();
794
795        // Now in canonical format
796        assert!(coo.has_canonical_format);
797
798        // Check order: (0,0), (0,2), (1,1), (2,2)
799        assert_eq!(coo.row[0], 0);
800        assert_eq!(coo.col[0], 0);
801        assert_eq!(coo.data[0], 1.0);
802
803        assert_eq!(coo.row[1], 0);
804        assert_eq!(coo.col[1], 2);
805        assert_eq!(coo.data[1], 2.0);
806
807        assert_eq!(coo.row[2], 1);
808        assert_eq!(coo.col[2], 1);
809        assert_eq!(coo.data[2], 3.0);
810
811        assert_eq!(coo.row[3], 2);
812        assert_eq!(coo.col[3], 2);
813        assert_eq!(coo.data[3], 5.0);
814    }
815
816    #[test]
817    fn test_coo_to_csr() {
818        let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
819        let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
820        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
821        let shape = (3, 3);
822
823        let coo = CooArray::new(data, row, col, shape, false).unwrap();
824
825        // Convert to CSR
826        let csr = coo.to_csr().unwrap();
827
828        // Check values
829        let dense = csr.to_array();
830        assert_eq!(dense[[0, 0]], 1.0);
831        assert_eq!(dense[[0, 2]], 2.0);
832        assert_eq!(dense[[1, 1]], 3.0);
833        assert_eq!(dense[[2, 0]], 4.0);
834        assert_eq!(dense[[2, 2]], 5.0);
835    }
836
837    #[test]
838    fn test_coo_transpose() {
839        let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
840        let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
841        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
842        let shape = (3, 3);
843
844        let coo = CooArray::new(data, row, col, shape, false).unwrap();
845
846        // Transpose
847        let transposed = coo.transpose().unwrap();
848
849        // Check shape
850        assert_eq!(transposed.shape(), (3, 3));
851
852        // Check values
853        let dense = transposed.to_array();
854        assert_eq!(dense[[0, 0]], 1.0);
855        assert_eq!(dense[[2, 0]], 2.0);
856        assert_eq!(dense[[1, 1]], 3.0);
857        assert_eq!(dense[[0, 2]], 4.0);
858        assert_eq!(dense[[2, 2]], 5.0);
859    }
860}