scirs2_optimize/sparse_numdiff/
coloring.rs

1//! Graph coloring algorithms for efficient sparse differentiation
2//!
3//! This module provides implementations of graph coloring algorithms
4//! used for grouping columns in sparse finite difference calculations.
5
6use crate::error::OptimizeError;
7use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
8use scirs2_sparse::csr_array::CsrArray;
9use std::collections::{HashMap, HashSet};
10
11/// Get the list of non-zero columns in a specific row of a sparse matrix
12#[allow(dead_code)]
13fn get_nonzero_cols_in_row(matrix: &CsrArray<f64>, row: usize) -> Vec<usize> {
14    // Get row indices
15    let row_start = matrix.get_indptr()[row];
16    let row_end = matrix.get_indptr()[row + 1];
17
18    // Collect all column indices in this row
19    let indices = matrix.get_indices();
20    let mut cols = Vec::new();
21    for i in row_start..row_end {
22        cols.push(indices[i]);
23    }
24    cols
25}
26
27/// Implements a greedy graph coloring algorithm to group columns that can be
28/// perturbed simultaneously during sparse finite differences
29///
30/// # Arguments
31///
32/// * `sparsity` - Sparse matrix representing the sparsity pattern
33/// * `seed` - Optional random seed for reproducibility
34/// * `max_group_size` - Maximum number of columns per group
35///
36/// # Returns
37///
38/// * Vector of column groups, where each group contains columns that don't conflict
39#[allow(dead_code)]
40pub fn determine_column_groups(
41    sparsity: &CsrArray<f64>,
42    seed: Option<u64>,
43    max_group_size: Option<usize>,
44) -> Result<Vec<Vec<usize>>, OptimizeError> {
45    let (m, n) = sparsity.shape();
46
47    // Create a conflict graph represented as an adjacency list
48    // Two columns conflict if they have nonzeros in the same row
49    let mut conflicts: Vec<HashSet<usize>> = vec![HashSet::new(); n];
50
51    // Build the conflict graph
52    for row in 0..m {
53        let cols = get_nonzero_cols_in_row(sparsity, row);
54
55        // All columns with nonzeros in this row conflict with each other
56        for &col1 in &cols {
57            for &col2 in &cols {
58                if col1 != col2 {
59                    conflicts[col1].insert(col2);
60                    conflicts[col2].insert(col1);
61                }
62            }
63        }
64    }
65
66    // Order vertices for coloring (by degree, randomized)
67    let mut order: Vec<usize> = (0..n).collect();
68
69    // Sort vertices by degree (number of conflicts) for better coloring
70    order.sort_by_key(|&v| conflicts[v].len());
71
72    // Randomize the order of vertices with the same degree
73    let mut rng = match seed {
74        Some(s) => StdRng::seed_from_u64(s),
75        None => {
76            // Use a constant seed for reproducibility in case rng fails
77            // This is a fallback case, so using a fixed seed is acceptable
78            StdRng::seed_from_u64(0)
79        }
80    };
81
82    // Group vertices with the same degree and shuffle each group
83    let mut i = 0;
84    while i < order.len() {
85        let degree = conflicts[order[i]].len();
86        let mut j = i + 1;
87        while j < order.len() && conflicts[order[j]].len() == degree {
88            j += 1;
89        }
90
91        // Shuffle this group
92        order[i..j].shuffle(&mut rng);
93
94        i = j;
95    }
96
97    // Apply greedy coloring
98    let mut vertex_colors: HashMap<usize, usize> = HashMap::new();
99
100    for &v in &order {
101        // Find the lowest color not used by neighbors
102        let mut neighbor_colors: HashSet<usize> = HashSet::new();
103        for &neighbor in &conflicts[v] {
104            if let Some(&color) = vertex_colors.get(&neighbor) {
105                neighbor_colors.insert(color);
106            }
107        }
108
109        // Find smallest available color
110        let mut color = 0;
111        while neighbor_colors.contains(&color) {
112            color += 1;
113        }
114
115        vertex_colors.insert(v, color);
116    }
117
118    // Group vertices by color
119    let max_color = vertex_colors.values().max().cloned().unwrap_or(0);
120    let mut color_groups: Vec<Vec<usize>> = vec![Vec::new(); max_color + 1];
121
122    for (vertex, &color) in &vertex_colors {
123        color_groups[color].push(*vertex);
124    }
125
126    // Apply max group _size constraint if specified
127    let max_size = max_group_size.unwrap_or(usize::MAX);
128    if max_size < n {
129        let mut final_groups = Vec::new();
130
131        for group in color_groups {
132            if group.len() <= max_size {
133                final_groups.push(group);
134            } else {
135                // Split into smaller groups
136                for chunk in group.chunks(max_size) {
137                    final_groups.push(chunk.to_vec());
138                }
139            }
140        }
141
142        Ok(final_groups)
143    } else {
144        // Filter out empty groups
145        Ok(color_groups.into_iter().filter(|g| !g.is_empty()).collect())
146    }
147}