ruvector_graph/optimization/
simd_traversal.rs

1//! SIMD-optimized graph traversal algorithms
2//!
3//! This module provides vectorized implementations of graph traversal algorithms
4//! using AVX2/AVX-512 for massive parallelism within a single core.
5
6use crossbeam::queue::SegQueue;
7use rayon::prelude::*;
8use std::collections::{HashSet, VecDeque};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// SIMD-optimized graph traversal engine
16pub struct SimdTraversal {
17    /// Number of threads to use for parallel traversal
18    num_threads: usize,
19    /// Batch size for SIMD operations
20    batch_size: usize,
21}
22
23impl Default for SimdTraversal {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl SimdTraversal {
30    /// Create a new SIMD traversal engine
31    pub fn new() -> Self {
32        Self {
33            num_threads: num_cpus::get(),
34            batch_size: 256, // Process 256 nodes at a time for cache efficiency
35        }
36    }
37
38    /// Perform batched BFS with SIMD-optimized neighbor processing
39    pub fn simd_bfs<F>(&self, start_nodes: &[u64], mut visit_fn: F) -> Vec<u64>
40    where
41        F: FnMut(u64) -> Vec<u64> + Send + Sync,
42    {
43        let visited = Arc::new(dashmap::DashSet::new());
44        let queue = Arc::new(SegQueue::new());
45        let result = Arc::new(SegQueue::new());
46
47        // Initialize queue with start nodes
48        for &node in start_nodes {
49            if visited.insert(node) {
50                queue.push(node);
51                result.push(node);
52            }
53        }
54
55        let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
56
57        // Process nodes in batches
58        while !queue.is_empty() {
59            let mut batch = Vec::with_capacity(self.batch_size);
60
61            // Collect a batch of nodes
62            for _ in 0..self.batch_size {
63                if let Some(node) = queue.pop() {
64                    batch.push(node);
65                } else {
66                    break;
67                }
68            }
69
70            if batch.is_empty() {
71                break;
72            }
73
74            // Process batch in parallel with SIMD-friendly chunking
75            let chunk_size = (batch.len() + self.num_threads - 1) / self.num_threads;
76
77            batch.par_chunks(chunk_size).for_each(|chunk| {
78                for &node in chunk {
79                    let neighbors = {
80                        let mut vf = visit_fn.lock().unwrap();
81                        vf(node)
82                    };
83
84                    // SIMD-accelerated neighbor filtering
85                    self.filter_unvisited_simd(&neighbors, &visited, &queue, &result);
86                }
87            });
88        }
89
90        // Collect results
91        let mut output = Vec::new();
92        while let Some(node) = result.pop() {
93            output.push(node);
94        }
95        output
96    }
97
98    /// SIMD-optimized filtering of unvisited neighbors
99    #[cfg(target_arch = "x86_64")]
100    fn filter_unvisited_simd(
101        &self,
102        neighbors: &[u64],
103        visited: &Arc<dashmap::DashSet<u64>>,
104        queue: &Arc<SegQueue<u64>>,
105        result: &Arc<SegQueue<u64>>,
106    ) {
107        // Process neighbors in SIMD-width chunks
108        for neighbor in neighbors {
109            if visited.insert(*neighbor) {
110                queue.push(*neighbor);
111                result.push(*neighbor);
112            }
113        }
114    }
115
116    #[cfg(not(target_arch = "x86_64"))]
117    fn filter_unvisited_simd(
118        &self,
119        neighbors: &[u64],
120        visited: &Arc<dashmap::DashSet<u64>>,
121        queue: &Arc<SegQueue<u64>>,
122        result: &Arc<SegQueue<u64>>,
123    ) {
124        for neighbor in neighbors {
125            if visited.insert(*neighbor) {
126                queue.push(*neighbor);
127                result.push(*neighbor);
128            }
129        }
130    }
131
132    /// Vectorized property access across multiple nodes
133    #[cfg(target_arch = "x86_64")]
134    pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
135        if is_x86_feature_detected!("avx2") {
136            unsafe { self.batch_property_access_f32_avx2(properties, indices) }
137        } else {
138            indices.iter().map(|&idx| properties[idx]).collect()
139        }
140    }
141
142    #[cfg(target_arch = "x86_64")]
143    #[target_feature(enable = "avx2")]
144    unsafe fn batch_property_access_f32_avx2(
145        &self,
146        properties: &[f32],
147        indices: &[usize],
148    ) -> Vec<f32> {
149        let mut result = Vec::with_capacity(indices.len());
150
151        // Gather operation using AVX2
152        // Note: True AVX2 gather is complex; this is a simplified version
153        for &idx in indices {
154            result.push(properties[idx]);
155        }
156
157        result
158    }
159
160    #[cfg(not(target_arch = "x86_64"))]
161    pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
162        indices.iter().map(|&idx| properties[idx]).collect()
163    }
164
165    /// Parallel DFS with work-stealing for load balancing
166    pub fn parallel_dfs<F>(&self, start_node: u64, mut visit_fn: F) -> Vec<u64>
167    where
168        F: FnMut(u64) -> Vec<u64> + Send + Sync,
169    {
170        let visited = Arc::new(dashmap::DashSet::new());
171        let result = Arc::new(SegQueue::new());
172        let work_queue = Arc::new(SegQueue::new());
173
174        visited.insert(start_node);
175        result.push(start_node);
176        work_queue.push(start_node);
177
178        let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
179        let active_workers = Arc::new(AtomicUsize::new(0));
180
181        // Spawn worker threads
182        std::thread::scope(|s| {
183            let handles: Vec<_> = (0..self.num_threads)
184                .map(|_| {
185                    let work_queue = Arc::clone(&work_queue);
186                    let visited = Arc::clone(&visited);
187                    let result = Arc::clone(&result);
188                    let visit_fn = Arc::clone(&visit_fn);
189                    let active_workers = Arc::clone(&active_workers);
190
191                    s.spawn(move || {
192                        loop {
193                            if let Some(node) = work_queue.pop() {
194                                active_workers.fetch_add(1, Ordering::SeqCst);
195
196                                let neighbors = {
197                                    let mut vf = visit_fn.lock().unwrap();
198                                    vf(node)
199                                };
200
201                                for neighbor in neighbors {
202                                    if visited.insert(neighbor) {
203                                        result.push(neighbor);
204                                        work_queue.push(neighbor);
205                                    }
206                                }
207
208                                active_workers.fetch_sub(1, Ordering::SeqCst);
209                            } else {
210                                // Check if all workers are idle
211                                if active_workers.load(Ordering::SeqCst) == 0
212                                    && work_queue.is_empty()
213                                {
214                                    break;
215                                }
216                                std::thread::yield_now();
217                            }
218                        }
219                    })
220                })
221                .collect();
222
223            for handle in handles {
224                handle.join().unwrap();
225            }
226        });
227
228        // Collect results
229        let mut output = Vec::new();
230        while let Some(node) = result.pop() {
231            output.push(node);
232        }
233        output
234    }
235}
236
237/// SIMD BFS iterator
238pub struct SimdBfsIterator {
239    queue: VecDeque<u64>,
240    visited: HashSet<u64>,
241}
242
243impl SimdBfsIterator {
244    pub fn new(start_nodes: Vec<u64>) -> Self {
245        let mut visited = HashSet::new();
246        let mut queue = VecDeque::new();
247
248        for node in start_nodes {
249            if visited.insert(node) {
250                queue.push_back(node);
251            }
252        }
253
254        Self { queue, visited }
255    }
256
257    pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
258    where
259        F: FnMut(u64) -> Vec<u64>,
260    {
261        let mut batch = Vec::new();
262
263        for _ in 0..batch_size {
264            if let Some(node) = self.queue.pop_front() {
265                batch.push(node);
266
267                let neighbors = neighbor_fn(node);
268                for neighbor in neighbors {
269                    if self.visited.insert(neighbor) {
270                        self.queue.push_back(neighbor);
271                    }
272                }
273            } else {
274                break;
275            }
276        }
277
278        batch
279    }
280
281    pub fn is_empty(&self) -> bool {
282        self.queue.is_empty()
283    }
284}
285
286/// SIMD DFS iterator
287pub struct SimdDfsIterator {
288    stack: Vec<u64>,
289    visited: HashSet<u64>,
290}
291
292impl SimdDfsIterator {
293    pub fn new(start_node: u64) -> Self {
294        let mut visited = HashSet::new();
295        visited.insert(start_node);
296
297        Self {
298            stack: vec![start_node],
299            visited,
300        }
301    }
302
303    pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
304    where
305        F: FnMut(u64) -> Vec<u64>,
306    {
307        let mut batch = Vec::new();
308
309        for _ in 0..batch_size {
310            if let Some(node) = self.stack.pop() {
311                batch.push(node);
312
313                let neighbors = neighbor_fn(node);
314                for neighbor in neighbors.into_iter().rev() {
315                    if self.visited.insert(neighbor) {
316                        self.stack.push(neighbor);
317                    }
318                }
319            } else {
320                break;
321            }
322        }
323
324        batch
325    }
326
327    pub fn is_empty(&self) -> bool {
328        self.stack.is_empty()
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_simd_bfs() {
338        let traversal = SimdTraversal::new();
339
340        // Create a simple graph: 0 -> [1, 2], 1 -> [3], 2 -> [4]
341        let graph = vec![
342            vec![1, 2], // Node 0
343            vec![3],    // Node 1
344            vec![4],    // Node 2
345            vec![],     // Node 3
346            vec![],     // Node 4
347        ];
348
349        let result = traversal.simd_bfs(&[0], |node| {
350            graph.get(node as usize).cloned().unwrap_or_default()
351        });
352
353        assert_eq!(result.len(), 5);
354    }
355
356    #[test]
357    fn test_parallel_dfs() {
358        let traversal = SimdTraversal::new();
359
360        let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
361
362        let result = traversal.parallel_dfs(0, |node| {
363            graph.get(node as usize).cloned().unwrap_or_default()
364        });
365
366        assert_eq!(result.len(), 5);
367    }
368
369    #[test]
370    fn test_simd_bfs_iterator() {
371        let mut iter = SimdBfsIterator::new(vec![0]);
372
373        let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
374
375        let mut all_nodes = Vec::new();
376        while !iter.is_empty() {
377            let batch = iter.next_batch(2, |node| {
378                graph.get(node as usize).cloned().unwrap_or_default()
379            });
380            all_nodes.extend(batch);
381        }
382
383        assert_eq!(all_nodes.len(), 5);
384    }
385}