scirs2_optimize/sparse_numdiff/
coloring.rs1use crate::error::OptimizeError;
7use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
8use scirs2_sparse::csr_array::CsrArray;
9use std::collections::{HashMap, HashSet};
10
11fn get_nonzero_cols_in_row(matrix: &CsrArray<f64>, row: usize) -> Vec<usize> {
13 let row_start = matrix.get_indptr()[row];
15 let row_end = matrix.get_indptr()[row + 1];
16
17 let indices = matrix.get_indices();
19 let mut cols = Vec::new();
20 for i in row_start..row_end {
21 cols.push(indices[i]);
22 }
23 cols
24}
25
26pub fn determine_column_groups(
39 sparsity: &CsrArray<f64>,
40 seed: Option<u64>,
41 max_group_size: Option<usize>,
42) -> Result<Vec<Vec<usize>>, OptimizeError> {
43 let (m, n) = sparsity.shape();
44
45 let mut conflicts: Vec<HashSet<usize>> = vec![HashSet::new(); n];
48
49 for row in 0..m {
51 let cols = get_nonzero_cols_in_row(sparsity, row);
52
53 for &col1 in &cols {
55 for &col2 in &cols {
56 if col1 != col2 {
57 conflicts[col1].insert(col2);
58 conflicts[col2].insert(col1);
59 }
60 }
61 }
62 }
63
64 let mut order: Vec<usize> = (0..n).collect();
66
67 order.sort_by_key(|&v| conflicts[v].len());
69
70 let mut rng = match seed {
72 Some(s) => StdRng::seed_from_u64(s),
73 None => {
74 StdRng::seed_from_u64(0)
77 }
78 };
79
80 let mut i = 0;
82 while i < order.len() {
83 let degree = conflicts[order[i]].len();
84 let mut j = i + 1;
85 while j < order.len() && conflicts[order[j]].len() == degree {
86 j += 1;
87 }
88
89 order[i..j].shuffle(&mut rng);
91
92 i = j;
93 }
94
95 let mut vertex_colors: HashMap<usize, usize> = HashMap::new();
97
98 for &v in &order {
99 let mut neighbor_colors: HashSet<usize> = HashSet::new();
101 for &neighbor in &conflicts[v] {
102 if let Some(&color) = vertex_colors.get(&neighbor) {
103 neighbor_colors.insert(color);
104 }
105 }
106
107 let mut color = 0;
109 while neighbor_colors.contains(&color) {
110 color += 1;
111 }
112
113 vertex_colors.insert(v, color);
114 }
115
116 let max_color = vertex_colors.values().max().cloned().unwrap_or(0);
118 let mut color_groups: Vec<Vec<usize>> = vec![Vec::new(); max_color + 1];
119
120 for (vertex, &color) in &vertex_colors {
121 color_groups[color].push(*vertex);
122 }
123
124 let max_size = max_group_size.unwrap_or(usize::MAX);
126 if max_size < n {
127 let mut final_groups = Vec::new();
128
129 for group in color_groups {
130 if group.len() <= max_size {
131 final_groups.push(group);
132 } else {
133 for chunk in group.chunks(max_size) {
135 final_groups.push(chunk.to_vec());
136 }
137 }
138 }
139
140 Ok(final_groups)
141 } else {
142 Ok(color_groups.into_iter().filter(|g| !g.is_empty()).collect())
144 }
145}