Skip to main content

uni_algo/algo/algorithms/
node_similarity.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Node Similarity Algorithm.
5
6use crate::algo::GraphProjection;
7use crate::algo::algorithms::Algorithm;
8use fxhash::FxHashMap;
9use rayon::prelude::*;
10use std::sync::Mutex;
11use uni_common::core::id::Vid;
12
13pub struct NodeSimilarity;
14
15#[derive(Debug, Clone)]
16pub struct NodeSimilarityConfig {
17    pub similarity_metric: SimilarityMetric,
18    pub similarity_cutoff: f64,
19    pub top_k: usize,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum SimilarityMetric {
24    Jaccard,
25    Overlap,
26    Cosine,
27}
28
29impl Default for NodeSimilarityConfig {
30    fn default() -> Self {
31        Self {
32            similarity_metric: SimilarityMetric::Jaccard,
33            similarity_cutoff: 0.1,
34            top_k: 10,
35        }
36    }
37}
38
39pub struct NodeSimilarityResult {
40    pub similar_pairs: Vec<(Vid, Vid, f64)>,
41}
42
43impl Algorithm for NodeSimilarity {
44    type Config = NodeSimilarityConfig;
45    type Result = NodeSimilarityResult;
46
47    fn name() -> &'static str {
48        "nodeSimilarity"
49    }
50
51    fn needs_reverse() -> bool {
52        true
53    }
54
55    fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
56        let n = graph.vertex_count();
57        if n == 0 {
58            return NodeSimilarityResult {
59                similar_pairs: Vec::new(),
60            };
61        }
62
63        // We compute similarity based on OUTGOING neighbors.
64        // Two nodes are similar if they point to the same targets.
65        // Intersection is computed by iterating over TARGETS and looking at their INCOMING neighbors.
66
67        // Map (u, v) -> intersection_count
68        // Using partitioned approach to allow parallelism and reduce contention?
69        // Or just straightforward MapReduce.
70
71        // Naive approach: Map<(u32, u32), u32> can become huge.
72        // Better: Process per-node?
73        // No, processing by target is efficient for intersection.
74
75        // Let's implement a chunked approach or simple parallel accumulation if memory permits.
76        // Given we are in-memory, we assume we can fit `E * avg_degree` roughly?
77        // Let's use a concurrent map (DashMap would be good, but we use Mutex<HashMap> for std).
78
79        // Optimization: iterate source nodes `u`. For each `u`, collect neighbors `N(u)`.
80        // Then for each `n` in `N(u)`, find other `v` in `in_neighbors(n)`.
81        // Accumulate `intersection[v]`.
82        // Compute similarity for `u` vs all `v`. Keep TopK for `u`.
83
84        // This is O(V * D * D_in).
85
86        let mut results = Vec::new();
87        let results_mutex = Mutex::new(&mut results);
88
89        (0..n as u32).into_par_iter().for_each(|u| {
90            let u_out = graph.out_neighbors(u);
91            let degree_u = u_out.len() as f64;
92
93            if degree_u == 0.0 {
94                return;
95            }
96
97            let mut intersections: FxHashMap<u32, u32> = FxHashMap::default();
98
99            for &target in u_out {
100                for &v in graph.in_neighbors(target) {
101                    if v != u {
102                        *intersections.entry(v).or_insert(0) += 1;
103                    }
104                }
105            }
106
107            let mut local_results = Vec::new();
108
109            for (v, count) in intersections {
110                let degree_v = graph.out_degree(v) as f64;
111                let overlap = count as f64;
112
113                let score = match config.similarity_metric {
114                    SimilarityMetric::Jaccard => overlap / (degree_u + degree_v - overlap),
115                    SimilarityMetric::Overlap => overlap / f64::min(degree_u, degree_v),
116                    SimilarityMetric::Cosine => overlap / (degree_u * degree_v).sqrt(),
117                };
118
119                if score >= config.similarity_cutoff {
120                    local_results.push((v, score));
121                }
122            }
123
124            // Top K
125            local_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
126            local_results.truncate(config.top_k);
127
128            if !local_results.is_empty() {
129                let mut guard = results_mutex
130                    .lock()
131                    .expect("Results mutex poisoned - a thread panicked while holding it");
132                let u_vid = graph.to_vid(u);
133                for (v, score) in local_results {
134                    guard.push((u_vid, graph.to_vid(v), score));
135                }
136            }
137        });
138
139        NodeSimilarityResult {
140            similar_pairs: results,
141        }
142    }
143}