1use crate::csc_array::CscArray;
8use crate::csr::CsrMatrix;
9use crate::csr_array::CsrArray;
10use crate::error::{SparseError, SparseResult};
11use crate::sparray::SparseArray;
12use scirs2_core::numeric::{SparseElement, Zero};
13use std::fmt::Debug;
14use std::ops::Div;
15
16#[derive(Debug, Clone)]
21pub struct AdjacencyGraph {
22 adj: Vec<Vec<usize>>,
24}
25
26impl AdjacencyGraph {
27 pub fn from_adjacency_list(mut adj: Vec<Vec<usize>>) -> Self {
31 let n = adj.len();
32 for (u, nbrs) in adj.iter_mut().enumerate() {
33 nbrs.retain(|&v| v != u && v < n);
34 nbrs.sort_unstable();
35 nbrs.dedup();
36 }
37 Self { adj }
38 }
39
40 pub fn from_csr_matrix<T>(mat: &CsrMatrix<T>) -> SparseResult<Self>
45 where
46 T: Clone + Copy + SparseElement + Zero + PartialEq + Debug,
47 {
48 let (n, nc) = mat.shape();
49 if n != nc {
50 return Err(SparseError::ValueError(
51 "adjacency graph requires a square matrix".to_string(),
52 ));
53 }
54
55 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
56 for row in 0..n {
57 for j in mat.indptr[row]..mat.indptr[row + 1] {
58 let col = mat.indices[j];
59 if col != row {
60 adj[row].push(col);
61 adj[col].push(row);
62 }
63 }
64 }
65 for nbrs in adj.iter_mut() {
66 nbrs.sort_unstable();
67 nbrs.dedup();
68 }
69 Ok(Self { adj })
70 }
71
72 pub fn from_csr_array<T>(arr: &CsrArray<T>) -> SparseResult<Self>
74 where
75 T: SparseElement + Div<Output = T> + Zero + PartialOrd + 'static,
76 {
77 let (n, nc) = arr.shape();
78 if n != nc {
79 return Err(SparseError::ValueError(
80 "adjacency graph requires a square matrix".to_string(),
81 ));
82 }
83
84 let dense = <CsrArray<T> as SparseArray<T>>::to_array(arr);
86 let zero = <T as Zero>::zero();
87 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
88 for row in 0..n {
89 for col in 0..n {
90 if row != col && dense[[row, col]] != zero {
91 adj[row].push(col);
92 }
93 }
94 }
95 for nbrs in adj.iter_mut() {
96 nbrs.sort_unstable();
97 nbrs.dedup();
98 }
99 Ok(Self { adj })
100 }
101
102 pub fn from_csc_array<T>(arr: &CscArray<T>) -> SparseResult<Self>
104 where
105 T: SparseElement
106 + Div<Output = T>
107 + Zero
108 + PartialOrd
109 + scirs2_core::numeric::Float
110 + 'static,
111 {
112 let (n, nc) = <CscArray<T> as SparseArray<T>>::shape(arr);
113 if n != nc {
114 return Err(SparseError::ValueError(
115 "adjacency graph requires a square matrix".to_string(),
116 ));
117 }
118
119 let dense = <CscArray<T> as SparseArray<T>>::to_array(arr);
120 let zero = <T as Zero>::zero();
121 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
122 for row in 0..n {
123 for col in 0..n {
124 if row != col && dense[[row, col]] != zero {
125 adj[row].push(col);
126 }
127 }
128 }
129 for nbrs in adj.iter_mut() {
130 nbrs.sort_unstable();
131 nbrs.dedup();
132 }
133 Ok(Self { adj })
134 }
135
136 #[inline]
138 pub fn num_nodes(&self) -> usize {
139 self.adj.len()
140 }
141
142 #[inline]
144 pub fn degree(&self, u: usize) -> usize {
145 self.adj.get(u).map_or(0, |v| v.len())
146 }
147
148 #[inline]
150 pub fn neighbors(&self, u: usize) -> &[usize] {
151 self.adj.get(u).map_or(&[], |v| v.as_slice())
152 }
153
154 pub fn num_edges(&self) -> usize {
156 let total: usize = self.adj.iter().map(|v| v.len()).sum();
157 total / 2
158 }
159
160 pub fn has_edge(&self, u: usize, v: usize) -> bool {
162 if u >= self.adj.len() || v >= self.adj.len() {
163 return false;
164 }
165 self.adj[u].binary_search(&v).is_ok()
166 }
167
168 pub fn subgraph(&self, nodes: &[usize]) -> (AdjacencyGraph, Vec<usize>) {
173 let n = nodes.len();
174 let mut rev = vec![usize::MAX; self.adj.len()];
176 for (new_i, &old_i) in nodes.iter().enumerate() {
177 if old_i < rev.len() {
178 rev[old_i] = new_i;
179 }
180 }
181
182 let mut adj = vec![Vec::new(); n];
183 for (new_u, &old_u) in nodes.iter().enumerate() {
184 if old_u >= self.adj.len() {
185 continue;
186 }
187 for &old_v in &self.adj[old_u] {
188 if old_v < rev.len() && rev[old_v] != usize::MAX {
189 adj[new_u].push(rev[old_v]);
190 }
191 }
192 adj[new_u].sort_unstable();
193 adj[new_u].dedup();
194 }
195
196 (AdjacencyGraph { adj }, nodes.to_vec())
197 }
198
199 pub(crate) fn raw_adj(&self) -> &[Vec<usize>] {
201 &self.adj
202 }
203}
204
205pub fn apply_permutation<T>(mat: &CsrMatrix<T>, perm: &[usize]) -> SparseResult<CsrMatrix<T>>
210where
211 T: Clone + Copy + SparseElement + Zero + PartialEq + Debug,
212{
213 let (n, nc) = mat.shape();
214 if n != nc {
215 return Err(SparseError::ValueError(
216 "apply_permutation requires a square matrix".to_string(),
217 ));
218 }
219 if perm.len() != n {
220 return Err(SparseError::DimensionMismatch {
221 expected: n,
222 found: perm.len(),
223 });
224 }
225
226 let mut inv_perm = vec![0usize; n];
228 for (new_i, &old_i) in perm.iter().enumerate() {
229 if old_i >= n {
230 return Err(SparseError::ValueError(format!(
231 "permutation index {} out of range (n={})",
232 old_i, n
233 )));
234 }
235 inv_perm[old_i] = new_i;
236 }
237
238 let mut rows = Vec::with_capacity(mat.nnz());
239 let mut cols = Vec::with_capacity(mat.nnz());
240 let mut data = Vec::with_capacity(mat.nnz());
241
242 for new_row in 0..n {
243 let old_row = perm[new_row];
244 for j in mat.indptr[old_row]..mat.indptr[old_row + 1] {
245 let old_col = mat.indices[j];
246 let new_col = inv_perm[old_col];
247 rows.push(new_row);
248 cols.push(new_col);
249 data.push(mat.data[j]);
250 }
251 }
252
253 CsrMatrix::new(data, rows, cols, (n, n))
254}
255
256pub fn apply_permutation_csr_array<T>(
258 arr: &CsrArray<T>,
259 perm: &[usize],
260) -> SparseResult<CsrArray<T>>
261where
262 T: SparseElement + Div<Output = T> + Zero + PartialOrd + 'static,
263{
264 let (n, nc) = arr.shape();
265 if n != nc {
266 return Err(SparseError::ValueError(
267 "apply_permutation requires a square matrix".to_string(),
268 ));
269 }
270 if perm.len() != n {
271 return Err(SparseError::DimensionMismatch {
272 expected: n,
273 found: perm.len(),
274 });
275 }
276
277 let mut inv_perm = vec![0usize; n];
278 for (new_i, &old_i) in perm.iter().enumerate() {
279 if old_i >= n {
280 return Err(SparseError::ValueError(format!(
281 "permutation index {} out of range (n={})",
282 old_i, n
283 )));
284 }
285 inv_perm[old_i] = new_i;
286 }
287
288 let dense = <CsrArray<T> as SparseArray<T>>::to_array(arr);
290 let mut rows = Vec::new();
291 let mut cols = Vec::new();
292 let mut data = Vec::new();
293 let zero = <T as Zero>::zero();
294
295 for new_row in 0..n {
296 let old_row = perm[new_row];
297 for old_col in 0..n {
298 let val = dense[[old_row, old_col]];
299 if val != zero {
300 let new_col = inv_perm[old_col];
301 rows.push(new_row);
302 cols.push(new_col);
303 data.push(val);
304 }
305 }
306 }
307
308 CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn test_adjacency_from_list() {
317 let adj = vec![vec![1, 2], vec![0, 2], vec![0, 1]];
318 let graph = AdjacencyGraph::from_adjacency_list(adj);
319 assert_eq!(graph.num_nodes(), 3);
320 assert_eq!(graph.degree(0), 2);
321 assert_eq!(graph.num_edges(), 3);
322 assert!(graph.has_edge(0, 1));
323 assert!(!graph.has_edge(0, 0)); }
325
326 #[test]
327 fn test_adjacency_from_csr_matrix() {
328 let rows = vec![0, 0, 1, 1, 2, 2];
330 let cols = vec![1, 2, 0, 2, 0, 1];
331 let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
332 let mat = CsrMatrix::new(data, rows, cols, (3, 3)).expect("csr");
333 let graph = AdjacencyGraph::from_csr_matrix(&mat).expect("adj");
334 assert_eq!(graph.num_nodes(), 3);
335 assert_eq!(graph.degree(0), 2);
336 }
337
338 #[test]
339 fn test_subgraph() {
340 let adj = vec![vec![1, 2, 3], vec![0, 2], vec![0, 1], vec![0]];
341 let graph = AdjacencyGraph::from_adjacency_list(adj);
342 let (sub, mapping) = graph.subgraph(&[0, 1, 2]);
343 assert_eq!(sub.num_nodes(), 3);
344 assert_eq!(mapping, vec![0, 1, 2]);
345 assert_eq!(sub.degree(0), 2);
347 }
348
349 #[test]
350 fn test_apply_permutation() {
351 let rows = vec![0, 0, 1, 1, 1, 2, 2];
353 let cols = vec![0, 1, 0, 1, 2, 1, 2];
354 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, -1.0, 2.0];
355 let mat = CsrMatrix::new(data, rows, cols, (3, 3)).expect("csr");
356
357 let perm = vec![2, 1, 0];
359 let permuted = apply_permutation(&mat, &perm).expect("apply");
360 assert_eq!(permuted.shape(), (3, 3));
361 assert_eq!(permuted.nnz(), 7);
362 }
363}