ruvector_graph/optimization/
simd_traversal.rs1use crossbeam::queue::SegQueue;
7use rayon::prelude::*;
8use std::collections::{HashSet, VecDeque};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15pub struct SimdTraversal {
17 num_threads: usize,
19 batch_size: usize,
21}
22
23impl Default for SimdTraversal {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl SimdTraversal {
30 pub fn new() -> Self {
32 Self {
33 num_threads: num_cpus::get(),
34 batch_size: 256, }
36 }
37
38 pub fn simd_bfs<F>(&self, start_nodes: &[u64], mut visit_fn: F) -> Vec<u64>
40 where
41 F: FnMut(u64) -> Vec<u64> + Send + Sync,
42 {
43 let visited = Arc::new(dashmap::DashSet::new());
44 let queue = Arc::new(SegQueue::new());
45 let result = Arc::new(SegQueue::new());
46
47 for &node in start_nodes {
49 if visited.insert(node) {
50 queue.push(node);
51 result.push(node);
52 }
53 }
54
55 let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
56
57 while !queue.is_empty() {
59 let mut batch = Vec::with_capacity(self.batch_size);
60
61 for _ in 0..self.batch_size {
63 if let Some(node) = queue.pop() {
64 batch.push(node);
65 } else {
66 break;
67 }
68 }
69
70 if batch.is_empty() {
71 break;
72 }
73
74 let chunk_size = (batch.len() + self.num_threads - 1) / self.num_threads;
76
77 batch.par_chunks(chunk_size).for_each(|chunk| {
78 for &node in chunk {
79 let neighbors = {
80 let mut vf = visit_fn.lock().unwrap();
81 vf(node)
82 };
83
84 self.filter_unvisited_simd(&neighbors, &visited, &queue, &result);
86 }
87 });
88 }
89
90 let mut output = Vec::new();
92 while let Some(node) = result.pop() {
93 output.push(node);
94 }
95 output
96 }
97
98 #[cfg(target_arch = "x86_64")]
100 fn filter_unvisited_simd(
101 &self,
102 neighbors: &[u64],
103 visited: &Arc<dashmap::DashSet<u64>>,
104 queue: &Arc<SegQueue<u64>>,
105 result: &Arc<SegQueue<u64>>,
106 ) {
107 for neighbor in neighbors {
109 if visited.insert(*neighbor) {
110 queue.push(*neighbor);
111 result.push(*neighbor);
112 }
113 }
114 }
115
116 #[cfg(not(target_arch = "x86_64"))]
117 fn filter_unvisited_simd(
118 &self,
119 neighbors: &[u64],
120 visited: &Arc<dashmap::DashSet<u64>>,
121 queue: &Arc<SegQueue<u64>>,
122 result: &Arc<SegQueue<u64>>,
123 ) {
124 for neighbor in neighbors {
125 if visited.insert(*neighbor) {
126 queue.push(*neighbor);
127 result.push(*neighbor);
128 }
129 }
130 }
131
132 #[cfg(target_arch = "x86_64")]
134 pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
135 if is_x86_feature_detected!("avx2") {
136 unsafe { self.batch_property_access_f32_avx2(properties, indices) }
137 } else {
138 indices.iter().map(|&idx| {
140 assert!(idx < properties.len(), "Index out of bounds: {} >= {}", idx, properties.len());
141 properties[idx]
142 }).collect()
143 }
144 }
145
146 #[cfg(target_arch = "x86_64")]
147 #[target_feature(enable = "avx2")]
148 unsafe fn batch_property_access_f32_avx2(
149 &self,
150 properties: &[f32],
151 indices: &[usize],
152 ) -> Vec<f32> {
153 let mut result = Vec::with_capacity(indices.len());
154
155 for &idx in indices {
159 assert!(idx < properties.len(), "Index out of bounds: {} >= {}", idx, properties.len());
160 result.push(properties[idx]);
161 }
162
163 result
164 }
165
166 #[cfg(not(target_arch = "x86_64"))]
167 pub fn batch_property_access_f32(&self, properties: &[f32], indices: &[usize]) -> Vec<f32> {
168 indices.iter().map(|&idx| {
170 assert!(idx < properties.len(), "Index out of bounds: {} >= {}", idx, properties.len());
171 properties[idx]
172 }).collect()
173 }
174
175 pub fn parallel_dfs<F>(&self, start_node: u64, mut visit_fn: F) -> Vec<u64>
177 where
178 F: FnMut(u64) -> Vec<u64> + Send + Sync,
179 {
180 let visited = Arc::new(dashmap::DashSet::new());
181 let result = Arc::new(SegQueue::new());
182 let work_queue = Arc::new(SegQueue::new());
183
184 visited.insert(start_node);
185 result.push(start_node);
186 work_queue.push(start_node);
187
188 let visit_fn = Arc::new(std::sync::Mutex::new(visit_fn));
189 let active_workers = Arc::new(AtomicUsize::new(0));
190
191 std::thread::scope(|s| {
193 let handles: Vec<_> = (0..self.num_threads)
194 .map(|_| {
195 let work_queue = Arc::clone(&work_queue);
196 let visited = Arc::clone(&visited);
197 let result = Arc::clone(&result);
198 let visit_fn = Arc::clone(&visit_fn);
199 let active_workers = Arc::clone(&active_workers);
200
201 s.spawn(move || {
202 loop {
203 if let Some(node) = work_queue.pop() {
204 active_workers.fetch_add(1, Ordering::SeqCst);
205
206 let neighbors = {
207 let mut vf = visit_fn.lock().unwrap();
208 vf(node)
209 };
210
211 for neighbor in neighbors {
212 if visited.insert(neighbor) {
213 result.push(neighbor);
214 work_queue.push(neighbor);
215 }
216 }
217
218 active_workers.fetch_sub(1, Ordering::SeqCst);
219 } else {
220 if active_workers.load(Ordering::SeqCst) == 0
222 && work_queue.is_empty()
223 {
224 break;
225 }
226 std::thread::yield_now();
227 }
228 }
229 })
230 })
231 .collect();
232
233 for handle in handles {
234 handle.join().unwrap();
235 }
236 });
237
238 let mut output = Vec::new();
240 while let Some(node) = result.pop() {
241 output.push(node);
242 }
243 output
244 }
245}
246
247pub struct SimdBfsIterator {
249 queue: VecDeque<u64>,
250 visited: HashSet<u64>,
251}
252
253impl SimdBfsIterator {
254 pub fn new(start_nodes: Vec<u64>) -> Self {
255 let mut visited = HashSet::new();
256 let mut queue = VecDeque::new();
257
258 for node in start_nodes {
259 if visited.insert(node) {
260 queue.push_back(node);
261 }
262 }
263
264 Self { queue, visited }
265 }
266
267 pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
268 where
269 F: FnMut(u64) -> Vec<u64>,
270 {
271 let mut batch = Vec::new();
272
273 for _ in 0..batch_size {
274 if let Some(node) = self.queue.pop_front() {
275 batch.push(node);
276
277 let neighbors = neighbor_fn(node);
278 for neighbor in neighbors {
279 if self.visited.insert(neighbor) {
280 self.queue.push_back(neighbor);
281 }
282 }
283 } else {
284 break;
285 }
286 }
287
288 batch
289 }
290
291 pub fn is_empty(&self) -> bool {
292 self.queue.is_empty()
293 }
294}
295
296pub struct SimdDfsIterator {
298 stack: Vec<u64>,
299 visited: HashSet<u64>,
300}
301
302impl SimdDfsIterator {
303 pub fn new(start_node: u64) -> Self {
304 let mut visited = HashSet::new();
305 visited.insert(start_node);
306
307 Self {
308 stack: vec![start_node],
309 visited,
310 }
311 }
312
313 pub fn next_batch<F>(&mut self, batch_size: usize, mut neighbor_fn: F) -> Vec<u64>
314 where
315 F: FnMut(u64) -> Vec<u64>,
316 {
317 let mut batch = Vec::new();
318
319 for _ in 0..batch_size {
320 if let Some(node) = self.stack.pop() {
321 batch.push(node);
322
323 let neighbors = neighbor_fn(node);
324 for neighbor in neighbors.into_iter().rev() {
325 if self.visited.insert(neighbor) {
326 self.stack.push(neighbor);
327 }
328 }
329 } else {
330 break;
331 }
332 }
333
334 batch
335 }
336
337 pub fn is_empty(&self) -> bool {
338 self.stack.is_empty()
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_simd_bfs() {
348 let traversal = SimdTraversal::new();
349
350 let graph = vec![
352 vec![1, 2], vec![3], vec![4], vec![], vec![], ];
358
359 let result = traversal.simd_bfs(&[0], |node| {
360 graph.get(node as usize).cloned().unwrap_or_default()
361 });
362
363 assert_eq!(result.len(), 5);
364 }
365
366 #[test]
367 fn test_parallel_dfs() {
368 let traversal = SimdTraversal::new();
369
370 let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
371
372 let result = traversal.parallel_dfs(0, |node| {
373 graph.get(node as usize).cloned().unwrap_or_default()
374 });
375
376 assert_eq!(result.len(), 5);
377 }
378
379 #[test]
380 fn test_simd_bfs_iterator() {
381 let mut iter = SimdBfsIterator::new(vec![0]);
382
383 let graph = vec![vec![1, 2], vec![3], vec![4], vec![], vec![]];
384
385 let mut all_nodes = Vec::new();
386 while !iter.is_empty() {
387 let batch = iter.next_batch(2, |node| {
388 graph.get(node as usize).cloned().unwrap_or_default()
389 });
390 all_nodes.extend(batch);
391 }
392
393 assert_eq!(all_nodes.len(), 5);
394 }
395}