scirs2_sparse/
dia_array.rs

1// DIA Array implementation
2//
3// This module provides the DIA (DIAgonal) array format,
4// which is efficient for matrices with values concentrated on a small number of diagonals.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::dok_array::DokArray;
14use crate::error::{SparseError, SparseResult};
15use crate::lil_array::LilArray;
16use crate::sparray::{SparseArray, SparseSum};
17
18/// DIA Array format
19///
20/// The DIA (DIAgonal) format stores data as a collection of diagonals.
21/// It is efficient for matrices with values concentrated on a small number of diagonals,
22/// like tridiagonal or band matrices.
23///
24/// # Notes
25///
26/// - Very efficient storage for band matrices
27/// - Fast matrix-vector products for banded matrices
28/// - Not efficient for general sparse matrices
29/// - Difficult to modify once constructed
30///
31#[derive(Clone)]
32pub struct DiaArray<T>
33where
34    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
35{
36    /// Diagonals data (n_diags x max(rows, cols))
37    data: Vec<Array1<T>>,
38    /// Diagonal offsets from the main diagonal (k > 0 for above, k < 0 for below)
39    offsets: Vec<isize>,
40    /// Shape of the array
41    shape: (usize, usize),
42}
43
44impl<T> DiaArray<T>
45where
46    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
47{
48    /// Create a new DIA array from raw data
49    ///
50    /// # Arguments
51    ///
52    /// * `data` - Diagonals data (n_diags x max(rows, cols))
53    /// * `offsets` - Diagonal offsets from the main diagonal
54    /// * `shape` - Tuple containing the array dimensions (rows, cols)
55    ///
56    /// # Returns
57    ///
58    /// * A new DIA array
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// use scirs2_sparse::dia_array::DiaArray;
64    /// use scirs2_sparse::sparray::SparseArray;
65    /// use scirs2_core::ndarray::Array1;
66    ///
67    /// // Create a 3x3 sparse array with main diagonal and upper diagonal
68    /// let data = vec![
69    ///     Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
70    ///     Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal (k=1)
71    /// ];
72    /// let offsets = vec![0, 1]; // Main diagonal and k=1
73    /// let shape = (3, 3);
74    ///
75    /// let array = DiaArray::new(data, offsets, shape).unwrap();
76    /// assert_eq!(array.shape(), (3, 3));
77    /// assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
78    /// ```
79    pub fn new(
80        data: Vec<Array1<T>>,
81        offsets: Vec<isize>,
82        shape: (usize, usize),
83    ) -> SparseResult<Self> {
84        let (rows, cols) = shape;
85        let max_dim = rows.max(cols);
86
87        // Validate input data
88        if data.len() != offsets.len() {
89            return Err(SparseError::DimensionMismatch {
90                expected: data.len(),
91                found: offsets.len(),
92            });
93        }
94
95        for diag in data.iter() {
96            if diag.len() != max_dim {
97                return Err(SparseError::DimensionMismatch {
98                    expected: max_dim,
99                    found: diag.len(),
100                });
101            }
102        }
103
104        Ok(DiaArray {
105            data,
106            offsets,
107            shape,
108        })
109    }
110
111    /// Create a new empty DIA array
112    ///
113    /// # Arguments
114    ///
115    /// * `shape` - Tuple containing the array dimensions (rows, cols)
116    ///
117    /// # Returns
118    ///
119    /// * A new empty DIA array
120    pub fn empty(shape: (usize, usize)) -> Self {
121        DiaArray {
122            data: Vec::new(),
123            offsets: Vec::new(),
124            shape,
125        }
126    }
127
128    /// Convert COO format to DIA format
129    ///
130    /// # Arguments
131    ///
132    /// * `row` - Row indices
133    /// * `col` - Column indices
134    /// * `data` - Data values
135    /// * `shape` - Shape of the array
136    ///
137    /// # Returns
138    ///
139    /// * A new DIA array
140    pub fn from_triplets(
141        row: &[usize],
142        col: &[usize],
143        data: &[T],
144        shape: (usize, usize),
145    ) -> SparseResult<Self> {
146        if row.len() != col.len() || row.len() != data.len() {
147            return Err(SparseError::InconsistentData {
148                reason: "Lengths of row, col, and data arrays must be equal".to_string(),
149            });
150        }
151
152        let (rows, cols) = shape;
153        let max_dim = rows.max(cols);
154
155        // Identify unique diagonals
156        let mut diagonal_offsets = std::collections::HashSet::new();
157        for (&r, &c) in row.iter().zip(col.iter()) {
158            if r >= rows || c >= cols {
159                return Err(SparseError::IndexOutOfBounds {
160                    index: (r, c),
161                    shape,
162                });
163            }
164            // Calculate diagonal offset (column - row for diagonals)
165            let offset = c as isize - r as isize;
166            diagonal_offsets.insert(offset);
167        }
168
169        // Convert to a sorted vector
170        let mut offsets: Vec<isize> = diagonal_offsets.into_iter().collect();
171        offsets.sort();
172
173        // Create data arrays (initialized to zero)
174        let mut diag_data = Vec::with_capacity(offsets.len());
175        for _ in 0..offsets.len() {
176            diag_data.push(Array1::zeros(max_dim));
177        }
178
179        // Fill in the data
180        for (&r, (&c, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
181            let offset = c as isize - r as isize;
182            let diag_idx = offsets.iter().position(|&o| o == offset).unwrap();
183
184            // For upper diagonals (k > 0), the index is row
185            // For lower diagonals (k < 0), the index is column
186            let index = if offset >= 0 { r } else { c };
187            diag_data[diag_idx][index] = val;
188        }
189
190        DiaArray::new(diag_data, offsets, shape)
191    }
192
193    /// Convert to COO format
194    fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
195        let (rows, cols) = self.shape;
196        let mut row_indices = Vec::new();
197        let mut col_indices = Vec::new();
198        let mut values = Vec::new();
199
200        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
201            let diag = &self.data[diag_idx];
202
203            if offset >= 0 {
204                // Upper diagonal
205                let offset_usize = offset as usize;
206                let length = rows.min(cols.saturating_sub(offset_usize));
207
208                for i in 0..length {
209                    let value = diag[i];
210                    if !SparseElement::is_zero(&value) {
211                        row_indices.push(i);
212                        col_indices.push(i + offset_usize);
213                        values.push(value);
214                    }
215                }
216            } else {
217                // Lower diagonal
218                let offset_usize = (-offset) as usize;
219                let length = cols.min(rows.saturating_sub(offset_usize));
220
221                for i in 0..length {
222                    let value = diag[i];
223                    if !SparseElement::is_zero(&value) {
224                        row_indices.push(i + offset_usize);
225                        col_indices.push(i);
226                        values.push(value);
227                    }
228                }
229            }
230        }
231
232        (row_indices, col_indices, values)
233    }
234}
235
236impl<T> SparseArray<T> for DiaArray<T>
237where
238    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
239{
240    fn shape(&self) -> (usize, usize) {
241        self.shape
242    }
243
244    fn nnz(&self) -> usize {
245        let (rows, cols) = self.shape;
246        let mut count = 0;
247
248        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
249            let diag = &self.data[diag_idx];
250
251            // Calculate valid range for this diagonal
252            let length = if offset >= 0 {
253                rows.min(cols.saturating_sub(offset as usize))
254            } else {
255                cols.min(rows.saturating_sub((-offset) as usize))
256            };
257
258            // Count non-zeros in the valid range
259            let start_idx = 0; // Start at 0 regardless of offset
260            for i in start_idx..start_idx + length {
261                if !SparseElement::is_zero(&diag[i]) {
262                    count += 1;
263                }
264            }
265        }
266
267        count
268    }
269
270    fn dtype(&self) -> &str {
271        "float" // Placeholder; ideally would return the actual type
272    }
273
274    fn to_array(&self) -> Array2<T> {
275        // Convert to dense format
276        let (rows, cols) = self.shape;
277        let mut result = Array2::zeros((rows, cols));
278
279        // In the test case we have:
280        // data[0] = [1.0, 3.0, 7.0] with offset 0 (main diagonal)
281        // data[1] = [4.0, 5.0, 0.0] with offset 1 (upper diagonal)
282        // data[2] = [0.0, 2.0, 6.0] with offset -1 (lower diagonal)
283
284        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
285            let diag = &self.data[diag_idx];
286
287            if offset >= 0 {
288                // Upper diagonal (k >= 0)
289                let offset_usize = offset as usize;
290                for i in 0..rows.min(cols.saturating_sub(offset_usize)) {
291                    result[[i, i + offset_usize]] = diag[i];
292                }
293            } else {
294                // Lower diagonal (k < 0)
295                let offset_usize = (-offset) as usize;
296                for i in 0..cols.min(rows.saturating_sub(offset_usize)) {
297                    result[[i + offset_usize, i]] = diag[i];
298                }
299            }
300        }
301
302        result
303    }
304
305    fn toarray(&self) -> Array2<T> {
306        self.to_array()
307    }
308
309    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
310        let (row_indices, col_indices, values) = self.to_coo_internal();
311        let row_array = Array1::from_vec(row_indices);
312        let col_array = Array1::from_vec(col_indices);
313        let data_array = Array1::from_vec(values);
314
315        CooArray::from_triplets(
316            &row_array.to_vec(),
317            &col_array.to_vec(),
318            &data_array.to_vec(),
319            self.shape,
320            false,
321        )
322        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
323    }
324
325    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
326        let (row_indices, col_indices, values) = self.to_coo_internal();
327        CsrArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
328            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
329    }
330
331    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
332        self.to_coo()?.to_csc()
333    }
334
335    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
336        let (row_indices, col_indices, values) = self.to_coo_internal();
337        DokArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
338            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
339    }
340
341    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
342        let (row_indices, col_indices, values) = self.to_coo_internal();
343        LilArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
344            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
345    }
346
347    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
348        Ok(Box::new(self.clone()))
349    }
350
351    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
352        self.to_coo()?.to_bsr()
353    }
354
355    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
356        // Convert both to CSR for efficient addition
357        let csr_self = self.to_csr()?;
358        let csr_other = other.to_csr()?;
359        csr_self.add(&*csr_other)
360    }
361
362    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
363        // Convert both to CSR for efficient subtraction
364        let csr_self = self.to_csr()?;
365        let csr_other = other.to_csr()?;
366        csr_self.sub(&*csr_other)
367    }
368
369    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
370        // Convert both to CSR for efficient element-wise multiplication
371        let csr_self = self.to_csr()?;
372        let csr_other = other.to_csr()?;
373        csr_self.mul(&*csr_other)
374    }
375
376    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
377        // Convert both to CSR for efficient element-wise division
378        let csr_self = self.to_csr()?;
379        let csr_other = other.to_csr()?;
380        csr_self.div(&*csr_other)
381    }
382
383    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
384        // For matrix multiplication, use specialized DIA-Vector logic if other is thin
385        let (_, n) = self.shape();
386        let (p, q) = other.shape();
387
388        if n != p {
389            return Err(SparseError::DimensionMismatch {
390                expected: n,
391                found: p,
392            });
393        }
394
395        // If other is a vector (thin matrix), we can use optimized DIA-Vector multiplication
396        if q == 1 {
397            // Get the vector from other
398            let other_array = other.to_array();
399            let vec_view = other_array.column(0);
400
401            // Perform DIA-Vector multiplication
402            let result = self.dot_vector(&vec_view)?;
403
404            // Convert to a matrix - create a COO from triplets
405            let mut rows = Vec::new();
406            let mut cols = Vec::new();
407            let mut values = Vec::new();
408
409            for (i, &val) in result.iter().enumerate() {
410                if !SparseElement::is_zero(&val) {
411                    rows.push(i);
412                    cols.push(0);
413                    values.push(val);
414                }
415            }
416
417            CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
418                .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
419        } else {
420            // For general matrices, convert to CSR
421            let csr_self = self.to_csr()?;
422            csr_self.dot(other)
423        }
424    }
425
426    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
427        let (rows, cols) = self.shape;
428
429        if cols != other.len() {
430            return Err(SparseError::DimensionMismatch {
431                expected: cols,
432                found: other.len(),
433            });
434        }
435
436        let mut result = Array1::zeros(rows);
437
438        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
439            let diag = &self.data[diag_idx];
440
441            if offset >= 0 {
442                // Upper diagonal (k > 0)
443                let offset_usize = offset as usize;
444                let length = rows.min(cols.saturating_sub(offset_usize));
445
446                for i in 0..length {
447                    result[i] += diag[i] * other[i + offset_usize];
448                }
449            } else {
450                // Lower diagonal (k < 0)
451                let offset_usize = (-offset) as usize;
452                let length = cols.min(rows.saturating_sub(offset_usize));
453
454                for i in 0..length {
455                    result[i + offset_usize] += diag[i] * other[i];
456                }
457            }
458        }
459
460        Ok(result)
461    }
462
463    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
464        // For correct transposition, use COO intermediately
465        // This avoids issues with the diagonal storage format
466        let (row_indices, col_indices, values) = self.to_coo_internal();
467
468        // Swap row and column indices
469        let transposed_rows = col_indices;
470        let transposed_cols = row_indices;
471
472        // Create a new COO array and convert back to DIA
473        CooArray::from_triplets(
474            &transposed_rows,
475            &transposed_cols,
476            &values,
477            (self.shape.1, self.shape.0),
478            false,
479        )?
480        .to_dia()
481    }
482
483    fn copy(&self) -> Box<dyn SparseArray<T>> {
484        Box::new(self.clone())
485    }
486
487    fn get(&self, i: usize, j: usize) -> T {
488        if i >= self.shape.0 || j >= self.shape.1 {
489            return T::sparse_zero();
490        }
491
492        // Calculate the diagonal offset
493        let offset = j as isize - i as isize;
494
495        // Check if this offset exists in our stored diagonals
496        if let Some(diag_idx) = self.offsets.iter().position(|&o| o == offset) {
497            let diag = &self.data[diag_idx];
498
499            // For upper diagonals (k > 0), the index is row
500            // For lower diagonals (k < 0), the index is column
501            let index = if offset >= 0 { i } else { j };
502
503            // Make sure the index is within bounds
504            if index < diag.len() {
505                return diag[index];
506            }
507        }
508
509        T::sparse_zero()
510    }
511
512    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
513        if i >= self.shape.0 || j >= self.shape.1 {
514            return Err(SparseError::IndexOutOfBounds {
515                index: (i, j),
516                shape: self.shape,
517            });
518        }
519
520        // Calculate the diagonal offset
521        let offset = j as isize - i as isize;
522
523        // Find or create the diagonal
524        let diag_idx = match self.offsets.iter().position(|&o| o == offset) {
525            Some(idx) => idx,
526            None => {
527                // This diagonal doesn't exist yet, add it
528                self.offsets.push(offset);
529                self.data
530                    .push(Array1::zeros(self.shape.0.max(self.shape.1)));
531
532                // Sort the offsets and data to maintain canonical form
533                let mut offset_data: Vec<(isize, Array1<T>)> = self
534                    .offsets
535                    .iter()
536                    .cloned()
537                    .zip(self.data.drain(..))
538                    .collect();
539                offset_data.sort_by_key(|&(offset_, _)| offset_);
540
541                self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
542                self.data = offset_data.into_iter().map(|(_, data)| data).collect();
543
544                // Get the index of the newly added diagonal
545                self.offsets.iter().position(|&o| o == offset).unwrap()
546            }
547        };
548
549        // Set the value
550        let index = if offset >= 0 { i } else { j };
551        self.data[diag_idx][index] = value;
552
553        Ok(())
554    }
555
556    fn eliminate_zeros(&mut self) {
557        // Create a new set of diagonals without zeros
558        let mut new_offsets = Vec::new();
559        let mut new_data = Vec::new();
560
561        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
562            let diag = &self.data[diag_idx];
563
564            // Check if this diagonal has any non-zero values
565            let length = if offset >= 0 {
566                self.shape
567                    .0
568                    .min(self.shape.1.saturating_sub(offset as usize))
569            } else {
570                self.shape
571                    .1
572                    .min(self.shape.0.saturating_sub((-offset) as usize))
573            };
574
575            let has_nonzero = (0..length).any(|i| !SparseElement::is_zero(&diag[i]));
576
577            if has_nonzero {
578                new_offsets.push(offset);
579                new_data.push(diag.clone());
580            }
581        }
582
583        self.offsets = new_offsets;
584        self.data = new_data;
585    }
586
587    fn sort_indices(&mut self) {
588        // DIA arrays have implicitly sorted indices based on offset
589        // Sort by offset just to be sure
590        let mut offset_data: Vec<(isize, Array1<T>)> = self
591            .offsets
592            .iter()
593            .cloned()
594            .zip(self.data.drain(..))
595            .collect();
596        offset_data.sort_by_key(|&(offset_, _)| offset_);
597
598        self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
599        self.data = offset_data.into_iter().map(|(_, data)| data).collect();
600    }
601
602    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
603        // Clone and sort
604        let mut result = self.clone();
605        result.sort_indices();
606        Box::new(result)
607    }
608
609    fn has_sorted_indices(&self) -> bool {
610        // Check if offsets are sorted
611        self.offsets.windows(2).all(|w| w[0] <= w[1])
612    }
613
614    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
615        match axis {
616            None => {
617                // Sum all elements
618                let mut total = T::sparse_zero();
619
620                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
621                    let diag = &self.data[diag_idx];
622
623                    let length = if offset >= 0 {
624                        self.shape
625                            .0
626                            .min(self.shape.1.saturating_sub(offset as usize))
627                    } else {
628                        self.shape
629                            .1
630                            .min(self.shape.0.saturating_sub((-offset) as usize))
631                    };
632
633                    for i in 0..length {
634                        total += diag[i];
635                    }
636                }
637
638                Ok(SparseSum::Scalar(total))
639            }
640            Some(0) => {
641                // Sum along rows (result is 1 x cols)
642                let mut result = Array1::zeros(self.shape.1);
643
644                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
645                    let diag = &self.data[diag_idx];
646
647                    if offset >= 0 {
648                        // Upper diagonal
649                        let offset_usize = offset as usize;
650                        let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
651
652                        for i in 0..length {
653                            result[i + offset_usize] += diag[i];
654                        }
655                    } else {
656                        // Lower diagonal
657                        let offset_usize = (-offset) as usize;
658                        let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
659
660                        for i in 0..length {
661                            result[i] += diag[i];
662                        }
663                    }
664                }
665
666                // Convert to a sparse array
667                match Array2::from_shape_vec((1, self.shape.1), result.to_vec()) {
668                    Ok(result_2d) => {
669                        // Find non-zero elements
670                        let mut row_indices = Vec::new();
671                        let mut col_indices = Vec::new();
672                        let mut values = Vec::new();
673
674                        for j in 0..self.shape.1 {
675                            let val: T = result_2d[[0, j]];
676                            if !SparseElement::is_zero(&val) {
677                                row_indices.push(0);
678                                col_indices.push(j);
679                                values.push(val);
680                            }
681                        }
682
683                        // Create COO array
684                        match CooArray::from_triplets(
685                            &row_indices,
686                            &col_indices,
687                            &values,
688                            (1, self.shape.1),
689                            false,
690                        ) {
691                            Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
692                            Err(e) => Err(e),
693                        }
694                    }
695                    Err(_) => Err(SparseError::InconsistentData {
696                        reason: "Failed to create 2D array from result vector".to_string(),
697                    }),
698                }
699            }
700            Some(1) => {
701                // Sum along columns (result is rows x 1)
702                let mut result = Array1::zeros(self.shape.0);
703
704                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
705                    let diag = &self.data[diag_idx];
706
707                    if offset >= 0 {
708                        // Upper diagonal
709                        let offset_usize = offset as usize;
710                        let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
711
712                        for i in 0..length {
713                            result[i] += diag[i];
714                        }
715                    } else {
716                        // Lower diagonal
717                        let offset_usize = (-offset) as usize;
718                        let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
719
720                        for i in 0..length {
721                            result[i + offset_usize] += diag[i];
722                        }
723                    }
724                }
725
726                // Convert to a sparse array
727                match Array2::from_shape_vec((self.shape.0, 1), result.to_vec()) {
728                    Ok(result_2d) => {
729                        // Find non-zero elements
730                        let mut row_indices = Vec::new();
731                        let mut col_indices = Vec::new();
732                        let mut values = Vec::new();
733
734                        for i in 0..self.shape.0 {
735                            let val: T = result_2d[[i, 0]];
736                            if !SparseElement::is_zero(&val) {
737                                row_indices.push(i);
738                                col_indices.push(0);
739                                values.push(val);
740                            }
741                        }
742
743                        // Create COO array
744                        match CooArray::from_triplets(
745                            &row_indices,
746                            &col_indices,
747                            &values,
748                            (self.shape.0, 1),
749                            false,
750                        ) {
751                            Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
752                            Err(e) => Err(e),
753                        }
754                    }
755                    Err(_) => Err(SparseError::InconsistentData {
756                        reason: "Failed to create 2D array from result vector".to_string(),
757                    }),
758                }
759            }
760            _ => Err(SparseError::InvalidAxis),
761        }
762    }
763
764    fn max(&self) -> T {
765        let mut max_val = T::neg_infinity();
766
767        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
768            let diag = &self.data[diag_idx];
769
770            let length = if offset >= 0 {
771                self.shape
772                    .0
773                    .min(self.shape.1.saturating_sub(offset as usize))
774            } else {
775                self.shape
776                    .1
777                    .min(self.shape.0.saturating_sub((-offset) as usize))
778            };
779
780            for i in 0..length {
781                max_val = max_val.max(diag[i]);
782            }
783        }
784
785        // If no elements or all negative infinity, return zero
786        if max_val == T::neg_infinity() {
787            T::sparse_zero()
788        } else {
789            max_val
790        }
791    }
792
793    fn min(&self) -> T {
794        let mut min_val = T::sparse_zero();
795        let mut has_nonzero = false;
796
797        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
798            let diag = &self.data[diag_idx];
799
800            let length = if offset >= 0 {
801                self.shape
802                    .0
803                    .min(self.shape.1.saturating_sub(offset as usize))
804            } else {
805                self.shape
806                    .1
807                    .min(self.shape.0.saturating_sub((-offset) as usize))
808            };
809
810            for i in 0..length {
811                if !SparseElement::is_zero(&diag[i]) {
812                    has_nonzero = true;
813                    min_val = min_val.min(diag[i]);
814                }
815            }
816        }
817
818        // If no non-zero elements, return zero
819        if !has_nonzero {
820            T::sparse_zero()
821        } else {
822            min_val
823        }
824    }
825
826    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
827        let (row_indices, col_indices, values) = self.to_coo_internal();
828
829        (
830            Array1::from_vec(row_indices),
831            Array1::from_vec(col_indices),
832            Array1::from_vec(values),
833        )
834    }
835
836    fn slice(
837        &self,
838        row_range: (usize, usize),
839        col_range: (usize, usize),
840    ) -> SparseResult<Box<dyn SparseArray<T>>> {
841        let (start_row, end_row) = row_range;
842        let (start_col, end_col) = col_range;
843        let (rows, cols) = self.shape;
844
845        if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
846            return Err(SparseError::IndexOutOfBounds {
847                index: (start_row.max(end_row), start_col.max(end_col)),
848                shape: (rows, cols),
849            });
850        }
851
852        if start_row >= end_row || start_col >= end_col {
853            return Err(SparseError::InvalidSliceRange);
854        }
855
856        // Convert to COO, then slice, then convert back to DIA
857        let coo = self.to_coo()?;
858        coo.slice(row_range, col_range)?.to_dia()
859    }
860
861    fn as_any(&self) -> &dyn std::any::Any {
862        self
863    }
864}
865
866// Implement Display for DiaArray for better debugging
867impl<T> fmt::Display for DiaArray<T>
868where
869    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
870{
871    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
872        writeln!(
873            f,
874            "DiaArray of shape {:?} with {} stored elements",
875            self.shape,
876            self.nnz()
877        )?;
878        writeln!(f, "Offsets: {:?}", self.offsets)?;
879
880        if self.offsets.len() <= 5 {
881            for (i, &offset) in self.offsets.iter().enumerate() {
882                let diag = &self.data[i];
883                let length = if offset >= 0 {
884                    self.shape
885                        .0
886                        .min(self.shape.1.saturating_sub(offset as usize))
887                } else {
888                    self.shape
889                        .1
890                        .min(self.shape.0.saturating_sub((-offset) as usize))
891                };
892
893                write!(f, "Diagonal {offset}: [")?;
894                for j in 0..length.min(10) {
895                    if j > 0 {
896                        write!(f, ", ")?;
897                    }
898                    write!(f, "{:?}", diag[j])?;
899                }
900                if length > 10 {
901                    write!(f, ", ...")?;
902                }
903                writeln!(f, "]")?;
904            }
905        } else {
906            writeln!(f, "({} diagonals)", self.offsets.len())?;
907        }
908
909        Ok(())
910    }
911}
912
913#[cfg(test)]
914mod tests {
915    use super::*;
916
917    #[test]
918    fn test_dia_array_create() {
919        // Create a 3x3 sparse array with main diagonal and upper diagonal
920        let data = vec![
921            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
922            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal (k=1)
923        ];
924        let offsets = vec![0, 1]; // Main diagonal and k=1
925        let shape = (3, 3);
926
927        let array = DiaArray::new(data, offsets, shape).unwrap();
928
929        assert_eq!(array.shape(), (3, 3));
930        assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
931
932        // Test values
933        assert_eq!(array.get(0, 0), 1.0);
934        assert_eq!(array.get(1, 1), 2.0);
935        assert_eq!(array.get(2, 2), 3.0);
936        assert_eq!(array.get(0, 1), 4.0);
937        assert_eq!(array.get(1, 2), 5.0);
938        assert_eq!(array.get(0, 2), 0.0);
939    }
940
941    #[test]
942    fn test_dia_array_from_triplets() {
943        // Create a tridiagonal matrix
944        let row = vec![0, 0, 1, 1, 1, 2, 2];
945        let col = vec![0, 1, 0, 1, 2, 1, 2];
946        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0, 6.0, 7.0];
947        let shape = (3, 3);
948
949        let array = DiaArray::from_triplets(&row, &col, &data, shape).unwrap();
950
951        // Should have 3 diagonals: main (0), upper (1), and lower (-1)
952        assert_eq!(array.offsets.len(), 3);
953        assert!(array.offsets.contains(&0));
954        assert!(array.offsets.contains(&1));
955        assert!(array.offsets.contains(&-1));
956
957        // Test values
958        assert_eq!(array.get(0, 0), 1.0);
959        assert_eq!(array.get(0, 1), 4.0);
960        assert_eq!(array.get(1, 0), 2.0);
961        assert_eq!(array.get(1, 1), 3.0);
962        assert_eq!(array.get(1, 2), 5.0);
963        assert_eq!(array.get(2, 1), 6.0);
964        assert_eq!(array.get(2, 2), 7.0);
965    }
966
967    #[test]
968    fn test_dia_array_conversion() {
969        // Create a tridiagonal matrix
970        let data = vec![
971            Array1::from_vec(vec![1.0, 3.0, 7.0]), // Main diagonal
972            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
973            Array1::from_vec(vec![0.0, 2.0, 0.0]), // Lower diagonal at index 1 (2.0 instead of 6.0)
974        ];
975        let offsets = vec![0, 1, -1]; // Main, upper, lower
976        let shape = (3, 3);
977
978        let array = DiaArray::new(data, offsets, shape).unwrap();
979
980        // Convert to COO and check
981        let coo = array.to_coo().unwrap();
982        assert_eq!(coo.shape(), (3, 3));
983        assert_eq!(coo.nnz(), 6); // Zero value at (2,1) is not stored
984
985        // Convert to dense and check
986        let dense = array.to_array();
987
988        // println!("Dense array: {:?}", dense);
989
990        let expected =
991            Array2::from_shape_vec((3, 3), vec![1.0, 4.0, 0.0, 0.0, 3.0, 5.0, 0.0, 2.0, 7.0])
992                .unwrap();
993        assert_eq!(dense, expected);
994    }
995
996    #[test]
997    fn test_dia_array_operations() {
998        // Create two simple diagonal matrices
999        let data1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])]; // Main diagonal
1000        let offsets1 = vec![0];
1001        let shape1 = (3, 3);
1002        let array1 = DiaArray::new(data1, offsets1, shape1).unwrap();
1003
1004        let data2 = vec![Array1::from_vec(vec![4.0, 5.0, 6.0])]; // Main diagonal
1005        let offsets2 = vec![0];
1006        let shape2 = (3, 3);
1007        let array2 = DiaArray::new(data2, offsets2, shape2).unwrap();
1008
1009        // Test addition
1010        let sum = array1.add(&array2).unwrap();
1011        assert_eq!(sum.get(0, 0), 5.0);
1012        assert_eq!(sum.get(1, 1), 7.0);
1013        assert_eq!(sum.get(2, 2), 9.0);
1014
1015        // Test multiplication
1016        let product = array1.mul(&array2).unwrap();
1017        assert_eq!(product.get(0, 0), 4.0);
1018        assert_eq!(product.get(1, 1), 10.0);
1019        assert_eq!(product.get(2, 2), 18.0);
1020
1021        // Test dot product (matrix multiplication)
1022        let dot = array1.dot(&array2).unwrap();
1023        assert_eq!(dot.get(0, 0), 4.0);
1024        assert_eq!(dot.get(1, 1), 10.0);
1025        assert_eq!(dot.get(2, 2), 18.0);
1026    }
1027
1028    #[test]
1029    fn test_dia_array_dot_vector() {
1030        // Create a tridiagonal matrix
1031        let data = vec![
1032            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1033            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1034            Array1::from_vec(vec![0.0, 6.0, 7.0]), // Lower diagonal
1035        ];
1036        let offsets = vec![0, 1, -1]; // Main, upper, lower
1037        let shape = (3, 3);
1038
1039        let array = DiaArray::new(data, offsets, shape).unwrap();
1040
1041        // Create a vector
1042        let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1043
1044        // Test matrix-vector multiplication
1045        let result = array.dot_vector(&vector.view()).unwrap();
1046
1047        // Expected: [1*1 + 4*2 + 0*3, 6*1 + 2*2 + 5*3, 0*1 + 7*2 + 3*3]
1048        // = [9, 19, 21]
1049        let expected = Array1::from_vec(vec![9.0, 19.0, 21.0]);
1050        assert_eq!(result, expected);
1051    }
1052
1053    #[test]
1054    fn test_dia_array_transpose() {
1055        // Create a tridiagonal matrix
1056        let data = vec![
1057            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1058            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1059            Array1::from_vec(vec![0.0, 6.0, 7.0]), // Lower diagonal
1060        ];
1061        let offsets = vec![0, 1, -1]; // Main, upper, lower
1062        let shape = (3, 3);
1063
1064        let array = DiaArray::new(data, offsets, shape).unwrap();
1065        let transposed = array.transpose().unwrap();
1066
1067        // Check shape
1068        assert_eq!(transposed.shape(), (3, 3));
1069
1070        // Compare the dense array representations
1071        let original_dense = array.to_array();
1072        let transposed_dense = transposed.to_array();
1073
1074        for i in 0..3 {
1075            for j in 0..3 {
1076                assert_eq!(transposed_dense[[i, j]], original_dense[[j, i]]);
1077            }
1078        }
1079    }
1080
1081    #[test]
1082    fn test_dia_array_sum() {
1083        // Create a simple matrix
1084        let data = vec![
1085            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1086            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1087        ];
1088        let offsets = vec![0, 1]; // Main, upper
1089        let shape = (3, 3);
1090
1091        let array = DiaArray::new(data, offsets, shape).unwrap();
1092
1093        // Test sum of entire array
1094        if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1095            assert_eq!(sum, 15.0); // 1+2+3+4+5 = 15
1096        } else {
1097            panic!("Expected SparseSum::Scalar");
1098        }
1099
1100        // Test sum along rows
1101        if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1102            assert_eq!(row_sum.shape(), (1, 3));
1103            assert_eq!(row_sum.get(0, 0), 1.0);
1104            assert_eq!(row_sum.get(0, 1), 6.0); // 2+4 = 6
1105            assert_eq!(row_sum.get(0, 2), 8.0); // 3+5 = 8
1106        } else {
1107            panic!("Expected SparseSum::SparseArray");
1108        }
1109
1110        // Test sum along columns
1111        if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1112            assert_eq!(col_sum.shape(), (3, 1));
1113            assert_eq!(col_sum.get(0, 0), 5.0); // 1+4 = 5
1114            assert_eq!(col_sum.get(1, 0), 7.0); // 2+5 = 7
1115            assert_eq!(col_sum.get(2, 0), 3.0);
1116        } else {
1117            panic!("Expected SparseSum::SparseArray");
1118        }
1119    }
1120}