uni_algo/algo/algorithms/
label_propagation.rs1use crate::algo::GraphProjection;
7use crate::algo::algorithms::Algorithm;
8use rand::prelude::*;
9use std::collections::HashMap;
10use uni_common::core::id::Vid;
11
12pub struct LabelPropagation;
13
14#[derive(Debug, Clone)]
16pub struct LabelPropagationConfig {
17 pub max_iterations: usize,
18 pub seed_property: Option<String>,
19 pub write: bool,
20 pub write_property: String,
21}
22
23impl Default for LabelPropagationConfig {
24 fn default() -> Self {
25 Self {
26 max_iterations: 10,
27 seed_property: None,
28 write: false,
29 write_property: "community".to_string(),
30 }
31 }
32}
33
34#[derive(Debug)]
36pub struct LabelPropagationResult {
37 pub communities: Vec<(Vid, u64)>,
39 pub iterations: usize,
41 pub converged: bool,
43}
44
45impl Algorithm for LabelPropagation {
46 type Config = LabelPropagationConfig;
47 type Result = LabelPropagationResult;
48
49 fn name() -> &'static str {
50 "labelPropagation"
51 }
52
53 fn needs_reverse() -> bool {
54 true
56 }
57
58 fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
59 let num_nodes = graph.vertex_count();
60 if num_nodes == 0 {
61 return LabelPropagationResult {
62 communities: Vec::new(),
63 iterations: 0,
64 converged: true,
65 };
66 }
67
68 let mut labels = vec![0u64; num_nodes];
69
70 for (i, label) in labels.iter_mut().enumerate().take(num_nodes) {
72 *label = graph.to_vid(i as u32).as_u64();
73 }
74
75 let mut converged = false;
76 let mut iterations = 0;
77 let mut node_indices: Vec<u32> = (0..num_nodes as u32).collect();
78 let mut rng = rand::thread_rng();
79
80 while iterations < config.max_iterations {
81 let mut changes = 0;
82
83 node_indices.shuffle(&mut rng);
85
86 for &node_idx in &node_indices {
87 let out_neighbors = graph.out_neighbors(node_idx);
89
90 let mut label_counts: HashMap<u64, usize> = HashMap::new();
91
92 for &neighbor_idx in out_neighbors {
93 let neighbor_label = labels[neighbor_idx as usize];
94 *label_counts.entry(neighbor_label).or_insert(0) += 1;
95 }
96
97 if graph.has_reverse() {
98 let in_neighbors = graph.in_neighbors(node_idx);
99 for &neighbor_idx in in_neighbors {
100 let neighbor_label = labels[neighbor_idx as usize];
101 *label_counts.entry(neighbor_label).or_insert(0) += 1;
102 }
103 }
104
105 if label_counts.is_empty() {
106 continue;
107 }
108
109 let mut max_count = 0;
111 for &count in label_counts.values() {
112 if count > max_count {
113 max_count = count;
114 }
115 }
116
117 let best_labels: Vec<u64> = label_counts
119 .iter()
120 .filter(|(_, count)| **count == max_count)
121 .map(|(label, _)| *label)
122 .collect();
123
124 let new_label = if best_labels.len() == 1 {
126 best_labels[0]
127 } else {
128 *best_labels.choose(&mut rng).unwrap()
129 };
130
131 if labels[node_idx as usize] != new_label {
132 labels[node_idx as usize] = new_label;
133 changes += 1;
134 }
135 }
136
137 iterations += 1;
138 if changes == 0 {
139 converged = true;
140 break;
141 }
142 }
143
144 let communities = labels
146 .into_iter()
147 .enumerate()
148 .map(|(slot, label)| (graph.to_vid(slot as u32), label))
149 .collect();
150
151 LabelPropagationResult {
152 communities,
153 iterations,
154 converged,
155 }
156 }
157}