scirs2_sparse/
dia_array.rs

1// DIA Array implementation
2//
3// This module provides the DIA (DIAgonal) array format,
4// which is efficient for matrices with values concentrated on a small number of diagonals.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::dok_array::DokArray;
14use crate::error::{SparseError, SparseResult};
15use crate::lil_array::LilArray;
16use crate::sparray::{SparseArray, SparseSum};
17
18/// DIA Array format
19///
20/// The DIA (DIAgonal) format stores data as a collection of diagonals.
21/// It is efficient for matrices with values concentrated on a small number of diagonals,
22/// like tridiagonal or band matrices.
23///
24/// # Notes
25///
26/// - Very efficient storage for band matrices
27/// - Fast matrix-vector products for banded matrices
28/// - Not efficient for general sparse matrices
29/// - Difficult to modify once constructed
30///
31#[derive(Clone)]
32pub struct DiaArray<T>
33where
34    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
35{
36    /// Diagonals data (n_diags x max(rows, cols))
37    data: Vec<Array1<T>>,
38    /// Diagonal offsets from the main diagonal (k > 0 for above, k < 0 for below)
39    offsets: Vec<isize>,
40    /// Shape of the array
41    shape: (usize, usize),
42}
43
44impl<T> DiaArray<T>
45where
46    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
47{
48    /// Create a new DIA array from raw data
49    ///
50    /// # Arguments
51    ///
52    /// * `data` - Diagonals data (n_diags x max(rows, cols))
53    /// * `offsets` - Diagonal offsets from the main diagonal
54    /// * `shape` - Tuple containing the array dimensions (rows, cols)
55    ///
56    /// # Returns
57    ///
58    /// * A new DIA array
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// use scirs2_sparse::dia_array::DiaArray;
64    /// use scirs2_sparse::sparray::SparseArray;
65    /// use scirs2_core::ndarray::Array1;
66    ///
67    /// // Create a 3x3 sparse array with main diagonal and upper diagonal
68    /// let data = vec![
69    ///     Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
70    ///     Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal (k=1)
71    /// ];
72    /// let offsets = vec![0, 1]; // Main diagonal and k=1
73    /// let shape = (3, 3);
74    ///
75    /// let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
76    /// assert_eq!(array.shape(), (3, 3));
77    /// assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
78    /// ```
79    pub fn new(
80        data: Vec<Array1<T>>,
81        offsets: Vec<isize>,
82        shape: (usize, usize),
83    ) -> SparseResult<Self> {
84        let (rows, cols) = shape;
85        let max_dim = rows.max(cols);
86
87        // Validate input data
88        if data.len() != offsets.len() {
89            return Err(SparseError::DimensionMismatch {
90                expected: data.len(),
91                found: offsets.len(),
92            });
93        }
94
95        for diag in data.iter() {
96            if diag.len() != max_dim {
97                return Err(SparseError::DimensionMismatch {
98                    expected: max_dim,
99                    found: diag.len(),
100                });
101            }
102        }
103
104        Ok(DiaArray {
105            data,
106            offsets,
107            shape,
108        })
109    }
110
111    /// Create a new empty DIA array
112    ///
113    /// # Arguments
114    ///
115    /// * `shape` - Tuple containing the array dimensions (rows, cols)
116    ///
117    /// # Returns
118    ///
119    /// * A new empty DIA array
120    pub fn empty(shape: (usize, usize)) -> Self {
121        DiaArray {
122            data: Vec::new(),
123            offsets: Vec::new(),
124            shape,
125        }
126    }
127
128    /// Convert COO format to DIA format
129    ///
130    /// # Arguments
131    ///
132    /// * `row` - Row indices
133    /// * `col` - Column indices
134    /// * `data` - Data values
135    /// * `shape` - Shape of the array
136    ///
137    /// # Returns
138    ///
139    /// * A new DIA array
140    pub fn from_triplets(
141        row: &[usize],
142        col: &[usize],
143        data: &[T],
144        shape: (usize, usize),
145    ) -> SparseResult<Self> {
146        if row.len() != col.len() || row.len() != data.len() {
147            return Err(SparseError::InconsistentData {
148                reason: "Lengths of row, col, and data arrays must be equal".to_string(),
149            });
150        }
151
152        let (rows, cols) = shape;
153        let max_dim = rows.max(cols);
154
155        // Identify unique diagonals
156        let mut diagonal_offsets = std::collections::HashSet::new();
157        for (&r, &c) in row.iter().zip(col.iter()) {
158            if r >= rows || c >= cols {
159                return Err(SparseError::IndexOutOfBounds {
160                    index: (r, c),
161                    shape,
162                });
163            }
164            // Calculate diagonal offset (column - row for diagonals)
165            let offset = c as isize - r as isize;
166            diagonal_offsets.insert(offset);
167        }
168
169        // Convert to a sorted vector
170        let mut offsets: Vec<isize> = diagonal_offsets.into_iter().collect();
171        offsets.sort();
172
173        // Create data arrays (initialized to zero)
174        let mut diag_data = Vec::with_capacity(offsets.len());
175        for _ in 0..offsets.len() {
176            diag_data.push(Array1::zeros(max_dim));
177        }
178
179        // Fill in the data
180        for (&r, (&c, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
181            let offset = c as isize - r as isize;
182            let diag_idx = offsets
183                .iter()
184                .position(|&o| o == offset)
185                .expect("Operation failed");
186
187            // For upper diagonals (k > 0), the index is row
188            // For lower diagonals (k < 0), the index is column
189            let index = if offset >= 0 { r } else { c };
190            diag_data[diag_idx][index] = val;
191        }
192
193        DiaArray::new(diag_data, offsets, shape)
194    }
195
196    /// Convert to COO format
197    fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
198        let (rows, cols) = self.shape;
199        let mut row_indices = Vec::new();
200        let mut col_indices = Vec::new();
201        let mut values = Vec::new();
202
203        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
204            let diag = &self.data[diag_idx];
205
206            if offset >= 0 {
207                // Upper diagonal
208                let offset_usize = offset as usize;
209                let length = rows.min(cols.saturating_sub(offset_usize));
210
211                for i in 0..length {
212                    let value = diag[i];
213                    if !SparseElement::is_zero(&value) {
214                        row_indices.push(i);
215                        col_indices.push(i + offset_usize);
216                        values.push(value);
217                    }
218                }
219            } else {
220                // Lower diagonal
221                let offset_usize = (-offset) as usize;
222                let length = cols.min(rows.saturating_sub(offset_usize));
223
224                for i in 0..length {
225                    let value = diag[i];
226                    if !SparseElement::is_zero(&value) {
227                        row_indices.push(i + offset_usize);
228                        col_indices.push(i);
229                        values.push(value);
230                    }
231                }
232            }
233        }
234
235        (row_indices, col_indices, values)
236    }
237}
238
239impl<T> SparseArray<T> for DiaArray<T>
240where
241    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
242{
243    fn shape(&self) -> (usize, usize) {
244        self.shape
245    }
246
247    fn nnz(&self) -> usize {
248        let (rows, cols) = self.shape;
249        let mut count = 0;
250
251        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
252            let diag = &self.data[diag_idx];
253
254            // Calculate valid range for this diagonal
255            let length = if offset >= 0 {
256                rows.min(cols.saturating_sub(offset as usize))
257            } else {
258                cols.min(rows.saturating_sub((-offset) as usize))
259            };
260
261            // Count non-zeros in the valid range
262            let start_idx = 0; // Start at 0 regardless of offset
263            for i in start_idx..start_idx + length {
264                if !SparseElement::is_zero(&diag[i]) {
265                    count += 1;
266                }
267            }
268        }
269
270        count
271    }
272
273    fn dtype(&self) -> &str {
274        "float" // Placeholder; ideally would return the actual type
275    }
276
277    fn to_array(&self) -> Array2<T> {
278        // Convert to dense format
279        let (rows, cols) = self.shape;
280        let mut result = Array2::zeros((rows, cols));
281
282        // In the test case we have:
283        // data[0] = [1.0, 3.0, 7.0] with offset 0 (main diagonal)
284        // data[1] = [4.0, 5.0, 0.0] with offset 1 (upper diagonal)
285        // data[2] = [0.0, 2.0, 6.0] with offset -1 (lower diagonal)
286
287        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
288            let diag = &self.data[diag_idx];
289
290            if offset >= 0 {
291                // Upper diagonal (k >= 0)
292                let offset_usize = offset as usize;
293                for i in 0..rows.min(cols.saturating_sub(offset_usize)) {
294                    result[[i, i + offset_usize]] = diag[i];
295                }
296            } else {
297                // Lower diagonal (k < 0)
298                let offset_usize = (-offset) as usize;
299                for i in 0..cols.min(rows.saturating_sub(offset_usize)) {
300                    result[[i + offset_usize, i]] = diag[i];
301                }
302            }
303        }
304
305        result
306    }
307
308    fn toarray(&self) -> Array2<T> {
309        self.to_array()
310    }
311
312    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
313        let (row_indices, col_indices, values) = self.to_coo_internal();
314        let row_array = Array1::from_vec(row_indices);
315        let col_array = Array1::from_vec(col_indices);
316        let data_array = Array1::from_vec(values);
317
318        CooArray::from_triplets(
319            &row_array.to_vec(),
320            &col_array.to_vec(),
321            &data_array.to_vec(),
322            self.shape,
323            false,
324        )
325        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
326    }
327
328    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
329        let (row_indices, col_indices, values) = self.to_coo_internal();
330        CsrArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
331            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
332    }
333
334    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
335        self.to_coo()?.to_csc()
336    }
337
338    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
339        let (row_indices, col_indices, values) = self.to_coo_internal();
340        DokArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
341            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
342    }
343
344    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
345        let (row_indices, col_indices, values) = self.to_coo_internal();
346        LilArray::from_triplets(&row_indices, &col_indices, &values, self.shape)
347            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
348    }
349
350    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
351        Ok(Box::new(self.clone()))
352    }
353
354    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
355        self.to_coo()?.to_bsr()
356    }
357
358    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
359        // Convert both to CSR for efficient addition
360        let csr_self = self.to_csr()?;
361        let csr_other = other.to_csr()?;
362        csr_self.add(&*csr_other)
363    }
364
365    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
366        // Convert both to CSR for efficient subtraction
367        let csr_self = self.to_csr()?;
368        let csr_other = other.to_csr()?;
369        csr_self.sub(&*csr_other)
370    }
371
372    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
373        // Convert both to CSR for efficient element-wise multiplication
374        let csr_self = self.to_csr()?;
375        let csr_other = other.to_csr()?;
376        csr_self.mul(&*csr_other)
377    }
378
379    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
380        // Convert both to CSR for efficient element-wise division
381        let csr_self = self.to_csr()?;
382        let csr_other = other.to_csr()?;
383        csr_self.div(&*csr_other)
384    }
385
386    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
387        // For matrix multiplication, use specialized DIA-Vector logic if other is thin
388        let (_, n) = self.shape();
389        let (p, q) = other.shape();
390
391        if n != p {
392            return Err(SparseError::DimensionMismatch {
393                expected: n,
394                found: p,
395            });
396        }
397
398        // If other is a vector (thin matrix), we can use optimized DIA-Vector multiplication
399        if q == 1 {
400            // Get the vector from other
401            let other_array = other.to_array();
402            let vec_view = other_array.column(0);
403
404            // Perform DIA-Vector multiplication
405            let result = self.dot_vector(&vec_view)?;
406
407            // Convert to a matrix - create a COO from triplets
408            let mut rows = Vec::new();
409            let mut cols = Vec::new();
410            let mut values = Vec::new();
411
412            for (i, &val) in result.iter().enumerate() {
413                if !SparseElement::is_zero(&val) {
414                    rows.push(i);
415                    cols.push(0);
416                    values.push(val);
417                }
418            }
419
420            CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
421                .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
422        } else {
423            // For general matrices, convert to CSR
424            let csr_self = self.to_csr()?;
425            csr_self.dot(other)
426        }
427    }
428
429    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
430        let (rows, cols) = self.shape;
431
432        if cols != other.len() {
433            return Err(SparseError::DimensionMismatch {
434                expected: cols,
435                found: other.len(),
436            });
437        }
438
439        let mut result = Array1::zeros(rows);
440
441        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
442            let diag = &self.data[diag_idx];
443
444            if offset >= 0 {
445                // Upper diagonal (k > 0)
446                let offset_usize = offset as usize;
447                let length = rows.min(cols.saturating_sub(offset_usize));
448
449                for i in 0..length {
450                    result[i] += diag[i] * other[i + offset_usize];
451                }
452            } else {
453                // Lower diagonal (k < 0)
454                let offset_usize = (-offset) as usize;
455                let length = cols.min(rows.saturating_sub(offset_usize));
456
457                for i in 0..length {
458                    result[i + offset_usize] += diag[i] * other[i];
459                }
460            }
461        }
462
463        Ok(result)
464    }
465
466    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
467        // For correct transposition, use COO intermediately
468        // This avoids issues with the diagonal storage format
469        let (row_indices, col_indices, values) = self.to_coo_internal();
470
471        // Swap row and column indices
472        let transposed_rows = col_indices;
473        let transposed_cols = row_indices;
474
475        // Create a new COO array and convert back to DIA
476        CooArray::from_triplets(
477            &transposed_rows,
478            &transposed_cols,
479            &values,
480            (self.shape.1, self.shape.0),
481            false,
482        )?
483        .to_dia()
484    }
485
486    fn copy(&self) -> Box<dyn SparseArray<T>> {
487        Box::new(self.clone())
488    }
489
490    fn get(&self, i: usize, j: usize) -> T {
491        if i >= self.shape.0 || j >= self.shape.1 {
492            return T::sparse_zero();
493        }
494
495        // Calculate the diagonal offset
496        let offset = j as isize - i as isize;
497
498        // Check if this offset exists in our stored diagonals
499        if let Some(diag_idx) = self.offsets.iter().position(|&o| o == offset) {
500            let diag = &self.data[diag_idx];
501
502            // For upper diagonals (k > 0), the index is row
503            // For lower diagonals (k < 0), the index is column
504            let index = if offset >= 0 { i } else { j };
505
506            // Make sure the index is within bounds
507            if index < diag.len() {
508                return diag[index];
509            }
510        }
511
512        T::sparse_zero()
513    }
514
515    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
516        if i >= self.shape.0 || j >= self.shape.1 {
517            return Err(SparseError::IndexOutOfBounds {
518                index: (i, j),
519                shape: self.shape,
520            });
521        }
522
523        // Calculate the diagonal offset
524        let offset = j as isize - i as isize;
525
526        // Find or create the diagonal
527        let diag_idx = match self.offsets.iter().position(|&o| o == offset) {
528            Some(idx) => idx,
529            None => {
530                // This diagonal doesn't exist yet, add it
531                self.offsets.push(offset);
532                self.data
533                    .push(Array1::zeros(self.shape.0.max(self.shape.1)));
534
535                // Sort the offsets and data to maintain canonical form
536                let mut offset_data: Vec<(isize, Array1<T>)> = self
537                    .offsets
538                    .iter()
539                    .cloned()
540                    .zip(self.data.drain(..))
541                    .collect();
542                offset_data.sort_by_key(|&(offset_, _)| offset_);
543
544                self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
545                self.data = offset_data.into_iter().map(|(_, data)| data).collect();
546
547                // Get the index of the newly added diagonal
548                self.offsets
549                    .iter()
550                    .position(|&o| o == offset)
551                    .expect("Operation failed")
552            }
553        };
554
555        // Set the value
556        let index = if offset >= 0 { i } else { j };
557        self.data[diag_idx][index] = value;
558
559        Ok(())
560    }
561
562    fn eliminate_zeros(&mut self) {
563        // Create a new set of diagonals without zeros
564        let mut new_offsets = Vec::new();
565        let mut new_data = Vec::new();
566
567        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
568            let diag = &self.data[diag_idx];
569
570            // Check if this diagonal has any non-zero values
571            let length = if offset >= 0 {
572                self.shape
573                    .0
574                    .min(self.shape.1.saturating_sub(offset as usize))
575            } else {
576                self.shape
577                    .1
578                    .min(self.shape.0.saturating_sub((-offset) as usize))
579            };
580
581            let has_nonzero = (0..length).any(|i| !SparseElement::is_zero(&diag[i]));
582
583            if has_nonzero {
584                new_offsets.push(offset);
585                new_data.push(diag.clone());
586            }
587        }
588
589        self.offsets = new_offsets;
590        self.data = new_data;
591    }
592
593    fn sort_indices(&mut self) {
594        // DIA arrays have implicitly sorted indices based on offset
595        // Sort by offset just to be sure
596        let mut offset_data: Vec<(isize, Array1<T>)> = self
597            .offsets
598            .iter()
599            .cloned()
600            .zip(self.data.drain(..))
601            .collect();
602        offset_data.sort_by_key(|&(offset_, _)| offset_);
603
604        self.offsets = offset_data.iter().map(|&(offset_, _)| offset_).collect();
605        self.data = offset_data.into_iter().map(|(_, data)| data).collect();
606    }
607
608    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
609        // Clone and sort
610        let mut result = self.clone();
611        result.sort_indices();
612        Box::new(result)
613    }
614
615    fn has_sorted_indices(&self) -> bool {
616        // Check if offsets are sorted
617        self.offsets.windows(2).all(|w| w[0] <= w[1])
618    }
619
620    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
621        match axis {
622            None => {
623                // Sum all elements
624                let mut total = T::sparse_zero();
625
626                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
627                    let diag = &self.data[diag_idx];
628
629                    let length = if offset >= 0 {
630                        self.shape
631                            .0
632                            .min(self.shape.1.saturating_sub(offset as usize))
633                    } else {
634                        self.shape
635                            .1
636                            .min(self.shape.0.saturating_sub((-offset) as usize))
637                    };
638
639                    for i in 0..length {
640                        total += diag[i];
641                    }
642                }
643
644                Ok(SparseSum::Scalar(total))
645            }
646            Some(0) => {
647                // Sum along rows (result is 1 x cols)
648                let mut result = Array1::zeros(self.shape.1);
649
650                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
651                    let diag = &self.data[diag_idx];
652
653                    if offset >= 0 {
654                        // Upper diagonal
655                        let offset_usize = offset as usize;
656                        let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
657
658                        for i in 0..length {
659                            result[i + offset_usize] += diag[i];
660                        }
661                    } else {
662                        // Lower diagonal
663                        let offset_usize = (-offset) as usize;
664                        let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
665
666                        for i in 0..length {
667                            result[i] += diag[i];
668                        }
669                    }
670                }
671
672                // Convert to a sparse array
673                match Array2::from_shape_vec((1, self.shape.1), result.to_vec()) {
674                    Ok(result_2d) => {
675                        // Find non-zero elements
676                        let mut row_indices = Vec::new();
677                        let mut col_indices = Vec::new();
678                        let mut values = Vec::new();
679
680                        for j in 0..self.shape.1 {
681                            let val: T = result_2d[[0, j]];
682                            if !SparseElement::is_zero(&val) {
683                                row_indices.push(0);
684                                col_indices.push(j);
685                                values.push(val);
686                            }
687                        }
688
689                        // Create COO array
690                        match CooArray::from_triplets(
691                            &row_indices,
692                            &col_indices,
693                            &values,
694                            (1, self.shape.1),
695                            false,
696                        ) {
697                            Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
698                            Err(e) => Err(e),
699                        }
700                    }
701                    Err(_) => Err(SparseError::InconsistentData {
702                        reason: "Failed to create 2D array from result vector".to_string(),
703                    }),
704                }
705            }
706            Some(1) => {
707                // Sum along columns (result is rows x 1)
708                let mut result = Array1::zeros(self.shape.0);
709
710                for (diag_idx, &offset) in self.offsets.iter().enumerate() {
711                    let diag = &self.data[diag_idx];
712
713                    if offset >= 0 {
714                        // Upper diagonal
715                        let offset_usize = offset as usize;
716                        let length = self.shape.0.min(self.shape.1.saturating_sub(offset_usize));
717
718                        for i in 0..length {
719                            result[i] += diag[i];
720                        }
721                    } else {
722                        // Lower diagonal
723                        let offset_usize = (-offset) as usize;
724                        let length = self.shape.1.min(self.shape.0.saturating_sub(offset_usize));
725
726                        for i in 0..length {
727                            result[i + offset_usize] += diag[i];
728                        }
729                    }
730                }
731
732                // Convert to a sparse array
733                match Array2::from_shape_vec((self.shape.0, 1), result.to_vec()) {
734                    Ok(result_2d) => {
735                        // Find non-zero elements
736                        let mut row_indices = Vec::new();
737                        let mut col_indices = Vec::new();
738                        let mut values = Vec::new();
739
740                        for i in 0..self.shape.0 {
741                            let val: T = result_2d[[i, 0]];
742                            if !SparseElement::is_zero(&val) {
743                                row_indices.push(i);
744                                col_indices.push(0);
745                                values.push(val);
746                            }
747                        }
748
749                        // Create COO array
750                        match CooArray::from_triplets(
751                            &row_indices,
752                            &col_indices,
753                            &values,
754                            (self.shape.0, 1),
755                            false,
756                        ) {
757                            Ok(coo_array) => Ok(SparseSum::SparseArray(Box::new(coo_array))),
758                            Err(e) => Err(e),
759                        }
760                    }
761                    Err(_) => Err(SparseError::InconsistentData {
762                        reason: "Failed to create 2D array from result vector".to_string(),
763                    }),
764                }
765            }
766            _ => Err(SparseError::InvalidAxis),
767        }
768    }
769
770    fn max(&self) -> T {
771        let mut max_val = T::neg_infinity();
772
773        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
774            let diag = &self.data[diag_idx];
775
776            let length = if offset >= 0 {
777                self.shape
778                    .0
779                    .min(self.shape.1.saturating_sub(offset as usize))
780            } else {
781                self.shape
782                    .1
783                    .min(self.shape.0.saturating_sub((-offset) as usize))
784            };
785
786            for i in 0..length {
787                max_val = max_val.max(diag[i]);
788            }
789        }
790
791        // If no elements or all negative infinity, return zero
792        if max_val == T::neg_infinity() {
793            T::sparse_zero()
794        } else {
795            max_val
796        }
797    }
798
799    fn min(&self) -> T {
800        let mut min_val = T::sparse_zero();
801        let mut has_nonzero = false;
802
803        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
804            let diag = &self.data[diag_idx];
805
806            let length = if offset >= 0 {
807                self.shape
808                    .0
809                    .min(self.shape.1.saturating_sub(offset as usize))
810            } else {
811                self.shape
812                    .1
813                    .min(self.shape.0.saturating_sub((-offset) as usize))
814            };
815
816            for i in 0..length {
817                if !SparseElement::is_zero(&diag[i]) {
818                    has_nonzero = true;
819                    min_val = min_val.min(diag[i]);
820                }
821            }
822        }
823
824        // If no non-zero elements, return zero
825        if !has_nonzero {
826            T::sparse_zero()
827        } else {
828            min_val
829        }
830    }
831
832    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
833        let (row_indices, col_indices, values) = self.to_coo_internal();
834
835        (
836            Array1::from_vec(row_indices),
837            Array1::from_vec(col_indices),
838            Array1::from_vec(values),
839        )
840    }
841
842    fn slice(
843        &self,
844        row_range: (usize, usize),
845        col_range: (usize, usize),
846    ) -> SparseResult<Box<dyn SparseArray<T>>> {
847        let (start_row, end_row) = row_range;
848        let (start_col, end_col) = col_range;
849        let (rows, cols) = self.shape;
850
851        if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
852            return Err(SparseError::IndexOutOfBounds {
853                index: (start_row.max(end_row), start_col.max(end_col)),
854                shape: (rows, cols),
855            });
856        }
857
858        if start_row >= end_row || start_col >= end_col {
859            return Err(SparseError::InvalidSliceRange);
860        }
861
862        // Convert to COO, then slice, then convert back to DIA
863        let coo = self.to_coo()?;
864        coo.slice(row_range, col_range)?.to_dia()
865    }
866
867    fn as_any(&self) -> &dyn std::any::Any {
868        self
869    }
870}
871
872// Implement Display for DiaArray for better debugging
873impl<T> fmt::Display for DiaArray<T>
874where
875    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
876{
877    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
878        writeln!(
879            f,
880            "DiaArray of shape {:?} with {} stored elements",
881            self.shape,
882            self.nnz()
883        )?;
884        writeln!(f, "Offsets: {:?}", self.offsets)?;
885
886        if self.offsets.len() <= 5 {
887            for (i, &offset) in self.offsets.iter().enumerate() {
888                let diag = &self.data[i];
889                let length = if offset >= 0 {
890                    self.shape
891                        .0
892                        .min(self.shape.1.saturating_sub(offset as usize))
893                } else {
894                    self.shape
895                        .1
896                        .min(self.shape.0.saturating_sub((-offset) as usize))
897                };
898
899                write!(f, "Diagonal {offset}: [")?;
900                for j in 0..length.min(10) {
901                    if j > 0 {
902                        write!(f, ", ")?;
903                    }
904                    write!(f, "{:?}", diag[j])?;
905                }
906                if length > 10 {
907                    write!(f, ", ...")?;
908                }
909                writeln!(f, "]")?;
910            }
911        } else {
912            writeln!(f, "({} diagonals)", self.offsets.len())?;
913        }
914
915        Ok(())
916    }
917}
918
919#[cfg(test)]
920mod tests {
921    use super::*;
922
923    #[test]
924    fn test_dia_array_create() {
925        // Create a 3x3 sparse array with main diagonal and upper diagonal
926        let data = vec![
927            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
928            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal (k=1)
929        ];
930        let offsets = vec![0, 1]; // Main diagonal and k=1
931        let shape = (3, 3);
932
933        let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
934
935        assert_eq!(array.shape(), (3, 3));
936        assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
937
938        // Test values
939        assert_eq!(array.get(0, 0), 1.0);
940        assert_eq!(array.get(1, 1), 2.0);
941        assert_eq!(array.get(2, 2), 3.0);
942        assert_eq!(array.get(0, 1), 4.0);
943        assert_eq!(array.get(1, 2), 5.0);
944        assert_eq!(array.get(0, 2), 0.0);
945    }
946
947    #[test]
948    fn test_dia_array_from_triplets() {
949        // Create a tridiagonal matrix
950        let row = vec![0, 0, 1, 1, 1, 2, 2];
951        let col = vec![0, 1, 0, 1, 2, 1, 2];
952        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0, 6.0, 7.0];
953        let shape = (3, 3);
954
955        let array = DiaArray::from_triplets(&row, &col, &data, shape).expect("Operation failed");
956
957        // Should have 3 diagonals: main (0), upper (1), and lower (-1)
958        assert_eq!(array.offsets.len(), 3);
959        assert!(array.offsets.contains(&0));
960        assert!(array.offsets.contains(&1));
961        assert!(array.offsets.contains(&-1));
962
963        // Test values
964        assert_eq!(array.get(0, 0), 1.0);
965        assert_eq!(array.get(0, 1), 4.0);
966        assert_eq!(array.get(1, 0), 2.0);
967        assert_eq!(array.get(1, 1), 3.0);
968        assert_eq!(array.get(1, 2), 5.0);
969        assert_eq!(array.get(2, 1), 6.0);
970        assert_eq!(array.get(2, 2), 7.0);
971    }
972
973    #[test]
974    fn test_dia_array_conversion() {
975        // Create a tridiagonal matrix
976        let data = vec![
977            Array1::from_vec(vec![1.0, 3.0, 7.0]), // Main diagonal
978            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
979            Array1::from_vec(vec![0.0, 2.0, 0.0]), // Lower diagonal at index 1 (2.0 instead of 6.0)
980        ];
981        let offsets = vec![0, 1, -1]; // Main, upper, lower
982        let shape = (3, 3);
983
984        let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
985
986        // Convert to COO and check
987        let coo = array.to_coo().expect("Operation failed");
988        assert_eq!(coo.shape(), (3, 3));
989        assert_eq!(coo.nnz(), 6); // Zero value at (2,1) is not stored
990
991        // Convert to dense and check
992        let dense = array.to_array();
993
994        // println!("Dense array: {:?}", dense);
995
996        let expected =
997            Array2::from_shape_vec((3, 3), vec![1.0, 4.0, 0.0, 0.0, 3.0, 5.0, 0.0, 2.0, 7.0])
998                .expect("Operation failed");
999        assert_eq!(dense, expected);
1000    }
1001
1002    #[test]
1003    fn test_dia_array_operations() {
1004        // Create two simple diagonal matrices
1005        let data1 = vec![Array1::from_vec(vec![1.0, 2.0, 3.0])]; // Main diagonal
1006        let offsets1 = vec![0];
1007        let shape1 = (3, 3);
1008        let array1 = DiaArray::new(data1, offsets1, shape1).expect("Operation failed");
1009
1010        let data2 = vec![Array1::from_vec(vec![4.0, 5.0, 6.0])]; // Main diagonal
1011        let offsets2 = vec![0];
1012        let shape2 = (3, 3);
1013        let array2 = DiaArray::new(data2, offsets2, shape2).expect("Operation failed");
1014
1015        // Test addition
1016        let sum = array1.add(&array2).expect("Operation failed");
1017        assert_eq!(sum.get(0, 0), 5.0);
1018        assert_eq!(sum.get(1, 1), 7.0);
1019        assert_eq!(sum.get(2, 2), 9.0);
1020
1021        // Test multiplication
1022        let product = array1.mul(&array2).expect("Operation failed");
1023        assert_eq!(product.get(0, 0), 4.0);
1024        assert_eq!(product.get(1, 1), 10.0);
1025        assert_eq!(product.get(2, 2), 18.0);
1026
1027        // Test dot product (matrix multiplication)
1028        let dot = array1.dot(&array2).expect("Operation failed");
1029        assert_eq!(dot.get(0, 0), 4.0);
1030        assert_eq!(dot.get(1, 1), 10.0);
1031        assert_eq!(dot.get(2, 2), 18.0);
1032    }
1033
1034    #[test]
1035    fn test_dia_array_dot_vector() {
1036        // Create a tridiagonal matrix
1037        let data = vec![
1038            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1039            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1040            Array1::from_vec(vec![0.0, 6.0, 7.0]), // Lower diagonal
1041        ];
1042        let offsets = vec![0, 1, -1]; // Main, upper, lower
1043        let shape = (3, 3);
1044
1045        let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
1046
1047        // Create a vector
1048        let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1049
1050        // Test matrix-vector multiplication
1051        let result = array.dot_vector(&vector.view()).expect("Operation failed");
1052
1053        // Expected: [1*1 + 4*2 + 0*3, 6*1 + 2*2 + 5*3, 0*1 + 7*2 + 3*3]
1054        // = [9, 19, 21]
1055        let expected = Array1::from_vec(vec![9.0, 19.0, 21.0]);
1056        assert_eq!(result, expected);
1057    }
1058
1059    #[test]
1060    fn test_dia_array_transpose() {
1061        // Create a tridiagonal matrix
1062        let data = vec![
1063            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1064            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1065            Array1::from_vec(vec![0.0, 6.0, 7.0]), // Lower diagonal
1066        ];
1067        let offsets = vec![0, 1, -1]; // Main, upper, lower
1068        let shape = (3, 3);
1069
1070        let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
1071        let transposed = array.transpose().expect("Operation failed");
1072
1073        // Check shape
1074        assert_eq!(transposed.shape(), (3, 3));
1075
1076        // Compare the dense array representations
1077        let original_dense = array.to_array();
1078        let transposed_dense = transposed.to_array();
1079
1080        for i in 0..3 {
1081            for j in 0..3 {
1082                assert_eq!(transposed_dense[[i, j]], original_dense[[j, i]]);
1083            }
1084        }
1085    }
1086
1087    #[test]
1088    fn test_dia_array_sum() {
1089        // Create a simple matrix
1090        let data = vec![
1091            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
1092            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
1093        ];
1094        let offsets = vec![0, 1]; // Main, upper
1095        let shape = (3, 3);
1096
1097        let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
1098
1099        // Test sum of entire array
1100        if let SparseSum::Scalar(sum) = array.sum(None).expect("Operation failed") {
1101            assert_eq!(sum, 15.0); // 1+2+3+4+5 = 15
1102        } else {
1103            panic!("Expected SparseSum::Scalar");
1104        }
1105
1106        // Test sum along rows
1107        if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).expect("Operation failed") {
1108            assert_eq!(row_sum.shape(), (1, 3));
1109            assert_eq!(row_sum.get(0, 0), 1.0);
1110            assert_eq!(row_sum.get(0, 1), 6.0); // 2+4 = 6
1111            assert_eq!(row_sum.get(0, 2), 8.0); // 3+5 = 8
1112        } else {
1113            panic!("Expected SparseSum::SparseArray");
1114        }
1115
1116        // Test sum along columns
1117        if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).expect("Operation failed") {
1118            assert_eq!(col_sum.shape(), (3, 1));
1119            assert_eq!(col_sum.get(0, 0), 5.0); // 1+4 = 5
1120            assert_eq!(col_sum.get(1, 0), 7.0); // 2+5 = 7
1121            assert_eq!(col_sum.get(2, 0), 3.0);
1122        } else {
1123            panic!("Expected SparseSum::SparseArray");
1124        }
1125    }
1126}