ruvector_graph/optimization/
simd_traversal.rs1use crossbeam::queue::SegQueue;
7use rayon::prelude::*;
8use std::collections::{HashSet, VecDeque};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15pub struct SimdTraversal {
17 num_threads: usize,
19 batch_size: usize,
21}
22
23impl Default for SimdTraversal {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl SimdTraversal {
30 pub fn new() -> Self {
32 Self {
33 num_threads: num_cpus::get(),
34 batch_size: 256, }
36 }
37
38 pub fn simd_bfs<F>(&self, start_nodes: &[u64], mut visit_fn: F) -> Vec<u64>
40 where
41 F: FnMut(u64) -> Vec<u64> + Send + Sync,
42 {
43 let visited = Arc::new(dashmap::DashSet::new());
44 let queue = Arc::new(SegQueue::new());
45 let result = Arc::new(SegQueue::new());
46
47 for &node in start_nodes {
49 if visited.insert(node) {
50 queue.push(node);
51 result.push(node);
52 }
53 }
54
55 let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
56
57 while !queue.is_empty() {
59 let mut batch = Vec::with_capacity(self.batch_size);
60
61 for _ in 0..self.batch_size {
63 if let Some(node) = queue.pop() {
64 batch.push(node);
65 } else {
66 break;
67 }
68 }
69
70 if batch.is_empty() {
71 break;
72 }
73
74 let chunk_size = (batch.len() + self.num_threads - 1) / self.num_threads;
76
77 batch.par_chunks(chunk_size).for_each(|chunk| {
78 for &node in chunk {
79 let neighbors = {
80 let mut vf = visit_fn.lock().unwrap();
81 vf(node)
82 };
83
84 self.filter_unvisited_simd(&neighbors, &visited, &queue, &result);
86 }
87 });
88 }
89
90 let mut output = Vec::new();
92 while let Some(node) = result.pop() {
93 output.push(node);
94 }
95 output
96 }
97
98 #[cfg(target_arch = "x86_64")]
100 fn filter_unvisited_simd(
101 &self,
102 neighbors: &[u64],
103 visited: &Arc<dashmap::DashSet<u64>>,
104 queue: &Arc<SegQueue<u64>>,
105 result: &Arc<SegQueue<u64>>,
106 ) {
107 for neighbor in neighbors {
109 if visited.insert(*neighbor) {
110 queue.push(*neighbor);
111 result.push(*neighbor);
112 }
113 }
114 }
115
116 #[cfg(not(target_arch = "x86_64"))]
117 fn filter_unvisited_simd(
118 &self,
119 neighbors: &[u64],
120 visited: &Arc<dashmap::DashSet<u64>>,
121 queue: &Arc<SegQueue<u64>>,
122 result: &Arc<SegQueue<u64>>,
123 ) {
124 for neighbor in neighbors {
125 if visited.insert(*neighbor) {
126 queue.push(*neighbor);
127 result.push(*neighbor);
128 }
129 }
130 }
131
132 #[cfg(target_arch = "x86_64")]
134 pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
135 if is_x86_feature_detected!("avx2") {
136 unsafe { self.batch_property_access_f32_avx2(properties, indices) }
137 } else {
138 indices.iter().map(|&idx| properties[idx]).collect()
139 }
140 }
141
142 #[cfg(target_arch = "x86_64")]
143 #[target_feature(enable = "avx2")]
144 unsafe fn batch_property_access_f32_avx2(
145 &self,
146 properties: &[f32],
147 indices: &[usize],
148 ) -> Vec<f32> {
149 let mut result = Vec::with_capacity(indices.len());
150
151 for &idx in indices {
154 result.push(properties[idx]);
155 }
156
157 result
158 }
159
160 #[cfg(not(target_arch = "x86_64"))]
161 pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
162 indices.iter().map(|&idx| properties[idx]).collect()
163 }
164
165 pub fn parallel_dfs<F>(&self, start_node: u64, mut visit_fn: F) -> Vec<u64>
167 where
168 F: FnMut(u64) -> Vec<u64> + Send + Sync,
169 {
170 let visited = Arc::new(dashmap::DashSet::new());
171 let result = Arc::new(SegQueue::new());
172 let work_queue = Arc::new(SegQueue::new());
173
174 visited.insert(start_node);
175 result.push(start_node);
176 work_queue.push(start_node);
177
178 let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
179 let active_workers = Arc::new(AtomicUsize::new(0));
180
181 std::thread::scope(|s| {
183 let handles: Vec<_> = (0..self.num_threads)
184 .map(|_| {
185 let work_queue = Arc::clone(&work_queue);
186 let visited = Arc::clone(&visited);
187 let result = Arc::clone(&result);
188 let visit_fn = Arc::clone(&visit_fn);
189 let active_workers = Arc::clone(&active_workers);
190
191 s.spawn(move || {
192 loop {
193 if let Some(node) = work_queue.pop() {
194 active_workers.fetch_add(1, Ordering::SeqCst);
195
196 let neighbors = {
197 let mut vf = visit_fn.lock().unwrap();
198 vf(node)
199 };
200
201 for neighbor in neighbors {
202 if visited.insert(neighbor) {
203 result.push(neighbor);
204 work_queue.push(neighbor);
205 }
206 }
207
208 active_workers.fetch_sub(1, Ordering::SeqCst);
209 } else {
210 if active_workers.load(Ordering::SeqCst) == 0
212 && work_queue.is_empty()
213 {
214 break;
215 }
216 std::thread::yield_now();
217 }
218 }
219 })
220 })
221 .collect();
222
223 for handle in handles {
224 handle.join().unwrap();
225 }
226 });
227
228 let mut output = Vec::new();
230 while let Some(node) = result.pop() {
231 output.push(node);
232 }
233 output
234 }
235}
236
237pub struct SimdBfsIterator {
239 queue: VecDeque<u64>,
240 visited: HashSet<u64>,
241}
242
243impl SimdBfsIterator {
244 pub fn new(start_nodes: Vec<u64>) -> Self {
245 let mut visited = HashSet::new();
246 let mut queue = VecDeque::new();
247
248 for node in start_nodes {
249 if visited.insert(node) {
250 queue.push_back(node);
251 }
252 }
253
254 Self { queue, visited }
255 }
256
257 pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
258 where
259 F: FnMut(u64) -> Vec<u64>,
260 {
261 let mut batch = Vec::new();
262
263 for _ in 0..batch_size {
264 if let Some(node) = self.queue.pop_front() {
265 batch.push(node);
266
267 let neighbors = neighbor_fn(node);
268 for neighbor in neighbors {
269 if self.visited.insert(neighbor) {
270 self.queue.push_back(neighbor);
271 }
272 }
273 } else {
274 break;
275 }
276 }
277
278 batch
279 }
280
281 pub fn is_empty(&self) -> bool {
282 self.queue.is_empty()
283 }
284}
285
286pub struct SimdDfsIterator {
288 stack: Vec<u64>,
289 visited: HashSet<u64>,
290}
291
292impl SimdDfsIterator {
293 pub fn new(start_node: u64) -> Self {
294 let mut visited = HashSet::new();
295 visited.insert(start_node);
296
297 Self {
298 stack: vec![start_node],
299 visited,
300 }
301 }
302
303 pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
304 where
305 F: FnMut(u64) -> Vec<u64>,
306 {
307 let mut batch = Vec::new();
308
309 for _ in 0..batch_size {
310 if let Some(node) = self.stack.pop() {
311 batch.push(node);
312
313 let neighbors = neighbor_fn(node);
314 for neighbor in neighbors.into_iter().rev() {
315 if self.visited.insert(neighbor) {
316 self.stack.push(neighbor);
317 }
318 }
319 } else {
320 break;
321 }
322 }
323
324 batch
325 }
326
327 pub fn is_empty(&self) -> bool {
328 self.stack.is_empty()
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_simd_bfs() {
338 let traversal = SimdTraversal::new();
339
340 let graph = vec![
342 vec![1, 2], vec![3], vec![4], vec![], vec![], ];
348
349 let result = traversal.simd_bfs(&[0], |node| {
350 graph.get(node as usize).cloned().unwrap_or_default()
351 });
352
353 assert_eq!(result.len(), 5);
354 }
355
356 #[test]
357 fn test_parallel_dfs() {
358 let traversal = SimdTraversal::new();
359
360 let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
361
362 let result = traversal.parallel_dfs(0, |node| {
363 graph.get(node as usize).cloned().unwrap_or_default()
364 });
365
366 assert_eq!(result.len(), 5);
367 }
368
369 #[test]
370 fn test_simd_bfs_iterator() {
371 let mut iter = SimdBfsIterator::new(vec![0]);
372
373 let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
374
375 let mut all_nodes = Vec::new();
376 while !iter.is_empty() {
377 let batch = iter.next_batch(2, |node| {
378 graph.get(node as usize).cloned().unwrap_or_default()
379 });
380 all_nodes.extend(batch);
381 }
382
383 assert_eq!(all_nodes.len(), 5);
384 }
385}