ruvector_mincut/optimization/
simd_distance.rs1use crate::graph::VertexId;
11
12#[cfg(target_arch = "wasm32")]
13use core::arch::wasm32::*;
14
15pub const SIMD_ALIGNMENT: usize = 64;
17
18pub const SIMD_LANES: usize = 4; #[repr(C, align(64))]
23pub struct DistanceArray {
24 data: Vec<f64>,
26 len: usize,
28}
29
30impl DistanceArray {
31 pub fn new(size: usize) -> Self {
33 Self {
34 data: vec![f64::INFINITY; size],
35 len: size,
36 }
37 }
38
39 pub fn from_slice(slice: &[f64]) -> Self {
41 Self {
42 data: slice.to_vec(),
43 len: slice.len(),
44 }
45 }
46
47 #[inline]
49 pub fn get(&self, v: VertexId) -> f64 {
50 self.data.get(v as usize).copied().unwrap_or(f64::INFINITY)
51 }
52
53 #[inline]
55 pub fn set(&mut self, v: VertexId, distance: f64) {
56 if (v as usize) < self.len {
57 self.data[v as usize] = distance;
58 }
59 }
60
61 pub fn len(&self) -> usize {
63 self.len
64 }
65
66 pub fn is_empty(&self) -> bool {
68 self.len == 0
69 }
70
71 pub fn reset(&mut self) {
73 for d in &mut self.data {
74 *d = f64::INFINITY;
75 }
76 }
77
78 pub fn as_slice(&self) -> &[f64] {
80 &self.data
81 }
82
83 pub fn as_mut_slice(&mut self) -> &mut [f64] {
85 &mut self.data
86 }
87}
88
89pub struct SimdDistanceOps;
91
92impl SimdDistanceOps {
93 #[cfg(target_arch = "wasm32")]
97 pub fn find_min(distances: &DistanceArray) -> (f64, usize) {
98 let data = distances.as_slice();
99 if data.is_empty() {
100 return (f64::INFINITY, 0);
101 }
102
103 let mut min_val = f64::INFINITY;
104 let mut min_idx = 0;
105
106 let chunks = data.len() / 2;
108
109 unsafe {
110 for i in 0..chunks {
111 let offset = i * 2;
112 let v = v128_load(data.as_ptr().add(offset) as *const v128);
113
114 let a = f64x2_extract_lane::<0>(v);
115 let b = f64x2_extract_lane::<1>(v);
116
117 if a < min_val {
118 min_val = a;
119 min_idx = offset;
120 }
121 if b < min_val {
122 min_val = b;
123 min_idx = offset + 1;
124 }
125 }
126 }
127
128 for i in (chunks * 2)..data.len() {
130 if data[i] < min_val {
131 min_val = data[i];
132 min_idx = i;
133 }
134 }
135
136 (min_val, min_idx)
137 }
138
139 #[cfg(not(target_arch = "wasm32"))]
141 pub fn find_min(distances: &DistanceArray) -> (f64, usize) {
142 let data = distances.as_slice();
143 if data.is_empty() {
144 return (f64::INFINITY, 0);
145 }
146
147 let mut min_val = f64::INFINITY;
148 let mut min_idx = 0;
149
150 let chunks = data.len() / 4;
152 for i in 0..chunks {
153 let base = i * 4;
154 let a = data[base];
155 let b = data[base + 1];
156 let c = data[base + 2];
157 let d = data[base + 3];
158
159 if a < min_val {
160 min_val = a;
161 min_idx = base;
162 }
163 if b < min_val {
164 min_val = b;
165 min_idx = base + 1;
166 }
167 if c < min_val {
168 min_val = c;
169 min_idx = base + 2;
170 }
171 if d < min_val {
172 min_val = d;
173 min_idx = base + 3;
174 }
175 }
176
177 for i in (chunks * 4)..data.len() {
179 if data[i] < min_val {
180 min_val = data[i];
181 min_idx = i;
182 }
183 }
184
185 (min_val, min_idx)
186 }
187
188 #[cfg(target_arch = "wasm32")]
192 pub fn relax_batch(
193 distances: &mut DistanceArray,
194 source_dist: f64,
195 neighbors: &[(VertexId, f64)], ) -> usize {
197 let mut updated = 0;
198 let data = distances.as_mut_slice();
199
200 unsafe {
201 let source_v = f64x2_splat(source_dist);
202
203 let pairs = neighbors.len() / 2;
205 for i in 0..pairs {
206 let idx0 = neighbors[i * 2].0 as usize;
207 let idx1 = neighbors[i * 2 + 1].0 as usize;
208 let w0 = neighbors[i * 2].1;
209 let w1 = neighbors[i * 2 + 1].1;
210
211 if idx0 < data.len() && idx1 < data.len() {
212 let weights = f64x2(w0, w1);
213 let new_dist = f64x2_add(source_v, weights);
214
215 let old0 = data[idx0];
216 let old1 = data[idx1];
217
218 let new0 = f64x2_extract_lane::<0>(new_dist);
219 let new1 = f64x2_extract_lane::<1>(new_dist);
220
221 if new0 < old0 {
222 data[idx0] = new0;
223 updated += 1;
224 }
225 if new1 < old1 {
226 data[idx1] = new1;
227 updated += 1;
228 }
229 }
230 }
231 }
232
233 if neighbors.len() % 2 == 1 {
235 let (idx, weight) = neighbors[neighbors.len() - 1];
236 let idx = idx as usize;
237 if idx < data.len() {
238 let new_dist = source_dist + weight;
239 if new_dist < data[idx] {
240 data[idx] = new_dist;
241 updated += 1;
242 }
243 }
244 }
245
246 updated
247 }
248
249 #[cfg(not(target_arch = "wasm32"))]
251 pub fn relax_batch(
252 distances: &mut DistanceArray,
253 source_dist: f64,
254 neighbors: &[(VertexId, f64)],
255 ) -> usize {
256 let mut updated = 0;
257 let data = distances.as_mut_slice();
258
259 let chunks = neighbors.len() / 4;
261
262 for i in 0..chunks {
263 let base = i * 4;
264
265 let (idx0, w0) = neighbors[base];
266 let (idx1, w1) = neighbors[base + 1];
267 let (idx2, w2) = neighbors[base + 2];
268 let (idx3, w3) = neighbors[base + 3];
269
270 let new0 = source_dist + w0;
271 let new1 = source_dist + w1;
272 let new2 = source_dist + w2;
273 let new3 = source_dist + w3;
274
275 let idx0 = idx0 as usize;
276 let idx1 = idx1 as usize;
277 let idx2 = idx2 as usize;
278 let idx3 = idx3 as usize;
279
280 if idx0 < data.len() && new0 < data[idx0] {
281 data[idx0] = new0;
282 updated += 1;
283 }
284 if idx1 < data.len() && new1 < data[idx1] {
285 data[idx1] = new1;
286 updated += 1;
287 }
288 if idx2 < data.len() && new2 < data[idx2] {
289 data[idx2] = new2;
290 updated += 1;
291 }
292 if idx3 < data.len() && new3 < data[idx3] {
293 data[idx3] = new3;
294 updated += 1;
295 }
296 }
297
298 for i in (chunks * 4)..neighbors.len() {
300 let (idx, weight) = neighbors[i];
301 let idx = idx as usize;
302 if idx < data.len() {
303 let new_dist = source_dist + weight;
304 if new_dist < data[idx] {
305 data[idx] = new_dist;
306 updated += 1;
307 }
308 }
309 }
310
311 updated
312 }
313
314 #[cfg(target_arch = "wasm32")]
316 pub fn count_below_threshold(distances: &DistanceArray, threshold: f64) -> usize {
317 let data = distances.as_slice();
318 let mut count = 0;
319
320 unsafe {
321 let thresh_v = f64x2_splat(threshold);
322
323 let chunks = data.len() / 2;
324 for i in 0..chunks {
325 let offset = i * 2;
326 let v = v128_load(data.as_ptr().add(offset) as *const v128);
327 let cmp = f64x2_lt(v, thresh_v);
328
329 let mask = i8x16_bitmask(cmp);
331 if mask & 0xFF != 0 {
333 count += 1;
334 }
335 if mask & 0xFF00 != 0 {
336 count += 1;
337 }
338 }
339 }
340
341 for i in (data.len() / 2 * 2)..data.len() {
343 if data[i] < threshold {
344 count += 1;
345 }
346 }
347
348 count
349 }
350
351 #[cfg(not(target_arch = "wasm32"))]
353 pub fn count_below_threshold(distances: &DistanceArray, threshold: f64) -> usize {
354 distances
355 .as_slice()
356 .iter()
357 .filter(|&&d| d < threshold)
358 .count()
359 }
360
361 pub fn sum_finite(distances: &DistanceArray) -> (f64, usize) {
363 let mut sum = 0.0;
364 let mut count = 0;
365
366 for &d in distances.as_slice() {
367 if d.is_finite() {
368 sum += d;
369 count += 1;
370 }
371 }
372
373 (sum, count)
374 }
375
376 pub fn elementwise_min(a: &DistanceArray, b: &DistanceArray) -> DistanceArray {
378 let len = a.len().min(b.len());
379 let mut result = DistanceArray::new(len);
380
381 let a_data = a.as_slice();
382 let b_data = b.as_slice();
383 let r_data = result.as_mut_slice();
384
385 let chunks = len / 4;
387 for i in 0..chunks {
388 let base = i * 4;
389 r_data[base] = a_data[base].min(b_data[base]);
390 r_data[base + 1] = a_data[base + 1].min(b_data[base + 1]);
391 r_data[base + 2] = a_data[base + 2].min(b_data[base + 2]);
392 r_data[base + 3] = a_data[base + 3].min(b_data[base + 3]);
393 }
394
395 for i in (chunks * 4)..len {
396 r_data[i] = a_data[i].min(b_data[i]);
397 }
398
399 result
400 }
401
402 pub fn scale(distances: &mut DistanceArray, factor: f64) {
404 for d in distances.as_mut_slice() {
405 if d.is_finite() {
406 *d *= factor;
407 }
408 }
409 }
410}
411
412#[repr(C)]
414#[derive(Debug, Clone, Copy)]
415pub struct PriorityEntry {
416 pub distance: f64,
418 pub vertex: VertexId,
420}
421
422impl PriorityEntry {
423 pub fn new(distance: f64, vertex: VertexId) -> Self {
425 Self { distance, vertex }
426 }
427}
428
429impl PartialEq for PriorityEntry {
430 fn eq(&self, other: &Self) -> bool {
431 self.distance == other.distance && self.vertex == other.vertex
432 }
433}
434
435impl Eq for PriorityEntry {}
436
437impl PartialOrd for PriorityEntry {
438 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
439 other.distance.partial_cmp(&self.distance)
441 }
442}
443
444impl Ord for PriorityEntry {
445 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
446 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_distance_array_basic() {
456 let mut arr = DistanceArray::new(10);
457
458 arr.set(0, 1.0);
459 arr.set(5, 5.0);
460
461 assert_eq!(arr.get(0), 1.0);
462 assert_eq!(arr.get(5), 5.0);
463 assert_eq!(arr.get(9), f64::INFINITY);
464 }
465
466 #[test]
467 fn test_find_min() {
468 let mut arr = DistanceArray::new(100);
469
470 arr.set(50, 1.0);
471 arr.set(25, 0.5);
472 arr.set(75, 2.0);
473
474 let (min_val, min_idx) = SimdDistanceOps::find_min(&arr);
475
476 assert_eq!(min_val, 0.5);
477 assert_eq!(min_idx, 25);
478 }
479
480 #[test]
481 fn test_find_min_empty() {
482 let arr = DistanceArray::new(0);
483 let (min_val, _) = SimdDistanceOps::find_min(&arr);
484 assert!(min_val.is_infinite());
485 }
486
487 #[test]
488 fn test_relax_batch() {
489 let mut arr = DistanceArray::new(10);
490 arr.set(0, 0.0); let neighbors = vec![(1, 1.0), (2, 2.0), (3, 3.0), (4, 4.0)];
493
494 let updated = SimdDistanceOps::relax_batch(&mut arr, 0.0, &neighbors);
495
496 assert_eq!(updated, 4);
497 assert_eq!(arr.get(1), 1.0);
498 assert_eq!(arr.get(2), 2.0);
499 assert_eq!(arr.get(3), 3.0);
500 assert_eq!(arr.get(4), 4.0);
501 }
502
503 #[test]
504 fn test_relax_batch_no_update() {
505 let mut arr = DistanceArray::from_slice(&[0.0, 0.5, 1.0, 1.5, 2.0]);
506
507 let neighbors = vec![
508 (1, 2.0), (2, 3.0), ];
511
512 let updated = SimdDistanceOps::relax_batch(&mut arr, 0.0, &neighbors);
513
514 assert_eq!(updated, 0); }
516
517 #[test]
518 fn test_count_below_threshold() {
519 let arr = DistanceArray::from_slice(&[0.0, 0.5, 1.0, 1.5, 2.0, f64::INFINITY]);
520
521 assert_eq!(SimdDistanceOps::count_below_threshold(&arr, 1.0), 2);
522 assert_eq!(SimdDistanceOps::count_below_threshold(&arr, 2.0), 4);
523 assert_eq!(SimdDistanceOps::count_below_threshold(&arr, 10.0), 5);
524 }
525
526 #[test]
527 fn test_sum_finite() {
528 let arr = DistanceArray::from_slice(&[1.0, 2.0, 3.0, f64::INFINITY, f64::INFINITY]);
529
530 let (sum, count) = SimdDistanceOps::sum_finite(&arr);
531
532 assert_eq!(sum, 6.0);
533 assert_eq!(count, 3);
534 }
535
536 #[test]
537 fn test_elementwise_min() {
538 let a = DistanceArray::from_slice(&[1.0, 5.0, 3.0, 7.0]);
539 let b = DistanceArray::from_slice(&[2.0, 4.0, 6.0, 1.0]);
540
541 let result = SimdDistanceOps::elementwise_min(&a, &b);
542
543 assert_eq!(result.as_slice(), &[1.0, 4.0, 3.0, 1.0]);
544 }
545
546 #[test]
547 fn test_scale() {
548 let mut arr = DistanceArray::from_slice(&[1.0, 2.0, f64::INFINITY, 4.0]);
549
550 SimdDistanceOps::scale(&mut arr, 2.0);
551
552 assert_eq!(arr.get(0), 2.0);
553 assert_eq!(arr.get(1), 4.0);
554 assert!(arr.get(2).is_infinite());
555 assert_eq!(arr.get(3), 8.0);
556 }
557
558 #[test]
559 fn test_priority_entry_ordering() {
560 let a = PriorityEntry::new(1.0, 1);
561 let b = PriorityEntry::new(2.0, 2);
562
563 assert!(a > b);
565 }
566}