scirs2_graph/algorithms/
motifs.rs

1//! Graph motif finding algorithms
2//!
3//! This module contains algorithms for finding small recurring subgraph patterns (motifs).
4
5use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use std::collections::{HashMap, HashSet};
7use std::hash::Hash;
8
9/// Finds all occurrences of a specific motif pattern in a graph
10///
11/// A motif is a small recurring subgraph pattern. This function finds all
12/// instances of common motifs like triangles, squares, or stars.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum MotifType {
15    /// Triangle (3-cycle)
16    Triangle,
17    /// Square (4-cycle)
18    Square,
19    /// Star with 3 leaves
20    Star3,
21    /// Clique of size 4
22    Clique4,
23    /// Path of length 3 (4 nodes)
24    Path3,
25    /// Bi-fan motif (2 nodes connected to 2 other nodes)
26    BiFan,
27    /// Feed-forward loop
28    FeedForwardLoop,
29    /// Bi-directional motif
30    BiDirectional,
31}
32
33/// Find all occurrences of a specified motif in the graph
34#[allow(dead_code)]
35pub fn find_motifs<N, E, Ix>(graph: &Graph<N, E, Ix>, motiftype: MotifType) -> Vec<Vec<N>>
36where
37    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
38    E: EdgeWeight + Send + Sync,
39    Ix: IndexType + Send + Sync,
40{
41    match motiftype {
42        MotifType::Triangle => find_triangles(graph),
43        MotifType::Square => find_squares(graph),
44        MotifType::Star3 => find_star3s(graph),
45        MotifType::Clique4 => find_clique4s(graph),
46        MotifType::Path3 => find_path3s(graph),
47        MotifType::BiFan => find_bi_fans(graph),
48        MotifType::FeedForwardLoop => find_feed_forward_loops(graph),
49        MotifType::BiDirectional => find_bidirectional_motifs(graph),
50    }
51}
52
53#[allow(dead_code)]
54fn find_triangles<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
55where
56    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
57    E: EdgeWeight + Send + Sync,
58    Ix: IndexType + Send + Sync,
59{
60    use scirs2_core::parallel_ops::*;
61    use std::sync::Mutex;
62
63    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
64    let triangles = Mutex::new(Vec::new());
65
66    // Parallel triangle finding using edge-based approach for better performance
67    nodes.par_iter().enumerate().for_each(|(_i, node_i)| {
68        if let Ok(neighbors_i) = graph.neighbors(node_i) {
69            let neighbors_i: Vec<_> = neighbors_i;
70
71            for (j, node_j) in neighbors_i.iter().enumerate() {
72                for node_k in neighbors_i.iter().skip(j + 1) {
73                    if graph.has_edge(node_j, node_k) {
74                        let mut triangles_guard = triangles.lock().unwrap();
75                        let mut triangle = vec![node_i.clone(), node_j.clone(), node_k.clone()];
76                        triangle.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
77
78                        // Avoid duplicates
79                        if !triangles_guard.iter().any(|t| t == &triangle) {
80                            triangles_guard.push(triangle);
81                        }
82                    }
83                }
84            }
85        }
86    });
87
88    triangles.into_inner().unwrap()
89}
90
91#[allow(dead_code)]
92fn find_squares<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
93where
94    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
95    E: EdgeWeight + Send + Sync,
96    Ix: IndexType + Send + Sync,
97{
98    let mut squares = Vec::new();
99    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
100
101    // For each quadruplet of nodes, check if they form a square
102    for i in 0..nodes.len() {
103        for j in i + 1..nodes.len() {
104            if !graph.has_edge(&nodes[i], &nodes[j]) {
105                continue;
106            }
107            for k in j + 1..nodes.len() {
108                if !graph.has_edge(&nodes[j], &nodes[k]) {
109                    continue;
110                }
111                for l in k + 1..nodes.len() {
112                    if graph.has_edge(&nodes[k], &nodes[l])
113                        && graph.has_edge(&nodes[l], &nodes[i])
114                        && !graph.has_edge(&nodes[i], &nodes[k])
115                        && !graph.has_edge(&nodes[j], &nodes[l])
116                    {
117                        squares.push(vec![
118                            nodes[i].clone(),
119                            nodes[j].clone(),
120                            nodes[k].clone(),
121                            nodes[l].clone(),
122                        ]);
123                    }
124                }
125            }
126        }
127    }
128
129    squares
130}
131
132#[allow(dead_code)]
133fn find_star3s<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
134where
135    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
136    E: EdgeWeight + Send + Sync,
137    Ix: IndexType + Send + Sync,
138{
139    let mut stars = Vec::new();
140    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
141
142    // For each node as center, find if it has exactly 3 neighbors that aren't connected
143    for center in &nodes {
144        if let Ok(neighbors) = graph.neighbors(center) {
145            let neighbor_list: Vec<N> = neighbors;
146
147            if neighbor_list.len() >= 3 {
148                // Check all combinations of 3 neighbors
149                for i in 0..neighbor_list.len() {
150                    for j in i + 1..neighbor_list.len() {
151                        for k in j + 1..neighbor_list.len() {
152                            // Check that the neighbors aren't connected to each other
153                            if !graph.has_edge(&neighbor_list[i], &neighbor_list[j])
154                                && !graph.has_edge(&neighbor_list[j], &neighbor_list[k])
155                                && !graph.has_edge(&neighbor_list[i], &neighbor_list[k])
156                            {
157                                stars.push(vec![
158                                    center.clone(),
159                                    neighbor_list[i].clone(),
160                                    neighbor_list[j].clone(),
161                                    neighbor_list[k].clone(),
162                                ]);
163                            }
164                        }
165                    }
166                }
167            }
168        }
169    }
170
171    stars
172}
173
174#[allow(dead_code)]
175fn find_clique4s<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
176where
177    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
178    E: EdgeWeight + Send + Sync,
179    Ix: IndexType + Send + Sync,
180{
181    let mut cliques = Vec::new();
182    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
183
184    // For each quadruplet of nodes, check if they form a complete graph
185    for i in 0..nodes.len() {
186        for j in i + 1..nodes.len() {
187            if !graph.has_edge(&nodes[i], &nodes[j]) {
188                continue;
189            }
190            for k in j + 1..nodes.len() {
191                if !graph.has_edge(&nodes[i], &nodes[k]) || !graph.has_edge(&nodes[j], &nodes[k]) {
192                    continue;
193                }
194                for l in k + 1..nodes.len() {
195                    if graph.has_edge(&nodes[i], &nodes[l])
196                        && graph.has_edge(&nodes[j], &nodes[l])
197                        && graph.has_edge(&nodes[k], &nodes[l])
198                    {
199                        cliques.push(vec![
200                            nodes[i].clone(),
201                            nodes[j].clone(),
202                            nodes[k].clone(),
203                            nodes[l].clone(),
204                        ]);
205                    }
206                }
207            }
208        }
209    }
210
211    cliques
212}
213
214/// Find all path motifs of length 3 (4 nodes in a line)
215#[allow(dead_code)]
216fn find_path3s<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
217where
218    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
219    E: EdgeWeight + Send + Sync,
220    Ix: IndexType + Send + Sync,
221{
222    use scirs2_core::parallel_ops::*;
223    use std::sync::Mutex;
224
225    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
226    let paths = Mutex::new(Vec::new());
227
228    nodes.par_iter().for_each(|start_node| {
229        if let Ok(neighbors1) = graph.neighbors(start_node) {
230            for middle1 in neighbors1 {
231                if let Ok(neighbors2) = graph.neighbors(&middle1) {
232                    for middle2 in neighbors2 {
233                        if middle2 == *start_node {
234                            continue;
235                        }
236
237                        if let Ok(neighbors3) = graph.neighbors(&middle2) {
238                            for end_node in neighbors3 {
239                                if end_node == middle1 || end_node == *start_node {
240                                    continue;
241                                }
242
243                                // Check it's a path (no shortcuts)
244                                if !graph.has_edge(start_node, &middle2)
245                                    && !graph.has_edge(start_node, &end_node)
246                                    && !graph.has_edge(&middle1, &end_node)
247                                {
248                                    let mut path = vec![
249                                        start_node.clone(),
250                                        middle1.clone(),
251                                        middle2.clone(),
252                                        end_node.clone(),
253                                    ];
254                                    path.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
255
256                                    let mut paths_guard = paths.lock().unwrap();
257                                    if !paths_guard.iter().any(|p| p == &path) {
258                                        paths_guard.push(path);
259                                    }
260                                }
261                            }
262                        }
263                    }
264                }
265            }
266        }
267    });
268
269    paths.into_inner().unwrap()
270}
271
272/// Find bi-fan motifs (2 nodes connected to the same 2 other nodes)
273#[allow(dead_code)]
274fn find_bi_fans<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
275where
276    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
277    E: EdgeWeight + Send + Sync,
278    Ix: IndexType + Send + Sync,
279{
280    use scirs2_core::parallel_ops::*;
281    use std::sync::Mutex;
282
283    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
284    let bi_fans = Mutex::new(Vec::new());
285
286    nodes.par_iter().enumerate().for_each(|(i, node1)| {
287        for node2 in nodes.iter().skip(i + 1) {
288            if let (Ok(neighbors1), Ok(neighbors2)) =
289                (graph.neighbors(node1), graph.neighbors(node2))
290            {
291                let neighbors1: HashSet<_> = neighbors1.into_iter().collect();
292                let neighbors2: HashSet<_> = neighbors2.into_iter().collect();
293
294                // Find common neighbors (excluding node1 and node2)
295                let common: Vec<_> = neighbors1
296                    .intersection(&neighbors2)
297                    .filter(|&n| n != node1 && n != node2)
298                    .cloned()
299                    .collect();
300
301                if common.len() >= 2 {
302                    // For each pair of common neighbors, create a bi-fan
303                    for (j, fan1) in common.iter().enumerate() {
304                        for fan2 in common.iter().skip(j + 1) {
305                            let mut bi_fan =
306                                vec![node1.clone(), node2.clone(), fan1.clone(), fan2.clone()];
307                            bi_fan.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
308
309                            let mut bi_fans_guard = bi_fans.lock().unwrap();
310                            if !bi_fans_guard.iter().any(|bf| bf == &bi_fan) {
311                                bi_fans_guard.push(bi_fan);
312                            }
313                        }
314                    }
315                }
316            }
317        }
318    });
319
320    bi_fans.into_inner().unwrap()
321}
322
323/// Find feed-forward loop motifs (3 nodes with specific directed pattern)
324#[allow(dead_code)]
325fn find_feed_forward_loops<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
326where
327    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
328    E: EdgeWeight + Send + Sync,
329    Ix: IndexType + Send + Sync,
330{
331    use scirs2_core::parallel_ops::*;
332    use std::sync::Mutex;
333
334    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
335    let ffls = Mutex::new(Vec::new());
336
337    // Feed-forward loop: A->B, A->C, B->C (but not A<-B, A<-C, B<-C)
338    nodes.par_iter().for_each(|node_a| {
339        if let Ok(out_neighbors_a) = graph.neighbors(node_a) {
340            let out_neighbors_a: Vec<_> = out_neighbors_a;
341
342            for (i, node_b) in out_neighbors_a.iter().enumerate() {
343                for node_c in out_neighbors_a.iter().skip(i + 1) {
344                    // Check if B->C exists and no back edges exist
345                    if graph.has_edge(node_b, node_c) {
346                        // Ensure it's a true feed-forward (no cycles back)
347                        if !graph.has_edge(node_b, node_a)
348                            && !graph.has_edge(node_c, node_a)
349                            && !graph.has_edge(node_c, node_b)
350                        {
351                            let mut ffl = vec![node_a.clone(), node_b.clone(), node_c.clone()];
352                            ffl.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
353
354                            let mut ffls_guard = ffls.lock().unwrap();
355                            if !ffls_guard.iter().any(|f| f == &ffl) {
356                                ffls_guard.push(ffl);
357                            }
358                        }
359                    }
360                }
361            }
362        }
363    });
364
365    ffls.into_inner().unwrap()
366}
367
368/// Find bi-directional motifs (mutual connections between pairs of nodes)
369#[allow(dead_code)]
370fn find_bidirectional_motifs<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Vec<Vec<N>>
371where
372    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
373    E: EdgeWeight + Send + Sync,
374    Ix: IndexType + Send + Sync,
375{
376    use scirs2_core::parallel_ops::*;
377    use std::sync::Mutex;
378
379    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
380    let bidirectionals = Mutex::new(Vec::new());
381
382    nodes.par_iter().enumerate().for_each(|(i, node1)| {
383        for node2 in nodes.iter().skip(i + 1) {
384            // Check for bidirectional connection
385            if graph.has_edge(node1, node2) && graph.has_edge(node2, node1) {
386                let mut motif = vec![node1.clone(), node2.clone()];
387                motif.sort_by(|a, b| format!("{a:?}").cmp(&format!("{b:?}")));
388
389                let mut bidirectionals_guard = bidirectionals.lock().unwrap();
390                if !bidirectionals_guard.iter().any(|m| m == &motif) {
391                    bidirectionals_guard.push(motif);
392                }
393            }
394        }
395    });
396
397    bidirectionals.into_inner().unwrap()
398}
399
400/// Advanced motif counting with frequency analysis
401/// Returns a map of motif patterns to their occurrence counts
402#[allow(dead_code)]
403pub fn count_motif_frequencies<N, E, Ix>(graph: &Graph<N, E, Ix>) -> HashMap<MotifType, usize>
404where
405    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
406    E: EdgeWeight + Send + Sync,
407    Ix: IndexType + Send + Sync,
408{
409    use scirs2_core::parallel_ops::*;
410
411    let motif_types = vec![
412        MotifType::Triangle,
413        MotifType::Square,
414        MotifType::Star3,
415        MotifType::Clique4,
416        MotifType::Path3,
417        MotifType::BiFan,
418        MotifType::FeedForwardLoop,
419        MotifType::BiDirectional,
420    ];
421
422    motif_types
423        .par_iter()
424        .map(|motif_type| {
425            let count = find_motifs(graph, *motif_type).len();
426            (*motif_type, count)
427        })
428        .collect()
429}
430
431/// Efficient motif detection using sampling for large graphs
432/// Returns estimated motif counts based on random sampling
433#[allow(dead_code)]
434pub fn sample_motif_frequencies<N, E, Ix>(
435    graph: &Graph<N, E, Ix>,
436    sample_size: usize,
437    rng: &mut impl rand::Rng,
438) -> HashMap<MotifType, f64>
439where
440    N: Node + Clone + Hash + Eq + std::fmt::Debug + Send + Sync,
441    E: EdgeWeight + Send + Sync,
442    Ix: IndexType + Send + Sync,
443{
444    use rand::seq::SliceRandom;
445
446    let all_nodes: Vec<_> = graph.nodes().into_iter().cloned().collect();
447    if all_nodes.len() <= sample_size {
448        // If graph is small, do exact counting
449        return count_motif_frequencies(graph)
450            .into_iter()
451            .map(|(k, v)| (k, v as f64))
452            .collect();
453    }
454
455    // Sample nodes
456    let mut sampled_nodes = all_nodes.clone();
457    sampled_nodes.shuffle(rng);
458    sampled_nodes.truncate(sample_size);
459
460    // Create subgraph from sampled nodes
461    let mut subgraph = crate::generators::create_graph::<N, E>();
462    for node in &sampled_nodes {
463        let _ = subgraph.add_node(node.clone());
464    }
465
466    // Add edges between sampled nodes
467    for node1 in &sampled_nodes {
468        if let Ok(neighbors) = graph.neighbors(node1) {
469            for node2 in neighbors {
470                if sampled_nodes.contains(&node2) && node1 != &node2 {
471                    if let Ok(weight) = graph.edge_weight(node1, &node2) {
472                        let _ = subgraph.add_edge(node1.clone(), node2, weight);
473                    }
474                }
475            }
476        }
477    }
478
479    // Count motifs in subgraph and extrapolate
480    let subgraph_counts = count_motif_frequencies(&subgraph);
481    let scaling_factor = (all_nodes.len() as f64) / (sample_size as f64);
482
483    subgraph_counts
484        .into_iter()
485        .map(|(motif_type, count)| (motif_type, count as f64 * scaling_factor))
486        .collect()
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use crate::error::Result as GraphResult;
493    use crate::generators::create_graph;
494
495    #[test]
496    fn test_find_triangles() -> GraphResult<()> {
497        let mut graph = create_graph::<&str, ()>();
498
499        // Create a triangle ABC
500        graph.add_edge("A", "B", ())?;
501        graph.add_edge("B", "C", ())?;
502        graph.add_edge("C", "A", ())?;
503
504        // Add another node D connected to A (not forming new triangles)
505        graph.add_edge("A", "D", ())?;
506
507        let triangles = find_motifs(&graph, MotifType::Triangle);
508        assert_eq!(triangles.len(), 1);
509
510        // The triangle should contain A, B, and C
511        let triangle = &triangles[0];
512        assert_eq!(triangle.len(), 3);
513        assert!(triangle.contains(&"A"));
514        assert!(triangle.contains(&"B"));
515        assert!(triangle.contains(&"C"));
516
517        Ok(())
518    }
519
520    #[test]
521    fn test_find_squares() -> GraphResult<()> {
522        let mut graph = create_graph::<&str, ()>();
523
524        // Create a square ABCD
525        graph.add_edge("A", "B", ())?;
526        graph.add_edge("B", "C", ())?;
527        graph.add_edge("C", "D", ())?;
528        graph.add_edge("D", "A", ())?;
529
530        let squares = find_motifs(&graph, MotifType::Square);
531        assert_eq!(squares.len(), 1);
532
533        let square = &squares[0];
534        assert_eq!(square.len(), 4);
535
536        Ok(())
537    }
538
539    #[test]
540    fn test_find_star3() -> GraphResult<()> {
541        let mut graph = create_graph::<&str, ()>();
542
543        // Create a star with center A and leaves B, C, D
544        graph.add_edge("A", "B", ())?;
545        graph.add_edge("A", "C", ())?;
546        graph.add_edge("A", "D", ())?;
547
548        let stars = find_motifs(&graph, MotifType::Star3);
549        assert_eq!(stars.len(), 1);
550
551        let star = &stars[0];
552        assert_eq!(star.len(), 4);
553        assert!(star.contains(&"A")); // Center should be included
554
555        Ok(())
556    }
557
558    #[test]
559    fn test_find_clique4() -> GraphResult<()> {
560        let mut graph = create_graph::<&str, ()>();
561
562        // Create a complete graph K4
563        let nodes = ["A", "B", "C", "D"];
564        for i in 0..nodes.len() {
565            for j in i + 1..nodes.len() {
566                graph.add_edge(nodes[i], nodes[j], ())?;
567            }
568        }
569
570        let cliques = find_motifs(&graph, MotifType::Clique4);
571        assert_eq!(cliques.len(), 1);
572
573        let clique = &cliques[0];
574        assert_eq!(clique.len(), 4);
575
576        Ok(())
577    }
578}