1use crate::TorshResult;
7use std::collections::{HashMap, HashSet, VecDeque};
8use torsh_core::Shape;
9
10#[derive(Debug, Clone)]
12pub enum AdvancedSparsityPattern {
13 Diagonal {
15 fill_ratio: f32,
17 },
18 MultiDiagonal {
20 num_diagonals: usize,
22 offsets: Vec<i32>,
24 },
25 BlockDiagonal {
27 block_sizes: Vec<(usize, usize)>,
29 block_positions: Vec<(usize, usize)>,
31 },
32 Banded {
34 lower_bandwidth: usize,
36 upper_bandwidth: usize,
38 fill_ratio: f32,
40 },
41 Symmetric {
43 symmetry_ratio: f32,
45 base_pattern: Box<AdvancedSparsityPattern>,
47 },
48 ArrowHead {
50 head_size: usize,
52 },
53 Random {
55 clustering_coefficient: f32,
57 },
58}
59
60#[derive(Debug, Clone)]
62pub enum ReorderingAlgorithm {
63 ReverseCuthillMcKee,
65 ApproximateMinimumDegree,
67 NestedDissection,
69 King,
71 Random,
73}
74
75#[derive(Debug, Clone)]
77pub enum ClusteringAlgorithm {
78 Spectral { num_clusters: usize },
80 KMeans { num_clusters: usize },
82 Hierarchical { num_clusters: usize },
84 GraphBased { num_clusters: usize },
86}
87
88#[derive(Debug, Clone)]
90pub struct PatternStatistics {
91 pub nnz: usize,
93 pub dimensions: (usize, usize),
95 pub sparsity: f32,
97 pub max_nnz_per_row: usize,
99 pub avg_nnz_per_row: f32,
101 pub std_nnz_per_row: f32,
103 pub bandwidth: usize,
105 pub profile: usize,
107 pub connected_components: usize,
109 pub clustering_coefficient: f32,
111}
112
113pub struct PatternAnalyzer {
115 cache: HashMap<String, AdvancedSparsityPattern>,
117}
118
119impl Default for PatternAnalyzer {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl PatternAnalyzer {
126 pub fn new() -> Self {
128 Self {
129 cache: HashMap::new(),
130 }
131 }
132
133 pub fn analyze_advanced_pattern(
135 &mut self,
136 triplets: &[(usize, usize, f32)],
137 shape: &Shape,
138 ) -> TorshResult<AdvancedSparsityPattern> {
139 let cache_key = self.create_cache_key(triplets, shape);
140
141 if let Some(cached_pattern) = self.cache.get(&cache_key) {
142 return Ok(cached_pattern.clone());
143 }
144
145 let pattern = self.detect_pattern(triplets, shape)?;
146 self.cache.insert(cache_key, pattern.clone());
147 Ok(pattern)
148 }
149
150 fn detect_pattern(
152 &self,
153 triplets: &[(usize, usize, f32)],
154 shape: &Shape,
155 ) -> TorshResult<AdvancedSparsityPattern> {
156 let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
157
158 if let Some(diagonal_pattern) = self.detect_diagonal_pattern(triplets, rows, cols) {
160 return Ok(diagonal_pattern);
161 }
162
163 if let Some(banded_pattern) = self.detect_banded_pattern(triplets, rows, cols) {
165 return Ok(banded_pattern);
166 }
167
168 if let Some(block_pattern) = self.detect_block_diagonal_pattern(triplets, rows, cols) {
170 return Ok(block_pattern);
171 }
172
173 if let Some(arrow_pattern) = self.detect_arrow_head_pattern(triplets, rows, cols) {
175 return Ok(arrow_pattern);
176 }
177
178 if let Some(symmetric_pattern) = self.detect_symmetric_pattern(triplets, rows, cols) {
180 return Ok(symmetric_pattern);
181 }
182
183 let clustering_coefficient = self.compute_clustering_coefficient(triplets, rows, cols);
185 Ok(AdvancedSparsityPattern::Random {
186 clustering_coefficient,
187 })
188 }
189
190 fn detect_diagonal_pattern(
192 &self,
193 triplets: &[(usize, usize, f32)],
194 rows: usize,
195 cols: usize,
196 ) -> Option<AdvancedSparsityPattern> {
197 let mut diagonal_counts: HashMap<i32, usize> = HashMap::new();
198
199 for (r, c, _) in triplets {
200 let offset = *r as i32 - *c as i32;
201 *diagonal_counts.entry(offset).or_insert(0) += 1;
202 }
203
204 let total_nnz = triplets.len();
205 let main_diagonal_count = diagonal_counts.get(&0).unwrap_or(&0);
206
207 if diagonal_counts.len() == 1 && diagonal_counts.contains_key(&0) {
209 let fill_ratio = *main_diagonal_count as f32 / std::cmp::min(rows, cols) as f32;
210 return Some(AdvancedSparsityPattern::Diagonal { fill_ratio });
211 }
212
213 if diagonal_counts.len() <= 5 {
215 let diagonal_nnz: usize = diagonal_counts.values().sum();
216 if diagonal_nnz as f32 / total_nnz as f32 > 0.9 {
217 let mut offsets: Vec<i32> = diagonal_counts.keys().copied().collect();
218 offsets.sort();
219 return Some(AdvancedSparsityPattern::MultiDiagonal {
220 num_diagonals: diagonal_counts.len(),
221 offsets,
222 });
223 }
224 }
225
226 None
227 }
228
229 fn detect_banded_pattern(
231 &self,
232 triplets: &[(usize, usize, f32)],
233 rows: usize,
234 cols: usize,
235 ) -> Option<AdvancedSparsityPattern> {
236 let mut max_lower_bandwidth = 0;
237 let mut max_upper_bandwidth = 0;
238
239 for (r, c, _) in triplets {
240 let diff = *r as i32 - *c as i32;
241 if diff > 0 {
242 max_lower_bandwidth = std::cmp::max(max_lower_bandwidth, diff as usize);
243 } else {
244 max_upper_bandwidth = std::cmp::max(max_upper_bandwidth, (-diff) as usize);
245 }
246 }
247
248 let total_bandwidth = max_lower_bandwidth + max_upper_bandwidth + 1;
249 let max_possible_bandwidth = std::cmp::min(rows, cols);
250
251 if total_bandwidth < max_possible_bandwidth / 4 {
253 let band_elements = std::cmp::min(rows, cols) * total_bandwidth
254 - (total_bandwidth * (total_bandwidth - 1)) / 2;
255 let fill_ratio = triplets.len() as f32 / band_elements as f32;
256
257 return Some(AdvancedSparsityPattern::Banded {
258 lower_bandwidth: max_lower_bandwidth,
259 upper_bandwidth: max_upper_bandwidth,
260 fill_ratio,
261 });
262 }
263
264 None
265 }
266
267 fn detect_block_diagonal_pattern(
269 &self,
270 triplets: &[(usize, usize, f32)],
271 rows: usize,
272 _cols: usize,
273 ) -> Option<AdvancedSparsityPattern> {
274 let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
276
277 for (r, c, _) in triplets {
278 adjacency.entry(*r).or_default().insert(*c);
279 adjacency.entry(*c).or_default().insert(*r);
280 }
281
282 let components = self.find_connected_components(&adjacency, rows);
284
285 if components.len() > 1 {
286 let mut block_sizes = Vec::new();
288 let mut block_positions = Vec::new();
289
290 for component in &components {
291 if component.len() > 1 {
292 let min_idx = *component
293 .iter()
294 .min()
295 .expect("component should not be empty");
296 let max_idx = *component
297 .iter()
298 .max()
299 .expect("component should not be empty");
300 let block_size = max_idx - min_idx + 1;
301
302 block_sizes.push((block_size, block_size));
303 block_positions.push((min_idx, min_idx));
304 }
305 }
306
307 if !block_sizes.is_empty() {
308 return Some(AdvancedSparsityPattern::BlockDiagonal {
309 block_sizes,
310 block_positions,
311 });
312 }
313 }
314
315 None
316 }
317
318 fn detect_arrow_head_pattern(
320 &self,
321 triplets: &[(usize, usize, f32)],
322 rows: usize,
323 cols: usize,
324 ) -> Option<AdvancedSparsityPattern> {
325 let mut first_row_count = 0;
326 let mut first_col_count = 0;
327
328 for (r, c, _) in triplets {
329 if *r == 0 {
330 first_row_count += 1;
331 }
332 if *c == 0 {
333 first_col_count += 1;
334 }
335 }
336
337 let first_row_density = first_row_count as f32 / cols as f32;
338 let first_col_density = first_col_count as f32 / rows as f32;
339
340 if first_row_density > 0.5 || first_col_density > 0.5 {
342 let head_size = std::cmp::max(first_row_count, first_col_count);
343 return Some(AdvancedSparsityPattern::ArrowHead { head_size });
344 }
345
346 None
347 }
348
349 fn detect_symmetric_pattern(
351 &self,
352 triplets: &[(usize, usize, f32)],
353 rows: usize,
354 cols: usize,
355 ) -> Option<AdvancedSparsityPattern> {
356 if rows != cols {
357 return None; }
359
360 let mut pattern_set: HashSet<(usize, usize)> = HashSet::new();
361 let mut symmetric_count = 0;
362
363 for (r, c, _) in triplets {
364 pattern_set.insert((*r, *c));
365 }
366
367 for (r, c, _) in triplets {
368 if pattern_set.contains(&(*c, *r)) {
369 symmetric_count += 1;
370 }
371 }
372
373 let symmetry_ratio = symmetric_count as f32 / triplets.len() as f32;
374
375 if symmetry_ratio > 0.8 {
376 let base_pattern = Box::new(AdvancedSparsityPattern::Random {
378 clustering_coefficient: self.compute_clustering_coefficient(triplets, rows, cols),
379 });
380
381 return Some(AdvancedSparsityPattern::Symmetric {
382 symmetry_ratio,
383 base_pattern,
384 });
385 }
386
387 None
388 }
389
390 fn compute_clustering_coefficient(
392 &self,
393 triplets: &[(usize, usize, f32)],
394 rows: usize,
395 _cols: usize,
396 ) -> f32 {
397 let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
398
399 for (r, c, _) in triplets {
400 if r != c {
401 adjacency.entry(*r).or_default().insert(*c);
403 adjacency.entry(*c).or_default().insert(*r);
404 }
405 }
406
407 let mut total_clustering = 0.0;
408 let mut nodes_with_neighbors = 0;
409
410 for node in 0..rows {
411 if let Some(neighbors) = adjacency.get(&node) {
412 if neighbors.len() >= 2 {
413 let mut triangles = 0;
414 let neighbor_vec: Vec<_> = neighbors.iter().collect();
415
416 for i in 0..neighbor_vec.len() {
417 for j in (i + 1)..neighbor_vec.len() {
418 if adjacency
419 .get(neighbor_vec[i])
420 .is_some_and(|adj| adj.contains(neighbor_vec[j]))
421 {
422 triangles += 1;
423 }
424 }
425 }
426
427 let possible_edges = neighbors.len() * (neighbors.len() - 1) / 2;
428 if possible_edges > 0 {
429 total_clustering += triangles as f32 / possible_edges as f32;
430 nodes_with_neighbors += 1;
431 }
432 }
433 }
434 }
435
436 if nodes_with_neighbors > 0 {
437 total_clustering / nodes_with_neighbors as f32
438 } else {
439 0.0
440 }
441 }
442
443 fn find_connected_components(
445 &self,
446 adjacency: &HashMap<usize, HashSet<usize>>,
447 num_nodes: usize,
448 ) -> Vec<Vec<usize>> {
449 let mut visited = vec![false; num_nodes];
450 let mut components = Vec::new();
451
452 for node in 0..num_nodes {
453 if !visited[node] {
454 let mut component = Vec::new();
455 let mut queue = VecDeque::new();
456 queue.push_back(node);
457 visited[node] = true;
458
459 while let Some(current) = queue.pop_front() {
460 component.push(current);
461
462 if let Some(neighbors) = adjacency.get(¤t) {
463 for &neighbor in neighbors {
464 if !visited[neighbor] {
465 visited[neighbor] = true;
466 queue.push_back(neighbor);
467 }
468 }
469 }
470 }
471
472 components.push(component);
473 }
474 }
475
476 components
477 }
478
479 fn create_cache_key(&self, triplets: &[(usize, usize, f32)], shape: &Shape) -> String {
481 format!(
482 "{}_{}_{}_{}",
483 shape.dims()[0],
484 shape.dims()[1],
485 triplets.len(),
486 triplets
487 .iter()
488 .take(10)
489 .map(|(r, c, _)| format!("{r}_{c}"))
490 .collect::<Vec<_>>()
491 .join("_")
492 )
493 }
494
495 pub fn compute_pattern_statistics(
497 &self,
498 triplets: &[(usize, usize, f32)],
499 shape: &Shape,
500 ) -> TorshResult<PatternStatistics> {
501 let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
502 let nnz = triplets.len();
503 let sparsity = 1.0 - (nnz as f32 / (rows * cols) as f32);
504
505 let mut row_counts = vec![0; rows];
507 let mut max_bandwidth = 0;
508 let mut profile = 0;
509
510 for (r, c, _) in triplets {
511 row_counts[*r] += 1;
512 let distance = (*r as i32 - *c as i32).unsigned_abs() as usize;
513 max_bandwidth = std::cmp::max(max_bandwidth, distance);
514 profile += distance;
515 }
516
517 let max_nnz_per_row = *row_counts.iter().max().unwrap_or(&0);
518 let avg_nnz_per_row = nnz as f32 / rows as f32;
519 let variance = row_counts
520 .iter()
521 .map(|&count| (count as f32 - avg_nnz_per_row).powi(2))
522 .sum::<f32>()
523 / rows as f32;
524 let std_nnz_per_row = variance.sqrt();
525
526 let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
528 for (r, c, _) in triplets {
529 adjacency.entry(*r).or_default().insert(*c);
530 adjacency.entry(*c).or_default().insert(*r);
531 }
532
533 let components = self.find_connected_components(&adjacency, rows);
534 let connected_components = components.len();
535
536 let clustering_coefficient = self.compute_clustering_coefficient(triplets, rows, cols);
537
538 Ok(PatternStatistics {
539 nnz,
540 dimensions: (rows, cols),
541 sparsity,
542 max_nnz_per_row,
543 avg_nnz_per_row,
544 std_nnz_per_row,
545 bandwidth: max_bandwidth,
546 profile,
547 connected_components,
548 clustering_coefficient,
549 })
550 }
551}
552
553pub struct MatrixReorderer;
555
556impl MatrixReorderer {
557 pub fn reverse_cuthill_mckee(
559 triplets: &[(usize, usize, f32)],
560 num_rows: usize,
561 ) -> TorshResult<Vec<usize>> {
562 let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
564 for (r, c, _) in triplets {
565 if r != c {
566 adjacency.entry(*r).or_default().insert(*c);
567 adjacency.entry(*c).or_default().insert(*r);
568 }
569 }
570
571 let start_vertex = Self::find_peripheral_vertex(&adjacency, num_rows)?;
573
574 let mut ordering = Vec::new();
576 let mut visited = vec![false; num_rows];
577 let mut queue = VecDeque::new();
578
579 queue.push_back(start_vertex);
580 visited[start_vertex] = true;
581
582 while let Some(vertex) = queue.pop_front() {
583 ordering.push(vertex);
584
585 if let Some(neighbors) = adjacency.get(&vertex) {
587 let mut neighbor_degrees: Vec<_> = neighbors
588 .iter()
589 .filter(|&&neighbor| !visited[neighbor])
590 .map(|&neighbor| {
591 let degree = adjacency.get(&neighbor).map_or(0, |adj| adj.len());
592 (degree, neighbor)
593 })
594 .collect();
595
596 neighbor_degrees.sort_by_key(|&(degree, _)| degree);
597
598 for (_, neighbor) in neighbor_degrees {
599 if !visited[neighbor] {
600 visited[neighbor] = true;
601 queue.push_back(neighbor);
602 }
603 }
604 }
605 }
606
607 for (i, &is_visited) in visited.iter().enumerate() {
609 if !is_visited {
610 ordering.push(i);
611 }
612 }
613
614 ordering.reverse();
616
617 Ok(ordering)
618 }
619
620 fn find_peripheral_vertex(
622 adjacency: &HashMap<usize, HashSet<usize>>,
623 num_rows: usize,
624 ) -> TorshResult<usize> {
625 let mut min_degree = usize::MAX;
626 let mut peripheral_candidates = Vec::new();
627
628 for i in 0..num_rows {
630 let degree = adjacency.get(&i).map_or(0, |adj| adj.len());
631 if degree < min_degree {
632 min_degree = degree;
633 peripheral_candidates.clear();
634 peripheral_candidates.push(i);
635 } else if degree == min_degree {
636 peripheral_candidates.push(i);
637 }
638 }
639
640 if peripheral_candidates.is_empty() {
641 return Ok(0); }
643
644 let mut best_vertex = peripheral_candidates[0];
646 let mut max_distance = 0;
647
648 for &candidate in &peripheral_candidates {
649 let distance = Self::compute_eccentricity(adjacency, candidate, num_rows);
650 if distance > max_distance {
651 max_distance = distance;
652 best_vertex = candidate;
653 }
654 }
655
656 Ok(best_vertex)
657 }
658
659 fn compute_eccentricity(
661 adjacency: &HashMap<usize, HashSet<usize>>,
662 start: usize,
663 num_rows: usize,
664 ) -> usize {
665 let mut distances = vec![usize::MAX; num_rows];
666 let mut queue = VecDeque::new();
667
668 distances[start] = 0;
669 queue.push_back(start);
670
671 while let Some(vertex) = queue.pop_front() {
672 if let Some(neighbors) = adjacency.get(&vertex) {
673 for &neighbor in neighbors {
674 if distances[neighbor] == usize::MAX {
675 distances[neighbor] = distances[vertex] + 1;
676 queue.push_back(neighbor);
677 }
678 }
679 }
680 }
681
682 distances
683 .iter()
684 .filter(|&&d| d != usize::MAX)
685 .max()
686 .copied()
687 .unwrap_or(0)
688 }
689
690 pub fn apply_reordering(
692 triplets: &[(usize, usize, f32)],
693 ordering: &[usize],
694 ) -> Vec<(usize, usize, f32)> {
695 let mut inverse_ordering = vec![0; ordering.len()];
696 for (new_idx, &old_idx) in ordering.iter().enumerate() {
697 inverse_ordering[old_idx] = new_idx;
698 }
699
700 triplets
701 .iter()
702 .map(|(r, c, v)| (inverse_ordering[*r], inverse_ordering[*c], *v))
703 .collect()
704 }
705}
706
707pub struct PatternVisualizer;
709
710impl PatternVisualizer {
711 pub fn ascii_pattern(
713 triplets: &[(usize, usize, f32)],
714 shape: &Shape,
715 max_size: Option<(usize, usize)>,
716 ) -> String {
717 let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
718 let (display_rows, display_cols) = max_size.unwrap_or((50, 50));
719
720 let row_scale = if rows > display_rows {
721 rows / display_rows
722 } else {
723 1
724 };
725 let col_scale = if cols > display_cols {
726 cols / display_cols
727 } else {
728 1
729 };
730
731 let scaled_rows = rows.div_ceil(row_scale);
732 let scaled_cols = cols.div_ceil(col_scale);
733
734 let mut pattern = vec![vec![' '; scaled_cols]; scaled_rows];
735
736 for (r, c, _) in triplets {
737 let scaled_r = r / row_scale;
738 let scaled_c = c / col_scale;
739 if scaled_r < scaled_rows && scaled_c < scaled_cols {
740 pattern[scaled_r][scaled_c] = '*';
741 }
742 }
743
744 let mut result = String::new();
745 result.push_str(&format!(
746 "Sparsity Pattern ({rows}x{cols}, scaled to {scaled_rows}x{scaled_cols})\n"
747 ));
748 result.push_str(&"-".repeat(scaled_cols + 2));
749 result.push('\n');
750
751 for row in pattern {
752 result.push('|');
753 for cell in row {
754 result.push(cell);
755 }
756 result.push_str("|\n");
757 }
758
759 result.push_str(&"-".repeat(scaled_cols + 2));
760 result.push('\n');
761
762 result
763 }
764
765 pub fn pattern_histogram(
767 triplets: &[(usize, usize, f32)],
768 shape: &Shape,
769 ) -> (Vec<usize>, Vec<usize>) {
770 let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
771 let mut row_counts = vec![0; rows];
772 let mut col_counts = vec![0; cols];
773
774 for (r, c, _) in triplets {
775 row_counts[*r] += 1;
776 col_counts[*c] += 1;
777 }
778
779 (row_counts, col_counts)
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use super::*;
786
787 #[test]
788 fn test_advanced_pattern_detection() {
789 let mut analyzer = PatternAnalyzer::new();
790
791 let triplets = vec![(0, 0, 1.0), (1, 1, 1.0), (2, 2, 1.0)];
793 let shape = Shape::new(vec![3, 3]);
794 let pattern = analyzer
795 .analyze_advanced_pattern(&triplets, &shape)
796 .unwrap();
797
798 matches!(pattern, AdvancedSparsityPattern::Diagonal { .. });
799 }
800
801 #[test]
802 fn test_rcm_reordering() {
803 let triplets = vec![
804 (0, 1, 1.0),
805 (1, 0, 1.0),
806 (1, 2, 1.0),
807 (2, 1, 1.0),
808 (2, 3, 1.0),
809 (3, 2, 1.0),
810 ];
811
812 let ordering = MatrixReorderer::reverse_cuthill_mckee(&triplets, 4).unwrap();
813 assert_eq!(ordering.len(), 4);
814
815 let reordered = MatrixReorderer::apply_reordering(&triplets, &ordering);
816 assert_eq!(reordered.len(), triplets.len());
817 }
818
819 #[test]
820 fn test_pattern_statistics() {
821 let analyzer = PatternAnalyzer::new();
822 let triplets = vec![(0, 0, 1.0), (1, 1, 1.0), (2, 2, 1.0)];
823 let shape = Shape::new(vec![3, 3]);
824
825 let stats = analyzer
826 .compute_pattern_statistics(&triplets, &shape)
827 .unwrap();
828 assert_eq!(stats.nnz, 3);
829 assert_eq!(stats.dimensions, (3, 3));
830 assert_eq!(stats.bandwidth, 0); }
832
833 #[test]
834 fn test_pattern_visualization() {
835 let triplets = vec![(0, 0, 1.0), (1, 1, 1.0), (2, 2, 1.0)];
836 let shape = Shape::new(vec![3, 3]);
837
838 let ascii = PatternVisualizer::ascii_pattern(&triplets, &shape, Some((10, 10)));
839 assert!(ascii.contains("*"));
840
841 let (row_hist, col_hist) = PatternVisualizer::pattern_histogram(&triplets, &shape);
842 assert_eq!(row_hist, vec![1, 1, 1]);
843 assert_eq!(col_hist, vec![1, 1, 1]);
844 }
845}