Skip to main content

torsh_sparse/
pattern_analysis.rs

1//! Advanced pattern analysis for sparse matrices
2//!
3//! This module provides sophisticated algorithms for analyzing, detecting,
4//! and optimizing sparsity patterns in sparse matrices.
5
6use crate::TorshResult;
7use std::collections::{HashMap, HashSet, VecDeque};
8use torsh_core::Shape;
9
10/// Advanced sparsity patterns with detailed characteristics
11#[derive(Debug, Clone)]
12pub enum AdvancedSparsityPattern {
13    /// Diagonal matrix
14    Diagonal {
15        /// Main diagonal fill ratio
16        fill_ratio: f32,
17    },
18    /// Multi-diagonal (tridiagonal, pentadiagonal, etc.)
19    MultiDiagonal {
20        /// Number of diagonals
21        num_diagonals: usize,
22        /// Diagonal offsets
23        offsets: Vec<i32>,
24    },
25    /// Block diagonal with detected block structure
26    BlockDiagonal {
27        /// Block sizes
28        block_sizes: Vec<(usize, usize)>,
29        /// Block positions
30        block_positions: Vec<(usize, usize)>,
31    },
32    /// Banded matrix with upper and lower bandwidth
33    Banded {
34        /// Lower bandwidth
35        lower_bandwidth: usize,
36        /// Upper bandwidth
37        upper_bandwidth: usize,
38        /// Fill ratio within band
39        fill_ratio: f32,
40    },
41    /// Symmetric pattern
42    Symmetric {
43        /// Symmetry ratio (0.0 = not symmetric, 1.0 = perfectly symmetric)
44        symmetry_ratio: f32,
45        /// Underlying pattern
46        base_pattern: Box<AdvancedSparsityPattern>,
47    },
48    /// Arrow-head pattern (dense first row/column, sparse elsewhere)
49    ArrowHead {
50        /// Size of the dense head
51        head_size: usize,
52    },
53    /// Random/unstructured pattern
54    Random {
55        /// Clustering coefficient
56        clustering_coefficient: f32,
57    },
58}
59
60/// Matrix reordering algorithms
61#[derive(Debug, Clone)]
62pub enum ReorderingAlgorithm {
63    /// Reverse Cuthill-McKee ordering
64    ReverseCuthillMcKee,
65    /// Approximate Minimum Degree ordering
66    ApproximateMinimumDegree,
67    /// Nested Dissection ordering
68    NestedDissection,
69    /// King ordering (variation of RCM)
70    King,
71    /// Random ordering (for comparison)
72    Random,
73}
74
75/// Matrix clustering algorithms
76#[derive(Debug, Clone)]
77pub enum ClusteringAlgorithm {
78    /// Spectral clustering
79    Spectral { num_clusters: usize },
80    /// K-means based on matrix structure
81    KMeans { num_clusters: usize },
82    /// Hierarchical clustering
83    Hierarchical { num_clusters: usize },
84    /// Graph-based clustering
85    GraphBased { num_clusters: usize },
86}
87
88/// Pattern statistics and characteristics
89#[derive(Debug, Clone)]
90pub struct PatternStatistics {
91    /// Number of non-zero elements
92    pub nnz: usize,
93    /// Matrix dimensions
94    pub dimensions: (usize, usize),
95    /// Sparsity ratio (fraction of zeros)
96    pub sparsity: f32,
97    /// Maximum number of non-zeros per row
98    pub max_nnz_per_row: usize,
99    /// Average number of non-zeros per row
100    pub avg_nnz_per_row: f32,
101    /// Standard deviation of non-zeros per row
102    pub std_nnz_per_row: f32,
103    /// Bandwidth (maximum distance from diagonal)
104    pub bandwidth: usize,
105    /// Profile (sum of distances from diagonal)
106    pub profile: usize,
107    /// Number of connected components in graph representation
108    pub connected_components: usize,
109    /// Clustering coefficient
110    pub clustering_coefficient: f32,
111}
112
113/// Advanced pattern analyzer
114pub struct PatternAnalyzer {
115    /// Cached analysis results
116    cache: HashMap<String, AdvancedSparsityPattern>,
117}
118
119impl Default for PatternAnalyzer {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl PatternAnalyzer {
126    /// Create a new pattern analyzer
127    pub fn new() -> Self {
128        Self {
129            cache: HashMap::new(),
130        }
131    }
132
133    /// Analyze sparsity pattern with advanced detection
134    pub fn analyze_advanced_pattern(
135        &mut self,
136        triplets: &[(usize, usize, f32)],
137        shape: &Shape,
138    ) -> TorshResult<AdvancedSparsityPattern> {
139        let cache_key = self.create_cache_key(triplets, shape);
140
141        if let Some(cached_pattern) = self.cache.get(&cache_key) {
142            return Ok(cached_pattern.clone());
143        }
144
145        let pattern = self.detect_pattern(triplets, shape)?;
146        self.cache.insert(cache_key, pattern.clone());
147        Ok(pattern)
148    }
149
150    /// Detect the underlying sparsity pattern
151    fn detect_pattern(
152        &self,
153        triplets: &[(usize, usize, f32)],
154        shape: &Shape,
155    ) -> TorshResult<AdvancedSparsityPattern> {
156        let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
157
158        // Check for diagonal patterns first
159        if let Some(diagonal_pattern) = self.detect_diagonal_pattern(triplets, rows, cols) {
160            return Ok(diagonal_pattern);
161        }
162
163        // Check for banded patterns
164        if let Some(banded_pattern) = self.detect_banded_pattern(triplets, rows, cols) {
165            return Ok(banded_pattern);
166        }
167
168        // Check for block diagonal patterns
169        if let Some(block_pattern) = self.detect_block_diagonal_pattern(triplets, rows, cols) {
170            return Ok(block_pattern);
171        }
172
173        // Check for arrow-head patterns
174        if let Some(arrow_pattern) = self.detect_arrow_head_pattern(triplets, rows, cols) {
175            return Ok(arrow_pattern);
176        }
177
178        // Check for symmetry
179        if let Some(symmetric_pattern) = self.detect_symmetric_pattern(triplets, rows, cols) {
180            return Ok(symmetric_pattern);
181        }
182
183        // Default to random pattern with clustering analysis
184        let clustering_coefficient = self.compute_clustering_coefficient(triplets, rows, cols);
185        Ok(AdvancedSparsityPattern::Random {
186            clustering_coefficient,
187        })
188    }
189
190    /// Detect diagonal and multi-diagonal patterns
191    fn detect_diagonal_pattern(
192        &self,
193        triplets: &[(usize, usize, f32)],
194        rows: usize,
195        cols: usize,
196    ) -> Option<AdvancedSparsityPattern> {
197        let mut diagonal_counts: HashMap<i32, usize> = HashMap::new();
198
199        for (r, c, _) in triplets {
200            let offset = *r as i32 - *c as i32;
201            *diagonal_counts.entry(offset).or_insert(0) += 1;
202        }
203
204        let total_nnz = triplets.len();
205        let main_diagonal_count = diagonal_counts.get(&0).unwrap_or(&0);
206
207        // Check for pure diagonal matrix
208        if diagonal_counts.len() == 1 && diagonal_counts.contains_key(&0) {
209            let fill_ratio = *main_diagonal_count as f32 / std::cmp::min(rows, cols) as f32;
210            return Some(AdvancedSparsityPattern::Diagonal { fill_ratio });
211        }
212
213        // Check for multi-diagonal pattern
214        if diagonal_counts.len() <= 5 {
215            let diagonal_nnz: usize = diagonal_counts.values().sum();
216            if diagonal_nnz as f32 / total_nnz as f32 > 0.9 {
217                let mut offsets: Vec<i32> = diagonal_counts.keys().copied().collect();
218                offsets.sort();
219                return Some(AdvancedSparsityPattern::MultiDiagonal {
220                    num_diagonals: diagonal_counts.len(),
221                    offsets,
222                });
223            }
224        }
225
226        None
227    }
228
229    /// Detect banded patterns
230    fn detect_banded_pattern(
231        &self,
232        triplets: &[(usize, usize, f32)],
233        rows: usize,
234        cols: usize,
235    ) -> Option<AdvancedSparsityPattern> {
236        let mut max_lower_bandwidth = 0;
237        let mut max_upper_bandwidth = 0;
238
239        for (r, c, _) in triplets {
240            let diff = *r as i32 - *c as i32;
241            if diff > 0 {
242                max_lower_bandwidth = std::cmp::max(max_lower_bandwidth, diff as usize);
243            } else {
244                max_upper_bandwidth = std::cmp::max(max_upper_bandwidth, (-diff) as usize);
245            }
246        }
247
248        let total_bandwidth = max_lower_bandwidth + max_upper_bandwidth + 1;
249        let max_possible_bandwidth = std::cmp::min(rows, cols);
250
251        // Consider it banded if bandwidth is significantly smaller than matrix size
252        if total_bandwidth < max_possible_bandwidth / 4 {
253            let band_elements = std::cmp::min(rows, cols) * total_bandwidth
254                - (total_bandwidth * (total_bandwidth - 1)) / 2;
255            let fill_ratio = triplets.len() as f32 / band_elements as f32;
256
257            return Some(AdvancedSparsityPattern::Banded {
258                lower_bandwidth: max_lower_bandwidth,
259                upper_bandwidth: max_upper_bandwidth,
260                fill_ratio,
261            });
262        }
263
264        None
265    }
266
267    /// Detect block diagonal patterns using graph analysis
268    fn detect_block_diagonal_pattern(
269        &self,
270        triplets: &[(usize, usize, f32)],
271        rows: usize,
272        _cols: usize,
273    ) -> Option<AdvancedSparsityPattern> {
274        // Build adjacency representation
275        let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
276
277        for (r, c, _) in triplets {
278            adjacency.entry(*r).or_default().insert(*c);
279            adjacency.entry(*c).or_default().insert(*r);
280        }
281
282        // Find connected components
283        let components = self.find_connected_components(&adjacency, rows);
284
285        if components.len() > 1 {
286            // Analyze block structure
287            let mut block_sizes = Vec::new();
288            let mut block_positions = Vec::new();
289
290            for component in &components {
291                if component.len() > 1 {
292                    let min_idx = *component
293                        .iter()
294                        .min()
295                        .expect("component should not be empty");
296                    let max_idx = *component
297                        .iter()
298                        .max()
299                        .expect("component should not be empty");
300                    let block_size = max_idx - min_idx + 1;
301
302                    block_sizes.push((block_size, block_size));
303                    block_positions.push((min_idx, min_idx));
304                }
305            }
306
307            if !block_sizes.is_empty() {
308                return Some(AdvancedSparsityPattern::BlockDiagonal {
309                    block_sizes,
310                    block_positions,
311                });
312            }
313        }
314
315        None
316    }
317
318    /// Detect arrow-head patterns
319    fn detect_arrow_head_pattern(
320        &self,
321        triplets: &[(usize, usize, f32)],
322        rows: usize,
323        cols: usize,
324    ) -> Option<AdvancedSparsityPattern> {
325        let mut first_row_count = 0;
326        let mut first_col_count = 0;
327
328        for (r, c, _) in triplets {
329            if *r == 0 {
330                first_row_count += 1;
331            }
332            if *c == 0 {
333                first_col_count += 1;
334            }
335        }
336
337        let first_row_density = first_row_count as f32 / cols as f32;
338        let first_col_density = first_col_count as f32 / rows as f32;
339
340        // Check if first row or column is significantly denser
341        if first_row_density > 0.5 || first_col_density > 0.5 {
342            let head_size = std::cmp::max(first_row_count, first_col_count);
343            return Some(AdvancedSparsityPattern::ArrowHead { head_size });
344        }
345
346        None
347    }
348
349    /// Detect symmetric patterns
350    fn detect_symmetric_pattern(
351        &self,
352        triplets: &[(usize, usize, f32)],
353        rows: usize,
354        cols: usize,
355    ) -> Option<AdvancedSparsityPattern> {
356        if rows != cols {
357            return None; // Can't be symmetric if not square
358        }
359
360        let mut pattern_set: HashSet<(usize, usize)> = HashSet::new();
361        let mut symmetric_count = 0;
362
363        for (r, c, _) in triplets {
364            pattern_set.insert((*r, *c));
365        }
366
367        for (r, c, _) in triplets {
368            if pattern_set.contains(&(*c, *r)) {
369                symmetric_count += 1;
370            }
371        }
372
373        let symmetry_ratio = symmetric_count as f32 / triplets.len() as f32;
374
375        if symmetry_ratio > 0.8 {
376            // Recursively detect underlying pattern
377            let base_pattern = Box::new(AdvancedSparsityPattern::Random {
378                clustering_coefficient: self.compute_clustering_coefficient(triplets, rows, cols),
379            });
380
381            return Some(AdvancedSparsityPattern::Symmetric {
382                symmetry_ratio,
383                base_pattern,
384            });
385        }
386
387        None
388    }
389
390    /// Compute clustering coefficient for graph representation
391    fn compute_clustering_coefficient(
392        &self,
393        triplets: &[(usize, usize, f32)],
394        rows: usize,
395        _cols: usize,
396    ) -> f32 {
397        let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
398
399        for (r, c, _) in triplets {
400            if r != c {
401                // Ignore self-loops
402                adjacency.entry(*r).or_default().insert(*c);
403                adjacency.entry(*c).or_default().insert(*r);
404            }
405        }
406
407        let mut total_clustering = 0.0;
408        let mut nodes_with_neighbors = 0;
409
410        for node in 0..rows {
411            if let Some(neighbors) = adjacency.get(&node) {
412                if neighbors.len() >= 2 {
413                    let mut triangles = 0;
414                    let neighbor_vec: Vec<_> = neighbors.iter().collect();
415
416                    for i in 0..neighbor_vec.len() {
417                        for j in (i + 1)..neighbor_vec.len() {
418                            if adjacency
419                                .get(neighbor_vec[i])
420                                .is_some_and(|adj| adj.contains(neighbor_vec[j]))
421                            {
422                                triangles += 1;
423                            }
424                        }
425                    }
426
427                    let possible_edges = neighbors.len() * (neighbors.len() - 1) / 2;
428                    if possible_edges > 0 {
429                        total_clustering += triangles as f32 / possible_edges as f32;
430                        nodes_with_neighbors += 1;
431                    }
432                }
433            }
434        }
435
436        if nodes_with_neighbors > 0 {
437            total_clustering / nodes_with_neighbors as f32
438        } else {
439            0.0
440        }
441    }
442
443    /// Find connected components in graph
444    fn find_connected_components(
445        &self,
446        adjacency: &HashMap<usize, HashSet<usize>>,
447        num_nodes: usize,
448    ) -> Vec<Vec<usize>> {
449        let mut visited = vec![false; num_nodes];
450        let mut components = Vec::new();
451
452        for node in 0..num_nodes {
453            if !visited[node] {
454                let mut component = Vec::new();
455                let mut queue = VecDeque::new();
456                queue.push_back(node);
457                visited[node] = true;
458
459                while let Some(current) = queue.pop_front() {
460                    component.push(current);
461
462                    if let Some(neighbors) = adjacency.get(&current) {
463                        for &neighbor in neighbors {
464                            if !visited[neighbor] {
465                                visited[neighbor] = true;
466                                queue.push_back(neighbor);
467                            }
468                        }
469                    }
470                }
471
472                components.push(component);
473            }
474        }
475
476        components
477    }
478
479    /// Create cache key for memoization
480    fn create_cache_key(&self, triplets: &[(usize, usize, f32)], shape: &Shape) -> String {
481        format!(
482            "{}_{}_{}_{}",
483            shape.dims()[0],
484            shape.dims()[1],
485            triplets.len(),
486            triplets
487                .iter()
488                .take(10)
489                .map(|(r, c, _)| format!("{r}_{c}"))
490                .collect::<Vec<_>>()
491                .join("_")
492        )
493    }
494
495    /// Compute detailed pattern statistics
496    pub fn compute_pattern_statistics(
497        &self,
498        triplets: &[(usize, usize, f32)],
499        shape: &Shape,
500    ) -> TorshResult<PatternStatistics> {
501        let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
502        let nnz = triplets.len();
503        let sparsity = 1.0 - (nnz as f32 / (rows * cols) as f32);
504
505        // Compute row-wise statistics
506        let mut row_counts = vec![0; rows];
507        let mut max_bandwidth = 0;
508        let mut profile = 0;
509
510        for (r, c, _) in triplets {
511            row_counts[*r] += 1;
512            let distance = (*r as i32 - *c as i32).unsigned_abs() as usize;
513            max_bandwidth = std::cmp::max(max_bandwidth, distance);
514            profile += distance;
515        }
516
517        let max_nnz_per_row = *row_counts.iter().max().unwrap_or(&0);
518        let avg_nnz_per_row = nnz as f32 / rows as f32;
519        let variance = row_counts
520            .iter()
521            .map(|&count| (count as f32 - avg_nnz_per_row).powi(2))
522            .sum::<f32>()
523            / rows as f32;
524        let std_nnz_per_row = variance.sqrt();
525
526        // Build adjacency for connected components
527        let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
528        for (r, c, _) in triplets {
529            adjacency.entry(*r).or_default().insert(*c);
530            adjacency.entry(*c).or_default().insert(*r);
531        }
532
533        let components = self.find_connected_components(&adjacency, rows);
534        let connected_components = components.len();
535
536        let clustering_coefficient = self.compute_clustering_coefficient(triplets, rows, cols);
537
538        Ok(PatternStatistics {
539            nnz,
540            dimensions: (rows, cols),
541            sparsity,
542            max_nnz_per_row,
543            avg_nnz_per_row,
544            std_nnz_per_row,
545            bandwidth: max_bandwidth,
546            profile,
547            connected_components,
548            clustering_coefficient,
549        })
550    }
551}
552
553/// Matrix reordering algorithms implementation
554pub struct MatrixReorderer;
555
556impl MatrixReorderer {
557    /// Apply Reverse Cuthill-McKee reordering
558    pub fn reverse_cuthill_mckee(
559        triplets: &[(usize, usize, f32)],
560        num_rows: usize,
561    ) -> TorshResult<Vec<usize>> {
562        // Build adjacency list representation
563        let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
564        for (r, c, _) in triplets {
565            if r != c {
566                adjacency.entry(*r).or_default().insert(*c);
567                adjacency.entry(*c).or_default().insert(*r);
568            }
569        }
570
571        // Find peripheral vertex (vertex with minimum degree, furthest from center)
572        let start_vertex = Self::find_peripheral_vertex(&adjacency, num_rows)?;
573
574        // BFS ordering
575        let mut ordering = Vec::new();
576        let mut visited = vec![false; num_rows];
577        let mut queue = VecDeque::new();
578
579        queue.push_back(start_vertex);
580        visited[start_vertex] = true;
581
582        while let Some(vertex) = queue.pop_front() {
583            ordering.push(vertex);
584
585            // Get neighbors and sort by degree (ascending)
586            if let Some(neighbors) = adjacency.get(&vertex) {
587                let mut neighbor_degrees: Vec<_> = neighbors
588                    .iter()
589                    .filter(|&&neighbor| !visited[neighbor])
590                    .map(|&neighbor| {
591                        let degree = adjacency.get(&neighbor).map_or(0, |adj| adj.len());
592                        (degree, neighbor)
593                    })
594                    .collect();
595
596                neighbor_degrees.sort_by_key(|&(degree, _)| degree);
597
598                for (_, neighbor) in neighbor_degrees {
599                    if !visited[neighbor] {
600                        visited[neighbor] = true;
601                        queue.push_back(neighbor);
602                    }
603                }
604            }
605        }
606
607        // Add any remaining unvisited vertices
608        for (i, &is_visited) in visited.iter().enumerate() {
609            if !is_visited {
610                ordering.push(i);
611            }
612        }
613
614        // Reverse the ordering (Reverse Cuthill-McKee)
615        ordering.reverse();
616
617        Ok(ordering)
618    }
619
620    /// Find a peripheral vertex (good starting point for RCM)
621    fn find_peripheral_vertex(
622        adjacency: &HashMap<usize, HashSet<usize>>,
623        num_rows: usize,
624    ) -> TorshResult<usize> {
625        let mut min_degree = usize::MAX;
626        let mut peripheral_candidates = Vec::new();
627
628        // Find vertices with minimum degree
629        for i in 0..num_rows {
630            let degree = adjacency.get(&i).map_or(0, |adj| adj.len());
631            if degree < min_degree {
632                min_degree = degree;
633                peripheral_candidates.clear();
634                peripheral_candidates.push(i);
635            } else if degree == min_degree {
636                peripheral_candidates.push(i);
637            }
638        }
639
640        if peripheral_candidates.is_empty() {
641            return Ok(0); // Fallback to first vertex
642        }
643
644        // Among minimum degree vertices, find the one with maximum distance to others
645        let mut best_vertex = peripheral_candidates[0];
646        let mut max_distance = 0;
647
648        for &candidate in &peripheral_candidates {
649            let distance = Self::compute_eccentricity(adjacency, candidate, num_rows);
650            if distance > max_distance {
651                max_distance = distance;
652                best_vertex = candidate;
653            }
654        }
655
656        Ok(best_vertex)
657    }
658
659    /// Compute eccentricity (maximum distance to any other vertex)
660    fn compute_eccentricity(
661        adjacency: &HashMap<usize, HashSet<usize>>,
662        start: usize,
663        num_rows: usize,
664    ) -> usize {
665        let mut distances = vec![usize::MAX; num_rows];
666        let mut queue = VecDeque::new();
667
668        distances[start] = 0;
669        queue.push_back(start);
670
671        while let Some(vertex) = queue.pop_front() {
672            if let Some(neighbors) = adjacency.get(&vertex) {
673                for &neighbor in neighbors {
674                    if distances[neighbor] == usize::MAX {
675                        distances[neighbor] = distances[vertex] + 1;
676                        queue.push_back(neighbor);
677                    }
678                }
679            }
680        }
681
682        distances
683            .iter()
684            .filter(|&&d| d != usize::MAX)
685            .max()
686            .copied()
687            .unwrap_or(0)
688    }
689
690    /// Apply reordering to triplets
691    pub fn apply_reordering(
692        triplets: &[(usize, usize, f32)],
693        ordering: &[usize],
694    ) -> Vec<(usize, usize, f32)> {
695        let mut inverse_ordering = vec![0; ordering.len()];
696        for (new_idx, &old_idx) in ordering.iter().enumerate() {
697            inverse_ordering[old_idx] = new_idx;
698        }
699
700        triplets
701            .iter()
702            .map(|(r, c, v)| (inverse_ordering[*r], inverse_ordering[*c], *v))
703            .collect()
704    }
705}
706
707/// Visualization utilities for sparsity patterns
708pub struct PatternVisualizer;
709
710impl PatternVisualizer {
711    /// Generate ASCII art visualization of sparsity pattern
712    pub fn ascii_pattern(
713        triplets: &[(usize, usize, f32)],
714        shape: &Shape,
715        max_size: Option<(usize, usize)>,
716    ) -> String {
717        let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
718        let (display_rows, display_cols) = max_size.unwrap_or((50, 50));
719
720        let row_scale = if rows > display_rows {
721            rows / display_rows
722        } else {
723            1
724        };
725        let col_scale = if cols > display_cols {
726            cols / display_cols
727        } else {
728            1
729        };
730
731        let scaled_rows = rows.div_ceil(row_scale);
732        let scaled_cols = cols.div_ceil(col_scale);
733
734        let mut pattern = vec![vec![' '; scaled_cols]; scaled_rows];
735
736        for (r, c, _) in triplets {
737            let scaled_r = r / row_scale;
738            let scaled_c = c / col_scale;
739            if scaled_r < scaled_rows && scaled_c < scaled_cols {
740                pattern[scaled_r][scaled_c] = '*';
741            }
742        }
743
744        let mut result = String::new();
745        result.push_str(&format!(
746            "Sparsity Pattern ({rows}x{cols}, scaled to {scaled_rows}x{scaled_cols})\n"
747        ));
748        result.push_str(&"-".repeat(scaled_cols + 2));
749        result.push('\n');
750
751        for row in pattern {
752            result.push('|');
753            for cell in row {
754                result.push(cell);
755            }
756            result.push_str("|\n");
757        }
758
759        result.push_str(&"-".repeat(scaled_cols + 2));
760        result.push('\n');
761
762        result
763    }
764
765    /// Generate pattern histogram (distribution of non-zeros per row/column)
766    pub fn pattern_histogram(
767        triplets: &[(usize, usize, f32)],
768        shape: &Shape,
769    ) -> (Vec<usize>, Vec<usize>) {
770        let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
771        let mut row_counts = vec![0; rows];
772        let mut col_counts = vec![0; cols];
773
774        for (r, c, _) in triplets {
775            row_counts[*r] += 1;
776            col_counts[*c] += 1;
777        }
778
779        (row_counts, col_counts)
780    }
781}
782
783#[cfg(test)]
784mod tests {
785    use super::*;
786
787    #[test]
788    fn test_advanced_pattern_detection() {
789        let mut analyzer = PatternAnalyzer::new();
790
791        // Test diagonal pattern
792        let triplets = vec![(0, 0, 1.0), (1, 1, 1.0), (2, 2, 1.0)];
793        let shape = Shape::new(vec![3, 3]);
794        let pattern = analyzer
795            .analyze_advanced_pattern(&triplets, &shape)
796            .unwrap();
797
798        matches!(pattern, AdvancedSparsityPattern::Diagonal { .. });
799    }
800
801    #[test]
802    fn test_rcm_reordering() {
803        let triplets = vec![
804            (0, 1, 1.0),
805            (1, 0, 1.0),
806            (1, 2, 1.0),
807            (2, 1, 1.0),
808            (2, 3, 1.0),
809            (3, 2, 1.0),
810        ];
811
812        let ordering = MatrixReorderer::reverse_cuthill_mckee(&triplets, 4).unwrap();
813        assert_eq!(ordering.len(), 4);
814
815        let reordered = MatrixReorderer::apply_reordering(&triplets, &ordering);
816        assert_eq!(reordered.len(), triplets.len());
817    }
818
819    #[test]
820    fn test_pattern_statistics() {
821        let analyzer = PatternAnalyzer::new();
822        let triplets = vec![(0, 0, 1.0), (1, 1, 1.0), (2, 2, 1.0)];
823        let shape = Shape::new(vec![3, 3]);
824
825        let stats = analyzer
826            .compute_pattern_statistics(&triplets, &shape)
827            .unwrap();
828        assert_eq!(stats.nnz, 3);
829        assert_eq!(stats.dimensions, (3, 3));
830        assert_eq!(stats.bandwidth, 0); // Diagonal matrix
831    }
832
833    #[test]
834    fn test_pattern_visualization() {
835        let triplets = vec![(0, 0, 1.0), (1, 1, 1.0), (2, 2, 1.0)];
836        let shape = Shape::new(vec![3, 3]);
837
838        let ascii = PatternVisualizer::ascii_pattern(&triplets, &shape, Some((10, 10)));
839        assert!(ascii.contains("*"));
840
841        let (row_hist, col_hist) = PatternVisualizer::pattern_histogram(&triplets, &shape);
842        assert_eq!(row_hist, vec![1, 1, 1]);
843        assert_eq!(col_hist, vec![1, 1, 1]);
844    }
845}