scirs2_optimize/sparse_numdiff/
coloring.rs1use crate::error::OptimizeError;
7use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
8use scirs2_sparse::csr_array::CsrArray;
9use std::collections::{HashMap, HashSet};
10
11#[allow(dead_code)]
13fn get_nonzero_cols_in_row(matrix: &CsrArray<f64>, row: usize) -> Vec<usize> {
14 let row_start = matrix.get_indptr()[row];
16 let row_end = matrix.get_indptr()[row + 1];
17
18 let indices = matrix.get_indices();
20 let mut cols = Vec::new();
21 for i in row_start..row_end {
22 cols.push(indices[i]);
23 }
24 cols
25}
26
27#[allow(dead_code)]
40pub fn determine_column_groups(
41 sparsity: &CsrArray<f64>,
42 seed: Option<u64>,
43 max_group_size: Option<usize>,
44) -> Result<Vec<Vec<usize>>, OptimizeError> {
45 let (m, n) = sparsity.shape();
46
47 let mut conflicts: Vec<HashSet<usize>> = vec![HashSet::new(); n];
50
51 for row in 0..m {
53 let cols = get_nonzero_cols_in_row(sparsity, row);
54
55 for &col1 in &cols {
57 for &col2 in &cols {
58 if col1 != col2 {
59 conflicts[col1].insert(col2);
60 conflicts[col2].insert(col1);
61 }
62 }
63 }
64 }
65
66 let mut order: Vec<usize> = (0..n).collect();
68
69 order.sort_by_key(|&v| conflicts[v].len());
71
72 let mut rng = match seed {
74 Some(s) => StdRng::seed_from_u64(s),
75 None => {
76 StdRng::seed_from_u64(0)
79 }
80 };
81
82 let mut i = 0;
84 while i < order.len() {
85 let degree = conflicts[order[i]].len();
86 let mut j = i + 1;
87 while j < order.len() && conflicts[order[j]].len() == degree {
88 j += 1;
89 }
90
91 order[i..j].shuffle(&mut rng);
93
94 i = j;
95 }
96
97 let mut vertex_colors: HashMap<usize, usize> = HashMap::new();
99
100 for &v in &order {
101 let mut neighbor_colors: HashSet<usize> = HashSet::new();
103 for &neighbor in &conflicts[v] {
104 if let Some(&color) = vertex_colors.get(&neighbor) {
105 neighbor_colors.insert(color);
106 }
107 }
108
109 let mut color = 0;
111 while neighbor_colors.contains(&color) {
112 color += 1;
113 }
114
115 vertex_colors.insert(v, color);
116 }
117
118 let max_color = vertex_colors.values().max().cloned().unwrap_or(0);
120 let mut color_groups: Vec<Vec<usize>> = vec![Vec::new(); max_color + 1];
121
122 for (vertex, &color) in &vertex_colors {
123 color_groups[color].push(*vertex);
124 }
125
126 let max_size = max_group_size.unwrap_or(usize::MAX);
128 if max_size < n {
129 let mut final_groups = Vec::new();
130
131 for group in color_groups {
132 if group.len() <= max_size {
133 final_groups.push(group);
134 } else {
135 for chunk in group.chunks(max_size) {
137 final_groups.push(chunk.to_vec());
138 }
139 }
140 }
141
142 Ok(final_groups)
143 } else {
144 Ok(color_groups.into_iter().filter(|g| !g.is_empty()).collect())
146 }
147}