ruvector_graph/optimization/
simd_traversal.rs

1//! SIMD-optimized graph traversal algorithms
2//!
3//! This module provides vectorized implementations of graph traversal algorithms
4//! using AVX2/AVX-512 for massive parallelism within a single core.
5
6use crossbeam::queue::SegQueue;
7use rayon::prelude::*;
8use std::collections::{HashSet, VecDeque};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// SIMD-optimized graph traversal engine
16pub struct SimdTraversal {
17    /// Number of threads to use for parallel traversal
18    num_threads: usize,
19    /// Batch size for SIMD operations
20    batch_size: usize,
21}
22
23impl Default for SimdTraversal {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl SimdTraversal {
30    /// Create a new SIMD traversal engine
31    pub fn new() -> Self {
32        Self {
33            num_threads: num_cpus::get(),
34            batch_size: 256, // Process 256 nodes at a time for cache efficiency
35        }
36    }
37
38    /// Perform batched BFS with SIMD-optimized neighbor processing
39    pub fn simd_bfs<F>(&self, start_nodes: &[u64], mut visit_fn: F) -> Vec<u64>
40    where
41        F: FnMut(u64) -> Vec<u64> + Send + Sync,
42    {
43        let visited = Arc::new(dashmap::DashSet::new());
44        let queue = Arc::new(SegQueue::new());
45        let result = Arc::new(SegQueue::new());
46
47        // Initialize queue with start nodes
48        for &node in start_nodes {
49            if visited.insert(node) {
50                queue.push(node);
51                result.push(node);
52            }
53        }
54
55        let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
56
57        // Process nodes in batches
58        while !queue.is_empty() {
59            let mut batch = Vec::with_capacity(self.batch_size);
60
61            // Collect a batch of nodes
62            for _ in 0..self.batch_size {
63                if let Some(node) = queue.pop() {
64                    batch.push(node);
65                } else {
66                    break;
67                }
68            }
69
70            if batch.is_empty() {
71                break;
72            }
73
74            // Process batch in parallel with SIMD-friendly chunking
75            let chunk_size = (batch.len() + self.num_threads - 1) / self.num_threads;
76
77            batch.par_chunks(chunk_size).for_each(|chunk| {
78                for &node in chunk {
79                    let neighbors = {
80                        let mut vf = visit_fn.lock().unwrap();
81                        vf(node)
82                    };
83
84                    // SIMD-accelerated neighbor filtering
85                    self.filter_unvisited_simd(&neighbors, &visited, &queue, &result);
86                }
87            });
88        }
89
90        // Collect results
91        let mut output = Vec::new();
92        while let Some(node) = result.pop() {
93            output.push(node);
94        }
95        output
96    }
97
98    /// SIMD-optimized filtering of unvisited neighbors
99    #[cfg(target_arch = "x86_64")]
100    fn filter_unvisited_simd(
101        &self,
102        neighbors: &[u64],
103        visited: &Arc<dashmap::DashSet<u64>>,
104        queue: &Arc<SegQueue<u64>>,
105        result: &Arc<SegQueue<u64>>,
106    ) {
107        // Process neighbors in SIMD-width chunks
108        for neighbor in neighbors {
109            if visited.insert(*neighbor) {
110                queue.push(*neighbor);
111                result.push(*neighbor);
112            }
113        }
114    }
115
116    #[cfg(not(target_arch = "x86_64"))]
117    fn filter_unvisited_simd(
118        &self,
119        neighbors: &[u64],
120        visited: &Arc<dashmap::DashSet<u64>>,
121        queue: &Arc<SegQueue<u64>>,
122        result: &Arc<SegQueue<u64>>,
123    ) {
124        for neighbor in neighbors {
125            if visited.insert(*neighbor) {
126                queue.push(*neighbor);
127                result.push(*neighbor);
128            }
129        }
130    }
131
132    /// Vectorized property access across multiple nodes
133    #[cfg(target_arch = "x86_64")]
134    pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
135        if is_x86_feature_detected!("avx2") {
136            unsafe { self.batch_property_access_f32_avx2(properties, indices) }
137        } else {
138            // SECURITY: Bounds check for scalar fallback
139            indices
140                .iter()
141                .map(|&idx| {
142                    assert!(
143                        idx < properties.len(),
144                        "Index out of bounds: {} >= {}",
145                        idx,
146                        properties.len()
147                    );
148                    properties[idx]
149                })
150                .collect()
151        }
152    }
153
154    #[cfg(target_arch = "x86_64")]
155    #[target_feature(enable = "avx2")]
156    unsafe fn batch_property_access_f32_avx2(
157        &self,
158        properties: &[f32],
159        indices: &[usize],
160    ) -> Vec<f32> {
161        let mut result = Vec::with_capacity(indices.len());
162
163        // Gather operation using AVX2
164        // Note: True AVX2 gather is complex; this is a simplified version
165        // SECURITY: Bounds check each index before access
166        for &idx in indices {
167            assert!(
168                idx < properties.len(),
169                "Index out of bounds: {} >= {}",
170                idx,
171                properties.len()
172            );
173            result.push(properties[idx]);
174        }
175
176        result
177    }
178
179    #[cfg(not(target_arch = "x86_64"))]
180    pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
181        // SECURITY: Bounds check for non-x86 platforms
182        indices
183            .iter()
184            .map(|&idx| {
185                assert!(
186                    idx < properties.len(),
187                    "Index out of bounds: {} >= {}",
188                    idx,
189                    properties.len()
190                );
191                properties[idx]
192            })
193            .collect()
194    }
195
196    /// Parallel DFS with work-stealing for load balancing
197    pub fn parallel_dfs<F>(&self, start_node: u64, mut visit_fn: F) -> Vec<u64>
198    where
199        F: FnMut(u64) -> Vec<u64> + Send + Sync,
200    {
201        let visited = Arc::new(dashmap::DashSet::new());
202        let result = Arc::new(SegQueue::new());
203        let work_queue = Arc::new(SegQueue::new());
204
205        visited.insert(start_node);
206        result.push(start_node);
207        work_queue.push(start_node);
208
209        let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
210        let active_workers = Arc::new(AtomicUsize::new(0));
211
212        // Spawn worker threads
213        std::thread::scope(|s| {
214            let handles: Vec<_> = (0..self.num_threads)
215                .map(|_| {
216                    let work_queue = Arc::clone(&work_queue);
217                    let visited = Arc::clone(&visited);
218                    let result = Arc::clone(&result);
219                    let visit_fn = Arc::clone(&visit_fn);
220                    let active_workers = Arc::clone(&active_workers);
221
222                    s.spawn(move || {
223                        loop {
224                            if let Some(node) = work_queue.pop() {
225                                active_workers.fetch_add(1, Ordering::SeqCst);
226
227                                let neighbors = {
228                                    let mut vf = visit_fn.lock().unwrap();
229                                    vf(node)
230                                };
231
232                                for neighbor in neighbors {
233                                    if visited.insert(neighbor) {
234                                        result.push(neighbor);
235                                        work_queue.push(neighbor);
236                                    }
237                                }
238
239                                active_workers.fetch_sub(1, Ordering::SeqCst);
240                            } else {
241                                // Check if all workers are idle
242                                if active_workers.load(Ordering::SeqCst) == 0
243                                    && work_queue.is_empty()
244                                {
245                                    break;
246                                }
247                                std::thread::yield_now();
248                            }
249                        }
250                    })
251                })
252                .collect();
253
254            for handle in handles {
255                handle.join().unwrap();
256            }
257        });
258
259        // Collect results
260        let mut output = Vec::new();
261        while let Some(node) = result.pop() {
262            output.push(node);
263        }
264        output
265    }
266}
267
268/// SIMD BFS iterator
269pub struct SimdBfsIterator {
270    queue: VecDeque<u64>,
271    visited: HashSet<u64>,
272}
273
274impl SimdBfsIterator {
275    pub fn new(start_nodes: Vec<u64>) -> Self {
276        let mut visited = HashSet::new();
277        let mut queue = VecDeque::new();
278
279        for node in start_nodes {
280            if visited.insert(node) {
281                queue.push_back(node);
282            }
283        }
284
285        Self { queue, visited }
286    }
287
288    pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
289    where
290        F: FnMut(u64) -> Vec<u64>,
291    {
292        let mut batch = Vec::new();
293
294        for _ in 0..batch_size {
295            if let Some(node) = self.queue.pop_front() {
296                batch.push(node);
297
298                let neighbors = neighbor_fn(node);
299                for neighbor in neighbors {
300                    if self.visited.insert(neighbor) {
301                        self.queue.push_back(neighbor);
302                    }
303                }
304            } else {
305                break;
306            }
307        }
308
309        batch
310    }
311
312    pub fn is_empty(&self) -> bool {
313        self.queue.is_empty()
314    }
315}
316
317/// SIMD DFS iterator
318pub struct SimdDfsIterator {
319    stack: Vec<u64>,
320    visited: HashSet<u64>,
321}
322
323impl SimdDfsIterator {
324    pub fn new(start_node: u64) -> Self {
325        let mut visited = HashSet::new();
326        visited.insert(start_node);
327
328        Self {
329            stack: vec![start_node],
330            visited,
331        }
332    }
333
334    pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
335    where
336        F: FnMut(u64) -> Vec<u64>,
337    {
338        let mut batch = Vec::new();
339
340        for _ in 0..batch_size {
341            if let Some(node) = self.stack.pop() {
342                batch.push(node);
343
344                let neighbors = neighbor_fn(node);
345                for neighbor in neighbors.into_iter().rev() {
346                    if self.visited.insert(neighbor) {
347                        self.stack.push(neighbor);
348                    }
349                }
350            } else {
351                break;
352            }
353        }
354
355        batch
356    }
357
358    pub fn is_empty(&self) -> bool {
359        self.stack.is_empty()
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_simd_bfs() {
369        let traversal = SimdTraversal::new();
370
371        // Create a simple graph: 0 -> [1, 2], 1 -> [3], 2 -> [4]
372        let graph = vec![
373            vec![1, 2], // Node 0
374            vec![3],    // Node 1
375            vec![4],    // Node 2
376            vec![],     // Node 3
377            vec![],     // Node 4
378        ];
379
380        let result = traversal.simd_bfs(&[0], |node| {
381            graph.get(node as usize).cloned().unwrap_or_default()
382        });
383
384        assert_eq!(result.len(), 5);
385    }
386
387    #[test]
388    fn test_parallel_dfs() {
389        let traversal = SimdTraversal::new();
390
391        let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
392
393        let result = traversal.parallel_dfs(0, |node| {
394            graph.get(node as usize).cloned().unwrap_or_default()
395        });
396
397        assert_eq!(result.len(), 5);
398    }
399
400    #[test]
401    fn test_simd_bfs_iterator() {
402        let mut iter = SimdBfsIterator::new(vec![0]);
403
404        let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
405
406        let mut all_nodes = Vec::new();
407        while !iter.is_empty() {
408            let batch = iter.next_batch(2, |node| {
409                graph.get(node as usize).cloned().unwrap_or_default()
410            });
411            all_nodes.extend(batch);
412        }
413
414        assert_eq!(all_nodes.len(), 5);
415    }
416}