ruvector_graph/optimization/
simd_traversal.rs1use crossbeam::queue::SegQueue;
7use rayon::prelude::*;
8use std::collections::{HashSet, VecDeque};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15pub struct SimdTraversal {
17 num_threads: usize,
19 batch_size: usize,
21}
22
23impl Default for SimdTraversal {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl SimdTraversal {
30 pub fn new() -> Self {
32 Self {
33 num_threads: num_cpus::get(),
34 batch_size: 256, }
36 }
37
38 pub fn simd_bfs<F>(&self, start_nodes: &[u64], mut visit_fn: F) -> Vec<u64>
40 where
41 F: FnMut(u64) -> Vec<u64> + Send + Sync,
42 {
43 let visited = Arc::new(dashmap::DashSet::new());
44 let queue = Arc::new(SegQueue::new());
45 let result = Arc::new(SegQueue::new());
46
47 for &node in start_nodes {
49 if visited.insert(node) {
50 queue.push(node);
51 result.push(node);
52 }
53 }
54
55 let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
56
57 while !queue.is_empty() {
59 let mut batch = Vec::with_capacity(self.batch_size);
60
61 for _ in 0..self.batch_size {
63 if let Some(node) = queue.pop() {
64 batch.push(node);
65 } else {
66 break;
67 }
68 }
69
70 if batch.is_empty() {
71 break;
72 }
73
74 let chunk_size = (batch.len() + self.num_threads - 1) / self.num_threads;
76
77 batch.par_chunks(chunk_size).for_each(|chunk| {
78 for &node in chunk {
79 let neighbors = {
80 let mut vf = visit_fn.lock().unwrap();
81 vf(node)
82 };
83
84 self.filter_unvisited_simd(&neighbors, &visited, &queue, &result);
86 }
87 });
88 }
89
90 let mut output = Vec::new();
92 while let Some(node) = result.pop() {
93 output.push(node);
94 }
95 output
96 }
97
98 #[cfg(target_arch = "x86_64")]
100 fn filter_unvisited_simd(
101 &self,
102 neighbors: &[u64],
103 visited: &Arc<dashmap::DashSet<u64>>,
104 queue: &Arc<SegQueue<u64>>,
105 result: &Arc<SegQueue<u64>>,
106 ) {
107 for neighbor in neighbors {
109 if visited.insert(*neighbor) {
110 queue.push(*neighbor);
111 result.push(*neighbor);
112 }
113 }
114 }
115
116 #[cfg(not(target_arch = "x86_64"))]
117 fn filter_unvisited_simd(
118 &self,
119 neighbors: &[u64],
120 visited: &Arc<dashmap::DashSet<u64>>,
121 queue: &Arc<SegQueue<u64>>,
122 result: &Arc<SegQueue<u64>>,
123 ) {
124 for neighbor in neighbors {
125 if visited.insert(*neighbor) {
126 queue.push(*neighbor);
127 result.push(*neighbor);
128 }
129 }
130 }
131
132 #[cfg(target_arch = "x86_64")]
134 pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
135 if is_x86_feature_detected!("avx2") {
136 unsafe { self.batch_property_access_f32_avx2(properties, indices) }
137 } else {
138 indices
140 .iter()
141 .map(|&idx| {
142 assert!(
143 idx < properties.len(),
144 "Index out of bounds: {} >= {}",
145 idx,
146 properties.len()
147 );
148 properties[idx]
149 })
150 .collect()
151 }
152 }
153
154 #[cfg(target_arch = "x86_64")]
155 #[target_feature(enable = "avx2")]
156 unsafe fn batch_property_access_f32_avx2(
157 &self,
158 properties: &[f32],
159 indices: &[usize],
160 ) -> Vec<f32> {
161 let mut result = Vec::with_capacity(indices.len());
162
163 for &idx in indices {
167 assert!(
168 idx < properties.len(),
169 "Index out of bounds: {} >= {}",
170 idx,
171 properties.len()
172 );
173 result.push(properties[idx]);
174 }
175
176 result
177 }
178
179 #[cfg(not(target_arch = "x86_64"))]
180 pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
181 indices
183 .iter()
184 .map(|&idx| {
185 assert!(
186 idx < properties.len(),
187 "Index out of bounds: {} >= {}",
188 idx,
189 properties.len()
190 );
191 properties[idx]
192 })
193 .collect()
194 }
195
196 pub fn parallel_dfs<F>(&self, start_node: u64, mut visit_fn: F) -> Vec<u64>
198 where
199 F: FnMut(u64) -> Vec<u64> + Send + Sync,
200 {
201 let visited = Arc::new(dashmap::DashSet::new());
202 let result = Arc::new(SegQueue::new());
203 let work_queue = Arc::new(SegQueue::new());
204
205 visited.insert(start_node);
206 result.push(start_node);
207 work_queue.push(start_node);
208
209 let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
210 let active_workers = Arc::new(AtomicUsize::new(0));
211
212 std::thread::scope(|s| {
214 let handles: Vec<_> = (0..self.num_threads)
215 .map(|_| {
216 let work_queue = Arc::clone(&work_queue);
217 let visited = Arc::clone(&visited);
218 let result = Arc::clone(&result);
219 let visit_fn = Arc::clone(&visit_fn);
220 let active_workers = Arc::clone(&active_workers);
221
222 s.spawn(move || {
223 loop {
224 if let Some(node) = work_queue.pop() {
225 active_workers.fetch_add(1, Ordering::SeqCst);
226
227 let neighbors = {
228 let mut vf = visit_fn.lock().unwrap();
229 vf(node)
230 };
231
232 for neighbor in neighbors {
233 if visited.insert(neighbor) {
234 result.push(neighbor);
235 work_queue.push(neighbor);
236 }
237 }
238
239 active_workers.fetch_sub(1, Ordering::SeqCst);
240 } else {
241 if active_workers.load(Ordering::SeqCst) == 0
243 && work_queue.is_empty()
244 {
245 break;
246 }
247 std::thread::yield_now();
248 }
249 }
250 })
251 })
252 .collect();
253
254 for handle in handles {
255 handle.join().unwrap();
256 }
257 });
258
259 let mut output = Vec::new();
261 while let Some(node) = result.pop() {
262 output.push(node);
263 }
264 output
265 }
266}
267
268pub struct SimdBfsIterator {
270 queue: VecDeque<u64>,
271 visited: HashSet<u64>,
272}
273
274impl SimdBfsIterator {
275 pub fn new(start_nodes: Vec<u64>) -> Self {
276 let mut visited = HashSet::new();
277 let mut queue = VecDeque::new();
278
279 for node in start_nodes {
280 if visited.insert(node) {
281 queue.push_back(node);
282 }
283 }
284
285 Self { queue, visited }
286 }
287
288 pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
289 where
290 F: FnMut(u64) -> Vec<u64>,
291 {
292 let mut batch = Vec::new();
293
294 for _ in 0..batch_size {
295 if let Some(node) = self.queue.pop_front() {
296 batch.push(node);
297
298 let neighbors = neighbor_fn(node);
299 for neighbor in neighbors {
300 if self.visited.insert(neighbor) {
301 self.queue.push_back(neighbor);
302 }
303 }
304 } else {
305 break;
306 }
307 }
308
309 batch
310 }
311
312 pub fn is_empty(&self) -> bool {
313 self.queue.is_empty()
314 }
315}
316
317pub struct SimdDfsIterator {
319 stack: Vec<u64>,
320 visited: HashSet<u64>,
321}
322
323impl SimdDfsIterator {
324 pub fn new(start_node: u64) -> Self {
325 let mut visited = HashSet::new();
326 visited.insert(start_node);
327
328 Self {
329 stack: vec![start_node],
330 visited,
331 }
332 }
333
334 pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
335 where
336 F: FnMut(u64) -> Vec<u64>,
337 {
338 let mut batch = Vec::new();
339
340 for _ in 0..batch_size {
341 if let Some(node) = self.stack.pop() {
342 batch.push(node);
343
344 let neighbors = neighbor_fn(node);
345 for neighbor in neighbors.into_iter().rev() {
346 if self.visited.insert(neighbor) {
347 self.stack.push(neighbor);
348 }
349 }
350 } else {
351 break;
352 }
353 }
354
355 batch
356 }
357
358 pub fn is_empty(&self) -> bool {
359 self.stack.is_empty()
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_simd_bfs() {
369 let traversal = SimdTraversal::new();
370
371 let graph = vec![
373 vec![1, 2], vec![3], vec![4], vec![], vec![], ];
379
380 let result = traversal.simd_bfs(&[0], |node| {
381 graph.get(node as usize).cloned().unwrap_or_default()
382 });
383
384 assert_eq!(result.len(), 5);
385 }
386
387 #[test]
388 fn test_parallel_dfs() {
389 let traversal = SimdTraversal::new();
390
391 let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
392
393 let result = traversal.parallel_dfs(0, |node| {
394 graph.get(node as usize).cloned().unwrap_or_default()
395 });
396
397 assert_eq!(result.len(), 5);
398 }
399
400 #[test]
401 fn test_simd_bfs_iterator() {
402 let mut iter = SimdBfsIterator::new(vec![0]);
403
404 let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
405
406 let mut all_nodes = Vec::new();
407 while !iter.is_empty() {
408 let batch = iter.next_batch(2, |node| {
409 graph.get(node as usize).cloned().unwrap_or_default()
410 });
411 all_nodes.extend(batch);
412 }
413
414 assert_eq!(all_nodes.len(), 5);
415 }
416}