1use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use std::collections::{HashMap, HashSet};
7use std::hash::Hash;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum MotifType {
15 Triangle,
17 Square,
19 Star3,
21 Clique4,
23 Path3,
25 BiFan,
27 FeedForwardLoop,
29 BiDirectional,
31}
32
33#[allow(dead_code)]
35pub fn find_motifs<N, E, Ix>(graph: &Graph<N, E, Ix>, motiftype: MotifType) -> Vec<Vec<N>>
36where
37 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
38 E: EdgeWeight + Send + Sync,
39 Ix: IndexType + Send + Sync,
40{
41 match motiftype {
42 MotifType::Triangle => find_triangles(graph),
43 MotifType::Square => find_squares(graph),
44 MotifType::Star3 => find_star3s(graph),
45 MotifType::Clique4 => find_clique4s(graph),
46 MotifType::Path3 => find_path3s(graph),
47 MotifType::BiFan => find_bi_fans(graph),
48 MotifType::FeedForwardLoop => find_feed_forward_loops(graph),
49 MotifType::BiDirectional => find_bidirectional_motifs(graph),
50 }
51}
52
53#[allow(dead_code)]
54fn find_triangles<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
55where
56 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
57 E: EdgeWeight + Send + Sync,
58 Ix: IndexType + Send + Sync,
59{
60 use scirs2_core::parallel_ops::*;
61 use std::sync::Mutex;
62
63 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
64 let triangles = Mutex::new(Vec::new());
65
66 nodes.par_iter().enumerate().for_each(|(_i, node_i)| {
68 if let Ok(neighbors_i) = graph.neighbors(node_i) {
69 let neighbors_i: Vec<_> = neighbors_i;
70
71 for (j, node_j) in neighbors_i.iter().enumerate() {
72 for node_k in neighbors_i.iter().skip(j + 1) {
73 if graph.has_edge(node_j, node_k) {
74 let mut triangles_guard = triangles.lock().unwrap();
75 let mut triangle = vec![node_i.clone(), node_j.clone(), node_k.clone()];
76 triangle.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
77
78 if !triangles_guard.iter().any(|t| t == &triangle) {
80 triangles_guard.push(triangle);
81 }
82 }
83 }
84 }
85 }
86 });
87
88 triangles.into_inner().unwrap()
89}
90
91#[allow(dead_code)]
92fn find_squares<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
93where
94 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
95 E: EdgeWeight + Send + Sync,
96 Ix: IndexType + Send + Sync,
97{
98 let mut squares = Vec::new();
99 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
100
101 for i in 0..nodes.len() {
103 for j in i + 1..nodes.len() {
104 if !graph.has_edge(&nodes[i], &nodes[j]) {
105 continue;
106 }
107 for k in j + 1..nodes.len() {
108 if !graph.has_edge(&nodes[j], &nodes[k]) {
109 continue;
110 }
111 for l in k + 1..nodes.len() {
112 if graph.has_edge(&nodes[k], &nodes[l])
113 && graph.has_edge(&nodes[l], &nodes[i])
114 && !graph.has_edge(&nodes[i], &nodes[k])
115 && !graph.has_edge(&nodes[j], &nodes[l])
116 {
117 squares.push(vec![
118 nodes[i].clone(),
119 nodes[j].clone(),
120 nodes[k].clone(),
121 nodes[l].clone(),
122 ]);
123 }
124 }
125 }
126 }
127 }
128
129 squares
130}
131
132#[allow(dead_code)]
133fn find_star3s<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
134where
135 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
136 E: EdgeWeight + Send + Sync,
137 Ix: IndexType + Send + Sync,
138{
139 let mut stars = Vec::new();
140 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
141
142 for center in &nodes {
144 if let Ok(neighbors) = graph.neighbors(center) {
145 let neighbor_list: Vec<N> = neighbors;
146
147 if neighbor_list.len() >= 3 {
148 for i in 0..neighbor_list.len() {
150 for j in i + 1..neighbor_list.len() {
151 for k in j + 1..neighbor_list.len() {
152 if !graph.has_edge(&neighbor_list[i], &neighbor_list[j])
154 && !graph.has_edge(&neighbor_list[j], &neighbor_list[k])
155 && !graph.has_edge(&neighbor_list[i], &neighbor_list[k])
156 {
157 stars.push(vec![
158 center.clone(),
159 neighbor_list[i].clone(),
160 neighbor_list[j].clone(),
161 neighbor_list[k].clone(),
162 ]);
163 }
164 }
165 }
166 }
167 }
168 }
169 }
170
171 stars
172}
173
174#[allow(dead_code)]
175fn find_clique4s<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
176where
177 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
178 E: EdgeWeight + Send + Sync,
179 Ix: IndexType + Send + Sync,
180{
181 let mut cliques = Vec::new();
182 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
183
184 for i in 0..nodes.len() {
186 for j in i + 1..nodes.len() {
187 if !graph.has_edge(&nodes[i], &nodes[j]) {
188 continue;
189 }
190 for k in j + 1..nodes.len() {
191 if !graph.has_edge(&nodes[i], &nodes[k]) || !graph.has_edge(&nodes[j], &nodes[k]) {
192 continue;
193 }
194 for l in k + 1..nodes.len() {
195 if graph.has_edge(&nodes[i], &nodes[l])
196 && graph.has_edge(&nodes[j], &nodes[l])
197 && graph.has_edge(&nodes[k], &nodes[l])
198 {
199 cliques.push(vec![
200 nodes[i].clone(),
201 nodes[j].clone(),
202 nodes[k].clone(),
203 nodes[l].clone(),
204 ]);
205 }
206 }
207 }
208 }
209 }
210
211 cliques
212}
213
214#[allow(dead_code)]
216fn find_path3s<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
217where
218 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
219 E: EdgeWeight + Send + Sync,
220 Ix: IndexType + Send + Sync,
221{
222 use scirs2_core::parallel_ops::*;
223 use std::sync::Mutex;
224
225 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
226 let paths = Mutex::new(Vec::new());
227
228 nodes.par_iter().for_each(|start_node| {
229 if let Ok(neighbors1) = graph.neighbors(start_node) {
230 for middle1 in neighbors1 {
231 if let Ok(neighbors2) = graph.neighbors(&middle1) {
232 for middle2 in neighbors2 {
233 if middle2 == *start_node {
234 continue;
235 }
236
237 if let Ok(neighbors3) = graph.neighbors(&middle2) {
238 for end_node in neighbors3 {
239 if end_node == middle1 || end_node == *start_node {
240 continue;
241 }
242
243 if !graph.has_edge(start_node, &middle2)
245 && !graph.has_edge(start_node, &end_node)
246 && !graph.has_edge(&middle1, &end_node)
247 {
248 let mut path = vec![
249 start_node.clone(),
250 middle1.clone(),
251 middle2.clone(),
252 end_node.clone(),
253 ];
254 path.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
255
256 let mut paths_guard = paths.lock().unwrap();
257 if !paths_guard.iter().any(|p| p == &path) {
258 paths_guard.push(path);
259 }
260 }
261 }
262 }
263 }
264 }
265 }
266 }
267 });
268
269 paths.into_inner().unwrap()
270}
271
272#[allow(dead_code)]
274fn find_bi_fans<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
275where
276 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
277 E: EdgeWeight + Send + Sync,
278 Ix: IndexType + Send + Sync,
279{
280 use scirs2_core::parallel_ops::*;
281 use std::sync::Mutex;
282
283 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
284 let bi_fans = Mutex::new(Vec::new());
285
286 nodes.par_iter().enumerate().for_each(|(i, node1)| {
287 for node2 in nodes.iter().skip(i + 1) {
288 if let (Ok(neighbors1), Ok(neighbors2)) =
289 (graph.neighbors(node1), graph.neighbors(node2))
290 {
291 let neighbors1: HashSet<_> = neighbors1.into_iter().collect();
292 let neighbors2: HashSet<_> = neighbors2.into_iter().collect();
293
294 let common: Vec<_> = neighbors1
296 .intersection(&neighbors2)
297 .filter(|&n| n != node1 && n != node2)
298 .cloned()
299 .collect();
300
301 if common.len() >= 2 {
302 for (j, fan1) in common.iter().enumerate() {
304 for fan2 in common.iter().skip(j + 1) {
305 let mut bi_fan =
306 vec![node1.clone(), node2.clone(), fan1.clone(), fan2.clone()];
307 bi_fan.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
308
309 let mut bi_fans_guard = bi_fans.lock().unwrap();
310 if !bi_fans_guard.iter().any(|bf| bf == &bi_fan) {
311 bi_fans_guard.push(bi_fan);
312 }
313 }
314 }
315 }
316 }
317 }
318 });
319
320 bi_fans.into_inner().unwrap()
321}
322
323#[allow(dead_code)]
325fn find_feed_forward_loops<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
326where
327 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
328 E: EdgeWeight + Send + Sync,
329 Ix: IndexType + Send + Sync,
330{
331 use scirs2_core::parallel_ops::*;
332 use std::sync::Mutex;
333
334 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
335 let ffls = Mutex::new(Vec::new());
336
337 nodes.par_iter().for_each(|node_a| {
339 if let Ok(out_neighbors_a) = graph.neighbors(node_a) {
340 let out_neighbors_a: Vec<_> = out_neighbors_a;
341
342 for (i, node_b) in out_neighbors_a.iter().enumerate() {
343 for node_c in out_neighbors_a.iter().skip(i + 1) {
344 if graph.has_edge(node_b, node_c) {
346 if !graph.has_edge(node_b, node_a)
348 && !graph.has_edge(node_c, node_a)
349 && !graph.has_edge(node_c, node_b)
350 {
351 let mut ffl = vec![node_a.clone(), node_b.clone(), node_c.clone()];
352 ffl.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
353
354 let mut ffls_guard = ffls.lock().unwrap();
355 if !ffls_guard.iter().any(|f| f == &ffl) {
356 ffls_guard.push(ffl);
357 }
358 }
359 }
360 }
361 }
362 }
363 });
364
365 ffls.into_inner().unwrap()
366}
367
368#[allow(dead_code)]
370fn find_bidirectional_motifs<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
371where
372 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
373 E: EdgeWeight + Send + Sync,
374 Ix: IndexType + Send + Sync,
375{
376 use scirs2_core::parallel_ops::*;
377 use std::sync::Mutex;
378
379 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
380 let bidirectionals = Mutex::new(Vec::new());
381
382 nodes.par_iter().enumerate().for_each(|(i, node1)| {
383 for node2 in nodes.iter().skip(i + 1) {
384 if graph.has_edge(node1, node2) && graph.has_edge(node2, node1) {
386 let mut motif = vec![node1.clone(), node2.clone()];
387 motif.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
388
389 let mut bidirectionals_guard = bidirectionals.lock().unwrap();
390 if !bidirectionals_guard.iter().any(|m| m == &motif) {
391 bidirectionals_guard.push(motif);
392 }
393 }
394 }
395 });
396
397 bidirectionals.into_inner().unwrap()
398}
399
400#[allow(dead_code)]
403pub fn count_motif_frequencies<N, E, Ix>(graph: &Graph<N, E, Ix>) -> HashMap<MotifType, usize>
404where
405 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
406 E: EdgeWeight + Send + Sync,
407 Ix: IndexType + Send + Sync,
408{
409 use scirs2_core::parallel_ops::*;
410
411 let motif_types = vec![
412 MotifType::Triangle,
413 MotifType::Square,
414 MotifType::Star3,
415 MotifType::Clique4,
416 MotifType::Path3,
417 MotifType::BiFan,
418 MotifType::FeedForwardLoop,
419 MotifType::BiDirectional,
420 ];
421
422 motif_types
423 .par_iter()
424 .map(|motif_type| {
425 let count = find_motifs(graph, *motif_type).len();
426 (*motif_type, count)
427 })
428 .collect()
429}
430
431#[allow(dead_code)]
434pub fn sample_motif_frequencies<N, E, Ix>(
435 graph: &Graph<N, E, Ix>,
436 sample_size: usize,
437 rng: &mut impl rand::Rng,
438) -> HashMap<MotifType, f64>
439where
440 N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
441 E: EdgeWeight + Send + Sync,
442 Ix: IndexType + Send + Sync,
443{
444 use rand::seq::SliceRandom;
445
446 let all_nodes: Vec<_> = graph.nodes().into_iter().cloned().collect();
447 if all_nodes.len() <= sample_size {
448 return count_motif_frequencies(graph)
450 .into_iter()
451 .map(|(k, v)| (k, v as f64))
452 .collect();
453 }
454
455 let mut sampled_nodes = all_nodes.clone();
457 sampled_nodes.shuffle(rng);
458 sampled_nodes.truncate(sample_size);
459
460 let mut subgraph = crate::generators::create_graph::<N, E>();
462 for node in &sampled_nodes {
463 let _ = subgraph.add_node(node.clone());
464 }
465
466 for node1 in &sampled_nodes {
468 if let Ok(neighbors) = graph.neighbors(node1) {
469 for node2 in neighbors {
470 if sampled_nodes.contains(&node2) && node1 != &node2 {
471 if let Ok(weight) = graph.edge_weight(node1, &node2) {
472 let _ = subgraph.add_edge(node1.clone(), node2, weight);
473 }
474 }
475 }
476 }
477 }
478
479 let subgraph_counts = count_motif_frequencies(&subgraph);
481 let scaling_factor = (all_nodes.len() as f64) / (sample_size as f64);
482
483 subgraph_counts
484 .into_iter()
485 .map(|(motif_type, count)| (motif_type, count as f64 * scaling_factor))
486 .collect()
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::error::Result as GraphResult;
493 use crate::generators::create_graph;
494
495 #[test]
496 fn test_find_triangles() -> GraphResult<()> {
497 let mut graph = create_graph::<&str, ()>();
498
499 graph.add_edge("A", "B", ())?;
501 graph.add_edge("B", "C", ())?;
502 graph.add_edge("C", "A", ())?;
503
504 graph.add_edge("A", "D", ())?;
506
507 let triangles = find_motifs(&graph, MotifType::Triangle);
508 assert_eq!(triangles.len(), 1);
509
510 let triangle = &triangles[0];
512 assert_eq!(triangle.len(), 3);
513 assert!(triangle.contains(&"A"));
514 assert!(triangle.contains(&"B"));
515 assert!(triangle.contains(&"C"));
516
517 Ok(())
518 }
519
520 #[test]
521 fn test_find_squares() -> GraphResult<()> {
522 let mut graph = create_graph::<&str, ()>();
523
524 graph.add_edge("A", "B", ())?;
526 graph.add_edge("B", "C", ())?;
527 graph.add_edge("C", "D", ())?;
528 graph.add_edge("D", "A", ())?;
529
530 let squares = find_motifs(&graph, MotifType::Square);
531 assert_eq!(squares.len(), 1);
532
533 let square = &squares[0];
534 assert_eq!(square.len(), 4);
535
536 Ok(())
537 }
538
539 #[test]
540 fn test_find_star3() -> GraphResult<()> {
541 let mut graph = create_graph::<&str, ()>();
542
543 graph.add_edge("A", "B", ())?;
545 graph.add_edge("A", "C", ())?;
546 graph.add_edge("A", "D", ())?;
547
548 let stars = find_motifs(&graph, MotifType::Star3);
549 assert_eq!(stars.len(), 1);
550
551 let star = &stars[0];
552 assert_eq!(star.len(), 4);
553 assert!(star.contains(&"A")); Ok(())
556 }
557
558 #[test]
559 fn test_find_clique4() -> GraphResult<()> {
560 let mut graph = create_graph::<&str, ()>();
561
562 let nodes = ["A", "B", "C", "D"];
564 for i in 0..nodes.len() {
565 for j in i + 1..nodes.len() {
566 graph.add_edge(nodes[i], nodes[j], ())?;
567 }
568 }
569
570 let cliques = find_motifs(&graph, MotifType::Clique4);
571 assert_eq!(cliques.len(), 1);
572
573 let clique = &cliques[0];
574 assert_eq!(clique.len(), 4);
575
576 Ok(())
577 }
578}