Skip to main content

uni_algo/algo/algorithms/
random_walk.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Random Walk Algorithm.
5
6use crate::algo::GraphProjection;
7use crate::algo::algorithms::Algorithm;
8use rand::prelude::*;
9use rayon::prelude::*;
10use std::sync::Mutex;
11use uni_common::core::id::Vid;
12
13pub struct RandomWalk;
14
15#[derive(Debug, Clone, Default)]
16pub struct RandomWalkConfig {
17    pub walk_length: usize,
18    pub walks_per_node: usize,
19    pub start_nodes: Vec<Vid>, // If empty, all nodes
20    pub return_param: f64,     // p (1/p)
21    pub in_out_param: f64,     // q (1/q)
22}
23
24pub struct RandomWalkResult {
25    pub walks: Vec<Vec<Vid>>,
26}
27
28impl Algorithm for RandomWalk {
29    type Config = RandomWalkConfig;
30    type Result = RandomWalkResult;
31
32    fn name() -> &'static str {
33        "randomWalk"
34    }
35
36    fn needs_reverse() -> bool {
37        // Only if we need to know previous neighbors for Node2Vec (checking if neighbor is connected to prev)
38        // For simple random walk, no.
39        // Implementing simple random walk first.
40        false
41    }
42
43    fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
44        let n = graph.vertex_count();
45        if n == 0 {
46            return RandomWalkResult { walks: Vec::new() };
47        }
48
49        let start_slots: Vec<u32> = if config.start_nodes.is_empty() {
50            (0..n as u32).collect()
51        } else {
52            config
53                .start_nodes
54                .iter()
55                .filter_map(|&vid| graph.to_slot(vid))
56                .collect()
57        };
58
59        let mut walks = Vec::new();
60        let walks_mutex = Mutex::new(&mut walks);
61
62        // Chunking to avoid massive mutex contention or single result vector resizing
63        start_slots.par_iter().for_each(|&start_node| {
64            let mut local_walks = Vec::with_capacity(config.walks_per_node);
65            let mut rng = rand::thread_rng();
66
67            for _ in 0..config.walks_per_node {
68                let mut walk = Vec::with_capacity(config.walk_length + 1);
69                let mut curr = start_node;
70                walk.push(graph.to_vid(curr));
71
72                for _ in 0..config.walk_length {
73                    let neighbors = graph.out_neighbors(curr);
74                    if neighbors.is_empty() {
75                        break;
76                    }
77                    // Simple uniform random walk
78                    let next = neighbors.choose(&mut rng).unwrap();
79                    curr = *next;
80                    walk.push(graph.to_vid(curr));
81                }
82                local_walks.push(walk);
83            }
84
85            let mut guard = walks_mutex.lock().unwrap_or_else(|e| e.into_inner());
86            guard.extend(local_walks);
87        });
88
89        RandomWalkResult { walks }
90    }
91}