scirs2_sparse/
dia_array.rs

1// DIA Array implementation
2//
3// This module provides the DIA (DIAgonal) array format,
4// which is efficient for matrices with values concentrated on a small number of diagonals.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::dok_array::DokArray;
14use crate::error::{SparseError, SparseResult};
15use crate::lil_array::LilArray;
16use crate::sparray::{SparseArray, SparseSum};
17
18/// DIA Array format
19///
20/// The DIA (DIAgonal) format stores data as a collection of diagonals.
21/// It is efficient for matrices with values concentrated on a small number of diagonals,
22/// like tridiagonal or band matrices.
23///
24/// # Notes
25///
26/// - Very efficient storage for band matrices
27/// - Fast matrix-vector products for banded matrices
28/// - Not efficient for general sparse matrices
29/// - Difficult to modify once constructed
30///
31#[derive(Clone)]
32pub struct DiaArray<T>
33where
34    T: Float
35        + Add<Output = T>
36        + Sub<Output = T>
37        + Mul<Output = T>
38        + Div<Output = T>
39        + Debug
40        + Copy
41        + 'static
42        + std::ops::AddAssign,
43{
44    /// Diagonals data (n_diags x max(rows, cols))
45    data: Vec<Array1<T>>,
46    /// Diagonal offsets from the main diagonal (k > 0 for above, k < 0 for below)
47    offsets: Vec<isize>,
48    /// Shape of the array
49    shape: (usize, usize),
50}
51
52impl<T> DiaArray<T>
53where
54    T: Float
55        + Add<Output = T>
56        + Sub<Output = T>
57        + Mul<Output = T>
58        + Div<Output = T>
59        + Debug
60        + Copy
61        + 'static
62        + std::ops::AddAssign,
63{
64    /// Create a new DIA array from raw data
65    ///
66    /// # Arguments
67    ///
68    /// * `data` - Diagonals data (n_diags x max(rows, cols))
69    /// * `offsets` - Diagonal offsets from the main diagonal
70    /// * `shape` - Tuple containing the array dimensions (rows, cols)
71    ///
72    /// # Returns
73    ///
74    /// * A new DIA array
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// use scirs2_sparse::dia_array::DiaArray;
80    /// use scirs2_sparse::sparray::SparseArray;
81    /// use scirs2_core::ndarray::Array1;
82    ///
83    /// // Create a 3x3 sparse array with main diagonal and upper diagonal
84    /// let data = vec![
85    ///     Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
86    ///     Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal (k=1)
87    /// ];
88    /// let offsets = vec![0, 1]; // Main diagonal and k=1
89    /// let shape = (3, 3);
90    ///
91    /// let array = DiaArray::new(data, offsets, shape).unwrap();
92    /// assert_eq!(array.shape(), (3, 3));
93    /// assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
94    /// ```
95    pub fn new(
96        data: Vec<Array1<T>>,
97        offsets: Vec<isize>,
98        shape: (usize, usize),
99    ) -> SparseResult<Self> {
100        let (rows, cols) = shape;
101        let max_dim = rows.max(cols);
102
103        // Validate input data
104        if data.len() != offsets.len() {
105            return Err(SparseError::DimensionMismatch {
106                expected: data.len(),
107                found: offsets.len(),
108            });
109        }
110
111        for diag in data.iter() {
112            if diag.len() != max_dim {
113                return Err(SparseError::DimensionMismatch {
114                    expected: max_dim,
115                    found: diag.len(),
116                });
117            }
118        }
119
120        Ok(DiaArray {
121            data,
122            offsets,
123            shape,
124        })
125    }
126
127    /// Create a new empty DIA array
128    ///
129    /// # Arguments
130    ///
131    /// * `shape` - Tuple containing the array dimensions (rows, cols)
132    ///
133    /// # Returns
134    ///
135    /// * A new empty DIA array
136    pub fn empty(shape: (usize, usize)) -> Self {
137        DiaArray {
138            data: Vec::new(),
139            offsets: Vec::new(),
140            shape,
141        }
142    }
143
144    /// Convert COO format to DIA format
145    ///
146    /// # Arguments
147    ///
148    /// * `row` - Row indices
149    /// * `col` - Column indices
150    /// * `data` - Data values
151    /// * `shape` - Shape of the array
152    ///
153    /// # Returns
154    ///
155    /// * A new DIA array
156    pub fn from_triplets(
157        row: &[usize],
158        col: &[usize],
159        data: &[T],
160        shape: (usize, usize),
161    ) -> SparseResult<Self> {
162        if row.len() != col.len() || row.len() != data.len() {
163            return Err(SparseError::InconsistentData {
164                reason: "Lengths of row, col, and data arrays must be equal".to_string(),
165            });
166        }
167
168        let (rows, cols) = shape;
169        let max_dim = rows.max(cols);
170
171        // Identify unique diagonals
172        let mut diagonal_offsets = std::collections::HashSet::new();
173        for (&r, &c) in row.iter().zip(col.iter()) {
174            if r >= rows || c >= cols {
175                return Err(SparseError::IndexOutOfBounds {
176                    index: (r, c),
177                    shape,
178                });
179            }
180            // Calculate diagonal offset (column - row for diagonals)
181            let offset = c as isize - r as isize;
182            diagonal_offsets.insert(offset);
183        }
184
185        // Convert to a sorted vector
186        let mut offsets: Vec<isize> = diagonal_offsets.into_iter().collect();
187        offsets.sort();
188
189        // Create data arrays (initialized to zero)
190        let mut diag_data = Vec::with_capacity(offsets.len());
191        for _ in 0..offsets.len() {
192            diag_data.push(Array1::zeros(max_dim));
193        }
194
195        // Fill in the data
196        for (&r, (&c, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
197            let offset = c as isize - r as isize;
198            let diag_idx = offsets.iter().position(|&o| o == offset).unwrap();
199
200            // For upper diagonals (k > 0), the index is row
201            // For lower diagonals (k < 0), the index is column
202            let index = if offset >= 0 { r } else { c };
203            diag_data[diag_idx][index] = val;
204        }
205
206        DiaArray::new(diag_data, offsets, shape)
207    }
208
209    /// Convert to COO format
210    fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
211        let (rows, cols) = self.shape;
212        let mut row_indices = Vec::new();
213        let mut col_indices = Vec::new();
214        let mut values = Vec::new();
215
216        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
217            let diag = &self.data[diag_idx];
218
219            if offset >= 0 {
220                // Upper diagonal
221                let offset_usize = offset as usize;
222                let length = rows.min(cols.saturating_sub(offset_usize));
223
224                for i in 0..length {
225                    let value = diag[i];
226                    if !value.is_zero() {
227                        row_indices.push(i);
228                        col_indices.push(i + offset_usize);
229                        values.push(value);
230                    }
231                }
232            } else {
233                // Lower diagonal
234                let offset_usize = (-offset) as usize;
235                let length = cols.min(rows.saturating_sub(offset_usize));
236
237                for i in 0..length {
238                    let value = diag[i];
239                    if !value.is_zero() {
240                        row_indices.push(i + offset_usize);
241                        col_indices.push(i);
242                        values.push(value);
243                    }
244                }
245            }
246        }
247
248        (row_indices, col_indices, values)
249    }
250}
251
252impl<T> SparseArray<T> for DiaArray<T>
253where
254    T: Float
255        + Add<Output = T>
256        + Sub<Output = T>
257        + Mul<Output = T>
258        + Div<Output = T>
259        + Debug
260        + Copy
261        + 'static
262        + std::ops::AddAssign,
263{
264    fn shape(&self) -> (usize, usize) {
265        self.shape
266    }
267
268    fn nnz(&self) -> usize {
269        let (rows, cols) = self.shape;
270        let mut count = 0;
271
272        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
273            let diag = &self.data[diag_idx];
274
275            // Calculate valid range for this diagonal
276            let length = if offset >= 0 {
277                rows.min(cols.saturating_sub(offset as usize))
278            } else {
279                cols.min(rows.saturating_sub((-offset) as usize))
280            };
281
282            // Count non-zeros in the valid range
283            let start_idx = 0; // Start at 0 regardless of offset
284            for i in start_idx..start_idx + length {
285                if !diag[i].is_zero() {
286                    count += 1;
287                }
288            }
289        }
290
291        count
292    }
293
294    fn dtype(&self) -> &str {
295        "float" // Placeholder; ideally would return the actual type
296    }
297
298    fn to_array(&self) -> Array2<T> {
299        // Convert to dense format
300        let (rows, cols) = self.shape;
301        let mut result = Array2::zeros((rows, cols));
302
303        // In the test case we have:
304        // data[0] = [1.0, 3.0, 7.0] with offset 0 (main diagonal)
305        // data[1] = [4.0, 5.0, 0.0] with offset 1 (upper diagonal)
306        // data[2] = [0.0, 2.0, 6.0] with offset -1 (lower diagonal)
307
308        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
309            let diag = &self.data[diag_idx];
310
311            if offset >= 0 {
312                // Upper diagonal (k >= 0)
313                let offset_usize = offset as usize;
314                for i in 0..rows.min(cols.saturating_sub(offset_usize)) {
315                    result[[i, i + offset_usize]] = diag[i];
316                }
317            } else {
318                // Lower diagonal (k < 0)
319                let offset_usize = (-offset) as usize;
320                for i in 0..cols.min(rows.saturating_sub(offset_usize)) {
321                    result[[i + offset_usize, i]] = diag[i];
322                }
323            }
324        }
325
326        result
327    }
328
329    fn toarray(&self) -> Array2<T> {
330        self.to_array()
331    }
332
333    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
334        let (row_indices, col_indices, values) = self.to_coo_internal();
335        let row_array = Array1::from_vec(row_indices);
336        let col_array = Array1::from_vec(col_indices);
337        let data_array = Array1::from_vec(values);
338
339        CooArray::from_triplets(
340            &row_array.to_vec(),
341            &col_array.to_vec(),
342            &data_array.to_vec(),
343            self.shape,
344            false,
345        )
346        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
347    }
348
349    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
350        let (row_indices, col_indices, values) = self.to_coo_internal();
351        CsrArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
352            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
353    }
354
355    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
356        self.to_coo()?.to_csc()
357    }
358
359    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
360        let (row_indices, col_indices, values) = self.to_coo_internal();
361        DokArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
362            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
363    }
364
365    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
366        let (row_indices, col_indices, values) = self.to_coo_internal();
367        LilArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
368            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
369    }
370
371    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
372        Ok(Box::new(self.clone()))
373    }
374
375    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
376        self.to_coo()?.to_bsr()
377    }
378
379    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
380        // Convert both to CSR for efficient addition
381        let csr_self = self.to_csr()?;
382        let csr_other = other.to_csr()?;
383        csr_self.add(&*csr_other)
384    }
385
386    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
387        // Convert both to CSR for efficient subtraction
388        let csr_self = self.to_csr()?;
389        let csr_other = other.to_csr()?;
390        csr_self.sub(&*csr_other)
391    }
392
393    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
394        // Convert both to CSR for efficient element-wise multiplication
395        let csr_self = self.to_csr()?;
396        let csr_other = other.to_csr()?;
397        csr_self.mul(&*csr_other)
398    }
399
400    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
401        // Convert both to CSR for efficient element-wise division
402        let csr_self = self.to_csr()?;
403        let csr_other = other.to_csr()?;
404        csr_self.div(&*csr_other)
405    }
406
407    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
408        // For matrix multiplication, use specialized DIA-Vector logic if other is thin
409        let (_, n) = self.shape();
410        let (p, q) = other.shape();
411
412        if n != p {
413            return Err(SparseError::DimensionMismatch {
414                expected: n,
415                found: p,
416            });
417        }
418
419        // If other is a vector (thin matrix), we can use optimized DIA-Vector multiplication
420        if q == 1 {
421            // Get the vector from other
422            let other_array = other.to_array();
423            let vec_view = other_array.column(0);
424
425            // Perform DIA-Vector multiplication
426            let result = self.dot_vector(&vec_view)?;
427
428            // Convert to a matrix - create a COO from triplets
429            let mut rows = Vec::new();
430            let mut cols = Vec::new();
431            let mut values = Vec::new();
432
433            for (i, &val) in result.iter().enumerate() {
434                if !val.is_zero() {
435                    rows.push(i);
436                    cols.push(0);
437                    values.push(val);
438                }
439            }
440
441            CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
442                .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
443        } else {
444            // For general matrices, convert to CSR
445            let csr_self = self.to_csr()?;
446            csr_self.dot(other)
447        }
448    }
449
450    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
451        let (rows, cols) = self.shape;
452
453        if cols != other.len() {
454            return Err(SparseError::DimensionMismatch {
455                expected: cols,
456                found: other.len(),
457            });
458        }
459
460        let mut result = Array1::zeros(rows);
461
462        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
463            let diag = &self.data[diag_idx];
464
465            if offset >= 0 {
466                // Upper diagonal (k > 0)
467                let offset_usize = offset as usize;
468                let length = rows.min(cols.saturating_sub(offset_usize));
469
470                for i in 0..length {
471                    result[i] += diag[i] * other[i + offset_usize];
472                }
473            } else {
474                // Lower diagonal (k < 0)
475                let offset_usize = (-offset) as usize;
476                let length = cols.min(rows.saturating_sub(offset_usize));
477
478                for i in 0..length {
479                    result[i + offset_usize] += diag[i] * other[i];
480                }
481            }
482        }
483
484        Ok(result)
485    }
486
487    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
488        // For correct transposition, use COO intermediately
489        // This avoids issues with the diagonal storage format
490        let (row_indices, col_indices, values) = self.to_coo_internal();
491
492        // Swap row and column indices
493        let transposed_rows = col_indices;
494        let transposed_cols = row_indices;
495
496        // Create a new COO array and convert back to DIA
497        CooArray::from_triplets(
498            &transposed_rows,
499            &transposed_cols,
500            &values,
501            (self.shape.1, self.shape.0),
502            false,
503        )?
504        .to_dia()
505    }
506
507    fn copy(&self) -> Box<dyn SparseArray<T>> {
508        Box::new(self.clone())
509    }
510
511    fn get(&self, i: usize, j: usize) -> T {
512        if i >= self.shape.0 || j >= self.shape.1 {
513            return T::zero();
514        }
515
516        // Calculate the diagonal offset
517        let offset = j as isize - i as isize;
518
519        // Check if this offset exists in our stored diagonals
520        if let Some(diag_idx) = self.offsets.iter().position(|&o| o == offset) {
521            let diag = &self.data[diag_idx];
522
523            // For upper diagonals (k > 0), the index is row
524            // For lower diagonals (k < 0), the index is column
525            let index = if offset >= 0 { i } else { j };
526
527            // Make sure the index is within bounds
528            if index < diag.len() {
529                return diag[index];
530            }
531        }
532
533        T::zero()
534    }
535
536    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
537        if i >= self.shape.0 || j >= self.shape.1 {
538            return Err(SparseError::IndexOutOfBounds {
539                index: (i, j),
540                shape: self.shape,
541            });
542        }
543
544        // Calculate the diagonal offset
545        let offset = j as isize - i as isize;
546
547        // Find or create the diagonal
548        let diag_idx = match self.offsets.iter().position(|&o| o == offset) {
549            Some(idx) => idx,
550            None => {
551                // This diagonal doesn't exist yet, add it
552                self.offsets.push(offset);
553                self.data
554                    .push(Array1::zeros(self.shape.0.max(self.shape.1)));
555
556                // Sort the offsets and data to maintain canonical form
557                let mut offset_data: Vec<(isize, Array1<T>)> = self
558                    .offsets
559                    .iter()
560                    .cloned()
561                    .zip(self.data.drain(..))
562                    .collect();
563                offset_data.sort_by_key(|&(offset_, _)| offset_);
564
565                self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
566                self.data = offset_data.into_iter().map(|(_, data)| data).collect();
567
568                // Get the index of the newly added diagonal
569                self.offsets.iter().position(|&o| o == offset).unwrap()
570            }
571        };
572
573        // Set the value
574        let index = if offset >= 0 { i } else { j };
575        self.data[diag_idx][index] = value;
576
577        Ok(())
578    }
579
580    fn eliminate_zeros(&mut self) {
581        // Create a new set of diagonals without zeros
582        let mut new_offsets = Vec::new();
583        let mut new_data = Vec::new();
584
585        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
586            let diag = &self.data[diag_idx];
587
588            // Check if this diagonal has any non-zero values
589            let length = if offset >= 0 {
590                self.shape
591                    .0
592                    .min(self.shape.1.saturating_sub(offset as usize))
593            } else {
594                self.shape
595                    .1
596                    .min(self.shape.0.saturating_sub((-offset) as usize))
597            };
598
599            let has_nonzero = (0..length).any(|i| !diag[i].is_zero());
600
601            if has_nonzero {
602                new_offsets.push(offset);
603                new_data.push(diag.clone());
604            }
605        }
606
607        self.offsets = new_offsets;
608        self.data = new_data;
609    }
610
611    fn sort_indices(&mut self) {
612        // DIA arrays have implicitly sorted indices based on offset
613        // Sort by offset just to be sure
614        let mut offset_data: Vec<(isize, Array1<T>)> = self
615            .offsets
616            .iter()
617            .cloned()
618            .zip(self.data.drain(..))
619            .collect();
620        offset_data.sort_by_key(|&(offset_, _)| offset_);
621
622        self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
623        self.data = offset_data.into_iter().map(|(_, data)| data).collect();
624    }
625
626    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
627        // Clone and sort
628        let mut result = self.clone();
629        result.sort_indices();
630        Box::new(result)
631    }
632
633    fn has_sorted_indices(&self) -> bool {
634        // Check if offsets are sorted
635        self.offsets.windows(2).all(|w| w[0] <= w[1])
636    }
637
638    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
639        match axis {
640            None => {
641                // Sum all elements
642                let mut total = T::zero();
643
644                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
645                    let diag = &self.data[diag_idx];
646
647                    let length = if offset >= 0 {
648                        self.shape
649                            .0
650                            .min(self.shape.1.saturating_sub(offset as usize))
651                    } else {
652                        self.shape
653                            .1
654                            .min(self.shape.0.saturating_sub((-offset) as usize))
655                    };
656
657                    for i in 0..length {
658                        total += diag[i];
659                    }
660                }
661
662                Ok(SparseSum::Scalar(total))
663            }
664            Some(0) => {
665                // Sum along rows (result is 1 x cols)
666                let mut result = Array1::zeros(self.shape.1);
667
668                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
669                    let diag = &self.data[diag_idx];
670
671                    if offset >= 0 {
672                        // Upper diagonal
673                        let offset_usize = offset as usize;
674                        let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
675
676                        for i in 0..length {
677                            result[i + offset_usize] += diag[i];
678                        }
679                    } else {
680                        // Lower diagonal
681                        let offset_usize = (-offset) as usize;
682                        let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
683
684                        for i in 0..length {
685                            result[i] += diag[i];
686                        }
687                    }
688                }
689
690                // Convert to a sparse array
691                match Array2::from_shape_vec((1, self.shape.1), result.to_vec()) {
692                    Ok(result_2d) => {
693                        // Find non-zero elements
694                        let mut row_indices = Vec::new();
695                        let mut col_indices = Vec::new();
696                        let mut values = Vec::new();
697
698                        for j in 0..self.shape.1 {
699                            let val: T = result_2d[[0, j]];
700                            if !val.is_zero() {
701                                row_indices.push(0);
702                                col_indices.push(j);
703                                values.push(val);
704                            }
705                        }
706
707                        // Create COO array
708                        match CooArray::from_triplets(
709                            &row_indices,
710                            &col_indices,
711                            &values,
712                            (1, self.shape.1),
713                            false,
714                        ) {
715                            Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
716                            Err(e) => Err(e),
717                        }
718                    }
719                    Err(_) => Err(SparseError::InconsistentData {
720                        reason: "Failed to create 2D array from result vector".to_string(),
721                    }),
722                }
723            }
724            Some(1) => {
725                // Sum along columns (result is rows x 1)
726                let mut result = Array1::zeros(self.shape.0);
727
728                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
729                    let diag = &self.data[diag_idx];
730
731                    if offset >= 0 {
732                        // Upper diagonal
733                        let offset_usize = offset as usize;
734                        let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
735
736                        for i in 0..length {
737                            result[i] += diag[i];
738                        }
739                    } else {
740                        // Lower diagonal
741                        let offset_usize = (-offset) as usize;
742                        let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
743
744                        for i in 0..length {
745                            result[i + offset_usize] += diag[i];
746                        }
747                    }
748                }
749
750                // Convert to a sparse array
751                match Array2::from_shape_vec((self.shape.0, 1), result.to_vec()) {
752                    Ok(result_2d) => {
753                        // Find non-zero elements
754                        let mut row_indices = Vec::new();
755                        let mut col_indices = Vec::new();
756                        let mut values = Vec::new();
757
758                        for i in 0..self.shape.0 {
759                            let val: T = result_2d[[i, 0]];
760                            if !val.is_zero() {
761                                row_indices.push(i);
762                                col_indices.push(0);
763                                values.push(val);
764                            }
765                        }
766
767                        // Create COO array
768                        match CooArray::from_triplets(
769                            &row_indices,
770                            &col_indices,
771                            &values,
772                            (self.shape.0, 1),
773                            false,
774                        ) {
775                            Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
776                            Err(e) => Err(e),
777                        }
778                    }
779                    Err(_) => Err(SparseError::InconsistentData {
780                        reason: "Failed to create 2D array from result vector".to_string(),
781                    }),
782                }
783            }
784            _ => Err(SparseError::InvalidAxis),
785        }
786    }
787
788    fn max(&self) -> T {
789        let mut max_val = T::neg_infinity();
790
791        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
792            let diag = &self.data[diag_idx];
793
794            let length = if offset >= 0 {
795                self.shape
796                    .0
797                    .min(self.shape.1.saturating_sub(offset as usize))
798            } else {
799                self.shape
800                    .1
801                    .min(self.shape.0.saturating_sub((-offset) as usize))
802            };
803
804            for i in 0..length {
805                max_val = max_val.max(diag[i]);
806            }
807        }
808
809        // If no elements or all negative infinity, return zero
810        if max_val == T::neg_infinity() {
811            T::zero()
812        } else {
813            max_val
814        }
815    }
816
817    fn min(&self) -> T {
818        let mut min_val = T::infinity();
819        let mut has_nonzero = false;
820
821        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
822            let diag = &self.data[diag_idx];
823
824            let length = if offset >= 0 {
825                self.shape
826                    .0
827                    .min(self.shape.1.saturating_sub(offset as usize))
828            } else {
829                self.shape
830                    .1
831                    .min(self.shape.0.saturating_sub((-offset) as usize))
832            };
833
834            for i in 0..length {
835                if !diag[i].is_zero() {
836                    has_nonzero = true;
837                    min_val = min_val.min(diag[i]);
838                }
839            }
840        }
841
842        // If no non-zero elements, return zero
843        if !has_nonzero {
844            T::zero()
845        } else {
846            min_val
847        }
848    }
849
850    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
851        let (row_indices, col_indices, values) = self.to_coo_internal();
852
853        (
854            Array1::from_vec(row_indices),
855            Array1::from_vec(col_indices),
856            Array1::from_vec(values),
857        )
858    }
859
860    fn slice(
861        &self,
862        row_range: (usize, usize),
863        col_range: (usize, usize),
864    ) -> SparseResult<Box<dyn SparseArray<T>>> {
865        let (start_row, end_row) = row_range;
866        let (start_col, end_col) = col_range;
867        let (rows, cols) = self.shape;
868
869        if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
870            return Err(SparseError::IndexOutOfBounds {
871                index: (start_row.max(end_row), start_col.max(end_col)),
872                shape: (rows, cols),
873            });
874        }
875
876        if start_row >= end_row || start_col >= end_col {
877            return Err(SparseError::InvalidSliceRange);
878        }
879
880        // Convert to COO, then slice, then convert back to DIA
881        let coo = self.to_coo()?;
882        coo.slice(row_range, col_range)?.to_dia()
883    }
884
885    fn as_any(&self) -> &dyn std::any::Any {
886        self
887    }
888}
889
890// Implement Display for DiaArray for better debugging
891impl<T> fmt::Display for DiaArray<T>
892where
893    T: Float
894        + Add<Output = T>
895        + Sub<Output = T>
896        + Mul<Output = T>
897        + Div<Output = T>
898        + Debug
899        + Copy
900        + 'static
901        + std::ops::AddAssign,
902{
903    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
904        writeln!(
905            f,
906            "DiaArray of shape {:?} with {} stored elements",
907            self.shape,
908            self.nnz()
909        )?;
910        writeln!(f, "Offsets: {:?}", self.offsets)?;
911
912        if self.offsets.len() <= 5 {
913            for (i, &offset) in self.offsets.iter().enumerate() {
914                let diag = &self.data[i];
915                let length = if offset >= 0 {
916                    self.shape
917                        .0
918                        .min(self.shape.1.saturating_sub(offset as usize))
919                } else {
920                    self.shape
921                        .1
922                        .min(self.shape.0.saturating_sub((-offset) as usize))
923                };
924
925                write!(f, "Diagonal {offset}: [")?;
926                for j in 0..length.min(10) {
927                    if j > 0 {
928                        write!(f, ", ")?;
929                    }
930                    write!(f, "{:?}", diag[j])?;
931                }
932                if length > 10 {
933                    write!(f, ", ...")?;
934                }
935                writeln!(f, "]")?;
936            }
937        } else {
938            writeln!(f, "({} diagonals)", self.offsets.len())?;
939        }
940
941        Ok(())
942    }
943}
944
945#[cfg(test)]
946mod tests {
947    use super::*;
948
949    #[test]
950    fn test_dia_array_create() {
951        // Create a 3x3 sparse array with main diagonal and upper diagonal
952        let data = vec![
953            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
954            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal (k=1)
955        ];
956        let offsets = vec![0, 1]; // Main diagonal and k=1
957        let shape = (3, 3);
958
959        let array = DiaArray::new(data, offsets, shape).unwrap();
960
961        assert_eq!(array.shape(), (3, 3));
962        assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
963
964        // Test values
965        assert_eq!(array.get(0, 0), 1.0);
966        assert_eq!(array.get(1, 1), 2.0);
967        assert_eq!(array.get(2, 2), 3.0);
968        assert_eq!(array.get(0, 1), 4.0);
969        assert_eq!(array.get(1, 2), 5.0);
970        assert_eq!(array.get(0, 2), 0.0);
971    }
972
973    #[test]
974    fn test_dia_array_from_triplets() {
975        // Create a tridiagonal matrix
976        let row = vec![0, 0, 1, 1, 1, 2, 2];
977        let col = vec![0, 1, 0, 1, 2, 1, 2];
978        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0, 6.0, 7.0];
979        let shape = (3, 3);
980
981        let array = DiaArray::from_triplets(&row, &col, &data, shape).unwrap();
982
983        // Should have 3 diagonals: main (0), upper (1), and lower (-1)
984        assert_eq!(array.offsets.len(), 3);
985        assert!(array.offsets.contains(&0));
986        assert!(array.offsets.contains(&1));
987        assert!(array.offsets.contains(&-1));
988
989        // Test values
990        assert_eq!(array.get(0, 0), 1.0);
991        assert_eq!(array.get(0, 1), 4.0);
992        assert_eq!(array.get(1, 0), 2.0);
993        assert_eq!(array.get(1, 1), 3.0);
994        assert_eq!(array.get(1, 2), 5.0);
995        assert_eq!(array.get(2, 1), 6.0);
996        assert_eq!(array.get(2, 2), 7.0);
997    }
998
999    #[test]
1000    fn test_dia_array_conversion() {
1001        // Create a tridiagonal matrix
1002        let data = vec![
1003            Array1::from_vec(vec![1.0, 3.0, 7.0]), // Main diagonal
1004            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1005            Array1::from_vec(vec![0.0, 2.0, 0.0]), // Lower diagonal at index 1 (2.0 instead of 6.0)
1006        ];
1007        let offsets = vec![0, 1, -1]; // Main, upper, lower
1008        let shape = (3, 3);
1009
1010        let array = DiaArray::new(data, offsets, shape).unwrap();
1011
1012        // Convert to COO and check
1013        let coo = array.to_coo().unwrap();
1014        assert_eq!(coo.shape(), (3, 3));
1015        assert_eq!(coo.nnz(), 6); // Zero value at (2,1) is not stored
1016
1017        // Convert to dense and check
1018        let dense = array.to_array();
1019
1020        // Debug print the array
1021        // println!("Dense array: {:?}", dense);
1022
1023        let expected =
1024            Array2::from_shape_vec((3, 3), vec![1.0, 4.0, 0.0, 0.0, 3.0, 5.0, 0.0, 2.0, 7.0])
1025                .unwrap();
1026        assert_eq!(dense, expected);
1027    }
1028
1029    #[test]
1030    fn test_dia_array_operations() {
1031        // Create two simple diagonal matrices
1032        let data1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])]; // Main diagonal
1033        let offsets1 = vec![0];
1034        let shape1 = (3, 3);
1035        let array1 = DiaArray::new(data1, offsets1, shape1).unwrap();
1036
1037        let data2 = vec![Array1::from_vec(vec![4.0, 5.0, 6.0])]; // Main diagonal
1038        let offsets2 = vec![0];
1039        let shape2 = (3, 3);
1040        let array2 = DiaArray::new(data2, offsets2, shape2).unwrap();
1041
1042        // Test addition
1043        let sum = array1.add(&array2).unwrap();
1044        assert_eq!(sum.get(0, 0), 5.0);
1045        assert_eq!(sum.get(1, 1), 7.0);
1046        assert_eq!(sum.get(2, 2), 9.0);
1047
1048        // Test multiplication
1049        let product = array1.mul(&array2).unwrap();
1050        assert_eq!(product.get(0, 0), 4.0);
1051        assert_eq!(product.get(1, 1), 10.0);
1052        assert_eq!(product.get(2, 2), 18.0);
1053
1054        // Test dot product (matrix multiplication)
1055        let dot = array1.dot(&array2).unwrap();
1056        assert_eq!(dot.get(0, 0), 4.0);
1057        assert_eq!(dot.get(1, 1), 10.0);
1058        assert_eq!(dot.get(2, 2), 18.0);
1059    }
1060
1061    #[test]
1062    fn test_dia_array_dot_vector() {
1063        // Create a tridiagonal matrix
1064        let data = vec![
1065            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1066            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1067            Array1::from_vec(vec![0.0, 6.0, 7.0]), // Lower diagonal
1068        ];
1069        let offsets = vec![0, 1, -1]; // Main, upper, lower
1070        let shape = (3, 3);
1071
1072        let array = DiaArray::new(data, offsets, shape).unwrap();
1073
1074        // Create a vector
1075        let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1076
1077        // Test matrix-vector multiplication
1078        let result = array.dot_vector(&vector.view()).unwrap();
1079
1080        // Expected: [1*1 + 4*2 + 0*3, 6*1 + 2*2 + 5*3, 0*1 + 7*2 + 3*3]
1081        // = [9, 19, 21]
1082        let expected = Array1::from_vec(vec![9.0, 19.0, 21.0]);
1083        assert_eq!(result, expected);
1084    }
1085
1086    #[test]
1087    fn test_dia_array_transpose() {
1088        // Create a tridiagonal matrix
1089        let data = vec![
1090            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1091            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1092            Array1::from_vec(vec![0.0, 6.0, 7.0]), // Lower diagonal
1093        ];
1094        let offsets = vec![0, 1, -1]; // Main, upper, lower
1095        let shape = (3, 3);
1096
1097        let array = DiaArray::new(data, offsets, shape).unwrap();
1098        let transposed = array.transpose().unwrap();
1099
1100        // Check shape
1101        assert_eq!(transposed.shape(), (3, 3));
1102
1103        // Compare the dense array representations
1104        let original_dense = array.to_array();
1105        let transposed_dense = transposed.to_array();
1106
1107        for i in 0..3 {
1108            for j in 0..3 {
1109                assert_eq!(transposed_dense[[i, j]], original_dense[[j, i]]);
1110            }
1111        }
1112    }
1113
1114    #[test]
1115    fn test_dia_array_sum() {
1116        // Create a simple matrix
1117        let data = vec![
1118            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1119            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1120        ];
1121        let offsets = vec![0, 1]; // Main, upper
1122        let shape = (3, 3);
1123
1124        let array = DiaArray::new(data, offsets, shape).unwrap();
1125
1126        // Test sum of entire array
1127        if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1128            assert_eq!(sum, 15.0); // 1+2+3+4+5 = 15
1129        } else {
1130            panic!("Expected SparseSum::Scalar");
1131        }
1132
1133        // Test sum along rows
1134        if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1135            assert_eq!(row_sum.shape(), (1, 3));
1136            assert_eq!(row_sum.get(0, 0), 1.0);
1137            assert_eq!(row_sum.get(0, 1), 6.0); // 2+4 = 6
1138            assert_eq!(row_sum.get(0, 2), 8.0); // 3+5 = 8
1139        } else {
1140            panic!("Expected SparseSum::SparseArray");
1141        }
1142
1143        // Test sum along columns
1144        if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1145            assert_eq!(col_sum.shape(), (3, 1));
1146            assert_eq!(col_sum.get(0, 0), 5.0); // 1+4 = 5
1147            assert_eq!(col_sum.get(1, 0), 7.0); // 2+5 = 7
1148            assert_eq!(col_sum.get(2, 0), 3.0);
1149        } else {
1150            panic!("Expected SparseSum::SparseArray");
1151        }
1152    }
1153}