scirs2_optimize/sparse_numdiff/
coloring.rs

1//! Graph coloring algorithms for efficient sparse differentiation
2//!
3//! This module provides implementations of graph coloring algorithms
4//! used for grouping columns in sparse finite difference calculations.
5
6use crate::error::OptimizeError;
7use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
8use scirs2_sparse::csr_array::CsrArray;
9use std::collections::{HashMap, HashSet};
10
11/// Get the list of non-zero columns in a specific row of a sparse matrix
12fn get_nonzero_cols_in_row(matrix: &CsrArray<f64>, row: usize) -> Vec<usize> {
13    // Get row indices
14    let row_start = matrix.get_indptr()[row];
15    let row_end = matrix.get_indptr()[row + 1];
16
17    // Collect all column indices in this row
18    let indices = matrix.get_indices();
19    let mut cols = Vec::new();
20    for i in row_start..row_end {
21        cols.push(indices[i]);
22    }
23    cols
24}
25
26/// Implements a greedy graph coloring algorithm to group columns that can be
27/// perturbed simultaneously during sparse finite differences
28///
29/// # Arguments
30///
31/// * `sparsity` - Sparse matrix representing the sparsity pattern
32/// * `seed` - Optional random seed for reproducibility
33/// * `max_group_size` - Maximum number of columns per group
34///
35/// # Returns
36///
37/// * Vector of column groups, where each group contains columns that don't conflict
38pub fn determine_column_groups(
39    sparsity: &CsrArray<f64>,
40    seed: Option<u64>,
41    max_group_size: Option<usize>,
42) -> Result<Vec<Vec<usize>>, OptimizeError> {
43    let (m, n) = sparsity.shape();
44
45    // Create a conflict graph represented as an adjacency list
46    // Two columns conflict if they have nonzeros in the same row
47    let mut conflicts: Vec<HashSet<usize>> = vec![HashSet::new(); n];
48
49    // Build the conflict graph
50    for row in 0..m {
51        let cols = get_nonzero_cols_in_row(sparsity, row);
52
53        // All columns with nonzeros in this row conflict with each other
54        for &col1 in &cols {
55            for &col2 in &cols {
56                if col1 != col2 {
57                    conflicts[col1].insert(col2);
58                    conflicts[col2].insert(col1);
59                }
60            }
61        }
62    }
63
64    // Order vertices for coloring (by degree, randomized)
65    let mut order: Vec<usize> = (0..n).collect();
66
67    // Sort vertices by degree (number of conflicts) for better coloring
68    order.sort_by_key(|&v| conflicts[v].len());
69
70    // Randomize the order of vertices with the same degree
71    let mut rng = match seed {
72        Some(s) => StdRng::seed_from_u64(s),
73        None => {
74            // Use a constant seed for reproducibility in case thread_rng fails
75            // This is a fallback case, so using a fixed seed is acceptable
76            StdRng::seed_from_u64(0)
77        }
78    };
79
80    // Group vertices with the same degree and shuffle each group
81    let mut i = 0;
82    while i < order.len() {
83        let degree = conflicts[order[i]].len();
84        let mut j = i + 1;
85        while j < order.len() && conflicts[order[j]].len() == degree {
86            j += 1;
87        }
88
89        // Shuffle this group
90        order[i..j].shuffle(&mut rng);
91
92        i = j;
93    }
94
95    // Apply greedy coloring
96    let mut vertex_colors: HashMap<usize, usize> = HashMap::new();
97
98    for &v in &order {
99        // Find the lowest color not used by neighbors
100        let mut neighbor_colors: HashSet<usize> = HashSet::new();
101        for &neighbor in &conflicts[v] {
102            if let Some(&color) = vertex_colors.get(&neighbor) {
103                neighbor_colors.insert(color);
104            }
105        }
106
107        // Find smallest available color
108        let mut color = 0;
109        while neighbor_colors.contains(&color) {
110            color += 1;
111        }
112
113        vertex_colors.insert(v, color);
114    }
115
116    // Group vertices by color
117    let max_color = vertex_colors.values().max().cloned().unwrap_or(0);
118    let mut color_groups: Vec<Vec<usize>> = vec![Vec::new(); max_color + 1];
119
120    for (vertex, &color) in &vertex_colors {
121        color_groups[color].push(*vertex);
122    }
123
124    // Apply max group size constraint if specified
125    let max_size = max_group_size.unwrap_or(usize::MAX);
126    if max_size < n {
127        let mut final_groups = Vec::new();
128
129        for group in color_groups {
130            if group.len() <= max_size {
131                final_groups.push(group);
132            } else {
133                // Split into smaller groups
134                for chunk in group.chunks(max_size) {
135                    final_groups.push(chunk.to_vec());
136                }
137            }
138        }
139
140        Ok(final_groups)
141    } else {
142        // Filter out empty groups
143        Ok(color_groups.into_iter().filter(|g| !g.is_empty()).collect())
144    }
145}