ruvector_graph/optimization/
simd_traversal.rs

1//! SIMD-optimized graph traversal algorithms
2//!
3//! This module provides vectorized implementations of graph traversal algorithms
4//! using AVX2/AVX-512 for massive parallelism within a single core.
5
6use crossbeam::queue::SegQueue;
7use rayon::prelude::*;
8use std::collections::{HashSet, VecDeque};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// SIMD-optimized graph traversal engine
16pub struct SimdTraversal {
17    /// Number of threads to use for parallel traversal
18    num_threads: usize,
19    /// Batch size for SIMD operations
20    batch_size: usize,
21}
22
23impl Default for SimdTraversal {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl SimdTraversal {
30    /// Create a new SIMD traversal engine
31    pub fn new() -> Self {
32        Self {
33            num_threads: num_cpus::get(),
34            batch_size: 256, // Process 256 nodes at a time for cache efficiency
35        }
36    }
37
38    /// Perform batched BFS with SIMD-optimized neighbor processing
39    pub fn simd_bfs<F>(&self, start_nodes: &[u64], mut visit_fn: F) -> Vec<u64>
40    where
41        F: FnMut(u64) -> Vec<u64> + Send + Sync,
42    {
43        let visited = Arc::new(dashmap::DashSet::new());
44        let queue = Arc::new(SegQueue::new());
45        let result = Arc::new(SegQueue::new());
46
47        // Initialize queue with start nodes
48        for &node in start_nodes {
49            if visited.insert(node) {
50                queue.push(node);
51                result.push(node);
52            }
53        }
54
55        let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
56
57        // Process nodes in batches
58        while !queue.is_empty() {
59            let mut batch = Vec::with_capacity(self.batch_size);
60
61            // Collect a batch of nodes
62            for _ in 0..self.batch_size {
63                if let Some(node) = queue.pop() {
64                    batch.push(node);
65                } else {
66                    break;
67                }
68            }
69
70            if batch.is_empty() {
71                break;
72            }
73
74            // Process batch in parallel with SIMD-friendly chunking
75            let chunk_size = (batch.len() + self.num_threads - 1) / self.num_threads;
76
77            batch.par_chunks(chunk_size).for_each(|chunk| {
78                for &node in chunk {
79                    let neighbors = {
80                        let mut vf = visit_fn.lock().unwrap();
81                        vf(node)
82                    };
83
84                    // SIMD-accelerated neighbor filtering
85                    self.filter_unvisited_simd(&neighbors, &visited, &queue, &result);
86                }
87            });
88        }
89
90        // Collect results
91        let mut output = Vec::new();
92        while let Some(node) = result.pop() {
93            output.push(node);
94        }
95        output
96    }
97
98    /// SIMD-optimized filtering of unvisited neighbors
99    #[cfg(target_arch = "x86_64")]
100    fn filter_unvisited_simd(
101        &self,
102        neighbors: &[u64],
103        visited: &Arc<dashmap::DashSet<u64>>,
104        queue: &Arc<SegQueue<u64>>,
105        result: &Arc<SegQueue<u64>>,
106    ) {
107        // Process neighbors in SIMD-width chunks
108        for neighbor in neighbors {
109            if visited.insert(*neighbor) {
110                queue.push(*neighbor);
111                result.push(*neighbor);
112            }
113        }
114    }
115
116    #[cfg(not(target_arch = "x86_64"))]
117    fn filter_unvisited_simd(
118        &self,
119        neighbors: &[u64],
120        visited: &Arc<dashmap::DashSet<u64>>,
121        queue: &Arc<SegQueue<u64>>,
122        result: &Arc<SegQueue<u64>>,
123    ) {
124        for neighbor in neighbors {
125            if visited.insert(*neighbor) {
126                queue.push(*neighbor);
127                result.push(*neighbor);
128            }
129        }
130    }
131
132    /// Vectorized property access across multiple nodes
133    #[cfg(target_arch = "x86_64")]
134    pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
135        if is_x86_feature_detected!("avx2") {
136            unsafe { self.batch_property_access_f32_avx2(properties, indices) }
137        } else {
138            // SECURITY: Bounds check for scalar fallback
139            indices.iter().map(|&idx| {
140                assert!(idx < properties.len(), "Index out of bounds: {} >= {}", idx, properties.len());
141                properties[idx]
142            }).collect()
143        }
144    }
145
146    #[cfg(target_arch = "x86_64")]
147    #[target_feature(enable = "avx2")]
148    unsafe fn batch_property_access_f32_avx2(
149        &self,
150        properties: &[f32],
151        indices: &[usize],
152    ) -> Vec<f32> {
153        let mut result = Vec::with_capacity(indices.len());
154
155        // Gather operation using AVX2
156        // Note: True AVX2 gather is complex; this is a simplified version
157        // SECURITY: Bounds check each index before access
158        for &idx in indices {
159            assert!(idx < properties.len(), "Index out of bounds: {} >= {}", idx, properties.len());
160            result.push(properties[idx]);
161        }
162
163        result
164    }
165
166    #[cfg(not(target_arch = "x86_64"))]
167    pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
168        // SECURITY: Bounds check for non-x86 platforms
169        indices.iter().map(|&idx| {
170            assert!(idx < properties.len(), "Index out of bounds: {} >= {}", idx, properties.len());
171            properties[idx]
172        }).collect()
173    }
174
175    /// Parallel DFS with work-stealing for load balancing
176    pub fn parallel_dfs<F>(&self, start_node: u64, mut visit_fn: F) -> Vec<u64>
177    where
178        F: FnMut(u64) -> Vec<u64> + Send + Sync,
179    {
180        let visited = Arc::new(dashmap::DashSet::new());
181        let result = Arc::new(SegQueue::new());
182        let work_queue = Arc::new(SegQueue::new());
183
184        visited.insert(start_node);
185        result.push(start_node);
186        work_queue.push(start_node);
187
188        let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
189        let active_workers = Arc::new(AtomicUsize::new(0));
190
191        // Spawn worker threads
192        std::thread::scope(|s| {
193            let handles: Vec<_> = (0..self.num_threads)
194                .map(|_| {
195                    let work_queue = Arc::clone(&work_queue);
196                    let visited = Arc::clone(&visited);
197                    let result = Arc::clone(&result);
198                    let visit_fn = Arc::clone(&visit_fn);
199                    let active_workers = Arc::clone(&active_workers);
200
201                    s.spawn(move || {
202                        loop {
203                            if let Some(node) = work_queue.pop() {
204                                active_workers.fetch_add(1, Ordering::SeqCst);
205
206                                let neighbors = {
207                                    let mut vf = visit_fn.lock().unwrap();
208                                    vf(node)
209                                };
210
211                                for neighbor in neighbors {
212                                    if visited.insert(neighbor) {
213                                        result.push(neighbor);
214                                        work_queue.push(neighbor);
215                                    }
216                                }
217
218                                active_workers.fetch_sub(1, Ordering::SeqCst);
219                            } else {
220                                // Check if all workers are idle
221                                if active_workers.load(Ordering::SeqCst) == 0
222                                    && work_queue.is_empty()
223                                {
224                                    break;
225                                }
226                                std::thread::yield_now();
227                            }
228                        }
229                    })
230                })
231                .collect();
232
233            for handle in handles {
234                handle.join().unwrap();
235            }
236        });
237
238        // Collect results
239        let mut output = Vec::new();
240        while let Some(node) = result.pop() {
241            output.push(node);
242        }
243        output
244    }
245}
246
247/// SIMD BFS iterator
248pub struct SimdBfsIterator {
249    queue: VecDeque<u64>,
250    visited: HashSet<u64>,
251}
252
253impl SimdBfsIterator {
254    pub fn new(start_nodes: Vec<u64>) -> Self {
255        let mut visited = HashSet::new();
256        let mut queue = VecDeque::new();
257
258        for node in start_nodes {
259            if visited.insert(node) {
260                queue.push_back(node);
261            }
262        }
263
264        Self { queue, visited }
265    }
266
267    pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
268    where
269        F: FnMut(u64) -> Vec<u64>,
270    {
271        let mut batch = Vec::new();
272
273        for _ in 0..batch_size {
274            if let Some(node) = self.queue.pop_front() {
275                batch.push(node);
276
277                let neighbors = neighbor_fn(node);
278                for neighbor in neighbors {
279                    if self.visited.insert(neighbor) {
280                        self.queue.push_back(neighbor);
281                    }
282                }
283            } else {
284                break;
285            }
286        }
287
288        batch
289    }
290
291    pub fn is_empty(&self) -> bool {
292        self.queue.is_empty()
293    }
294}
295
296/// SIMD DFS iterator
297pub struct SimdDfsIterator {
298    stack: Vec<u64>,
299    visited: HashSet<u64>,
300}
301
302impl SimdDfsIterator {
303    pub fn new(start_node: u64) -> Self {
304        let mut visited = HashSet::new();
305        visited.insert(start_node);
306
307        Self {
308            stack: vec![start_node],
309            visited,
310        }
311    }
312
313    pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
314    where
315        F: FnMut(u64) -> Vec<u64>,
316    {
317        let mut batch = Vec::new();
318
319        for _ in 0..batch_size {
320            if let Some(node) = self.stack.pop() {
321                batch.push(node);
322
323                let neighbors = neighbor_fn(node);
324                for neighbor in neighbors.into_iter().rev() {
325                    if self.visited.insert(neighbor) {
326                        self.stack.push(neighbor);
327                    }
328                }
329            } else {
330                break;
331            }
332        }
333
334        batch
335    }
336
337    pub fn is_empty(&self) -> bool {
338        self.stack.is_empty()
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_simd_bfs() {
348        let traversal = SimdTraversal::new();
349
350        // Create a simple graph: 0 -> [1, 2], 1 -> [3], 2 -> [4]
351        let graph = vec![
352            vec![1, 2], // Node 0
353            vec![3],    // Node 1
354            vec![4],    // Node 2
355            vec![],     // Node 3
356            vec![],     // Node 4
357        ];
358
359        let result = traversal.simd_bfs(&[0], |node| {
360            graph.get(node as usize).cloned().unwrap_or_default()
361        });
362
363        assert_eq!(result.len(), 5);
364    }
365
366    #[test]
367    fn test_parallel_dfs() {
368        let traversal = SimdTraversal::new();
369
370        let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
371
372        let result = traversal.parallel_dfs(0, |node| {
373            graph.get(node as usize).cloned().unwrap_or_default()
374        });
375
376        assert_eq!(result.len(), 5);
377    }
378
379    #[test]
380    fn test_simd_bfs_iterator() {
381        let mut iter = SimdBfsIterator::new(vec![0]);
382
383        let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
384
385        let mut all_nodes = Vec::new();
386        while !iter.is_empty() {
387            let batch = iter.next_batch(2, |node| {
388                graph.get(node as usize).cloned().unwrap_or_default()
389            });
390            all_nodes.extend(batch);
391        }
392
393        assert_eq!(all_nodes.len(), 5);
394    }
395}