Skip to main content

uni_algo/algo/algorithms/
label_propagation.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Label Propagation Community Detection Algorithm.
5
6use crate::algo::GraphProjection;
7use crate::algo::algorithms::Algorithm;
8use rand::prelude::*;
9use std::collections::HashMap;
10use uni_common::core::id::Vid;
11
12pub struct LabelPropagation;
13
14/// Configuration for Label Propagation.
15#[derive(Debug, Clone)]
16pub struct LabelPropagationConfig {
17    pub max_iterations: usize,
18    pub seed_property: Option<String>,
19    pub write: bool,
20    pub write_property: String,
21}
22
23impl Default for LabelPropagationConfig {
24    fn default() -> Self {
25        Self {
26            max_iterations: 10,
27            seed_property: None,
28            write: false,
29            write_property: "community".to_string(),
30        }
31    }
32}
33
34/// Result of Label Propagation.
35#[derive(Debug)]
36pub struct LabelPropagationResult {
37    /// Community ID for each node (VID, CommunityID).
38    pub communities: Vec<(Vid, u64)>,
39    /// Number of iterations executed.
40    pub iterations: usize,
41    /// Whether the algorithm converged.
42    pub converged: bool,
43}
44
45impl Algorithm for LabelPropagation {
46    type Config = LabelPropagationConfig;
47    type Result = LabelPropagationResult;
48
49    fn name() -> &'static str {
50        "labelPropagation"
51    }
52
53    fn needs_reverse() -> bool {
54        // We typically treat graph as undirected for communities
55        true
56    }
57
58    fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
59        let num_nodes = graph.vertex_count();
60        if num_nodes == 0 {
61            return LabelPropagationResult {
62                communities: Vec::new(),
63                iterations: 0,
64                converged: true,
65            };
66        }
67
68        let mut labels = vec![0u64; num_nodes];
69
70        // 1. Initialize labels with VID
71        for (i, label) in labels.iter_mut().enumerate().take(num_nodes) {
72            *label = graph.to_vid(i as u32).as_u64();
73        }
74
75        let mut converged = false;
76        let mut iterations = 0;
77        let mut node_indices: Vec<u32> = (0..num_nodes as u32).collect();
78        let mut rng = rand::thread_rng();
79
80        while iterations < config.max_iterations {
81            let mut changes = 0;
82
83            // Shuffle processing order to prevent oscillation
84            node_indices.shuffle(&mut rng);
85
86            for &node_idx in &node_indices {
87                // Collect all neighbors (undirected view if reverse edges present)
88                let out_neighbors = graph.out_neighbors(node_idx);
89
90                let mut label_counts: HashMap<u64, usize> = HashMap::new();
91
92                for &neighbor_idx in out_neighbors {
93                    let neighbor_label = labels[neighbor_idx as usize];
94                    *label_counts.entry(neighbor_label).or_insert(0) += 1;
95                }
96
97                if graph.has_reverse() {
98                    let in_neighbors = graph.in_neighbors(node_idx);
99                    for &neighbor_idx in in_neighbors {
100                        let neighbor_label = labels[neighbor_idx as usize];
101                        *label_counts.entry(neighbor_label).or_insert(0) += 1;
102                    }
103                }
104
105                if label_counts.is_empty() {
106                    continue;
107                }
108
109                // Find max frequency
110                let mut max_count = 0;
111                for &count in label_counts.values() {
112                    if count > max_count {
113                        max_count = count;
114                    }
115                }
116
117                // Collect best labels (ties)
118                let best_labels: Vec<u64> = label_counts
119                    .iter()
120                    .filter(|(_, count)| **count == max_count)
121                    .map(|(label, _)| *label)
122                    .collect();
123
124                // Pick one randomly
125                let new_label = if best_labels.len() == 1 {
126                    best_labels[0]
127                } else {
128                    *best_labels.choose(&mut rng).unwrap()
129                };
130
131                if labels[node_idx as usize] != new_label {
132                    labels[node_idx as usize] = new_label;
133                    changes += 1;
134                }
135            }
136
137            iterations += 1;
138            if changes == 0 {
139                converged = true;
140                break;
141            }
142        }
143
144        // Map results back to VIDs
145        let communities = labels
146            .into_iter()
147            .enumerate()
148            .map(|(slot, label)| (graph.to_vid(slot as u32), label))
149            .collect();
150
151        LabelPropagationResult {
152            communities,
153            iterations,
154            converged,
155        }
156    }
157}