Skip to main content

ruvector_mincut/optimization/
simd_distance.rs

1//! SIMD-Optimized Distance Array Operations
2//!
3//! Provides vectorized operations for distance arrays:
4//! - Parallel min/max finding
5//! - Batch distance updates
6//! - Vector comparisons
7//!
8//! Uses WASM SIMD128 when available, falls back to scalar.
9
10use crate::graph::VertexId;
11
12#[cfg(target_arch = "wasm32")]
13use core::arch::wasm32::*;
14
15/// Alignment for SIMD operations (64 bytes for AVX-512 compatibility)
16pub const SIMD_ALIGNMENT: usize = 64;
17
18/// Number of f64 elements per SIMD operation
19pub const SIMD_LANES: usize = 4; // 256-bit = 4 x f64
20
21/// Aligned distance array for SIMD operations
22#[repr(C, align(64))]
23pub struct DistanceArray {
24    /// Raw distance values
25    data: Vec<f64>,
26    /// Number of vertices
27    len: usize,
28}
29
30impl DistanceArray {
31    /// Create new distance array initialized to infinity
32    pub fn new(size: usize) -> Self {
33        Self {
34            data: vec![f64::INFINITY; size],
35            len: size,
36        }
37    }
38
39    /// Create from slice
40    pub fn from_slice(slice: &[f64]) -> Self {
41        Self {
42            data: slice.to_vec(),
43            len: slice.len(),
44        }
45    }
46
47    /// Get distance for vertex
48    #[inline]
49    pub fn get(&self, v: VertexId) -> f64 {
50        self.data.get(v as usize).copied().unwrap_or(f64::INFINITY)
51    }
52
53    /// Set distance for vertex
54    #[inline]
55    pub fn set(&mut self, v: VertexId, distance: f64) {
56        if (v as usize) < self.len {
57            self.data[v as usize] = distance;
58        }
59    }
60
61    /// Get number of elements
62    pub fn len(&self) -> usize {
63        self.len
64    }
65
66    /// Check if empty
67    pub fn is_empty(&self) -> bool {
68        self.len == 0
69    }
70
71    /// Reset all distances to infinity
72    pub fn reset(&mut self) {
73        for d in &mut self.data {
74            *d = f64::INFINITY;
75        }
76    }
77
78    /// Get raw slice
79    pub fn as_slice(&self) -> &[f64] {
80        &self.data
81    }
82
83    /// Get mutable slice
84    pub fn as_mut_slice(&mut self) -> &mut [f64] {
85        &mut self.data
86    }
87}
88
89/// SIMD-optimized distance operations
90pub struct SimdDistanceOps;
91
92impl SimdDistanceOps {
93    /// Find minimum distance and its index using SIMD
94    ///
95    /// Returns (min_distance, min_index)
96    #[cfg(target_arch = "wasm32")]
97    pub fn find_min(distances: &DistanceArray) -> (f64, usize) {
98        let data = distances.as_slice();
99        if data.is_empty() {
100            return (f64::INFINITY, 0);
101        }
102
103        let mut min_val = f64::INFINITY;
104        let mut min_idx = 0;
105
106        // Process in chunks of 2 (WASM SIMD has 128-bit = 2 x f64)
107        let chunks = data.len() / 2;
108
109        unsafe {
110            for i in 0..chunks {
111                let offset = i * 2;
112                let v = v128_load(data.as_ptr().add(offset) as *const v128);
113
114                let a = f64x2_extract_lane::<0>(v);
115                let b = f64x2_extract_lane::<1>(v);
116
117                if a < min_val {
118                    min_val = a;
119                    min_idx = offset;
120                }
121                if b < min_val {
122                    min_val = b;
123                    min_idx = offset + 1;
124                }
125            }
126        }
127
128        // Handle remainder
129        for i in (chunks * 2)..data.len() {
130            if data[i] < min_val {
131                min_val = data[i];
132                min_idx = i;
133            }
134        }
135
136        (min_val, min_idx)
137    }
138
139    /// Find minimum distance and its index (scalar fallback)
140    #[cfg(not(target_arch = "wasm32"))]
141    pub fn find_min(distances: &DistanceArray) -> (f64, usize) {
142        let data = distances.as_slice();
143        if data.is_empty() {
144            return (f64::INFINITY, 0);
145        }
146
147        let mut min_val = f64::INFINITY;
148        let mut min_idx = 0;
149
150        // Unrolled loop for better ILP
151        let chunks = data.len() / 4;
152        for i in 0..chunks {
153            let base = i * 4;
154            let a = data[base];
155            let b = data[base + 1];
156            let c = data[base + 2];
157            let d = data[base + 3];
158
159            if a < min_val {
160                min_val = a;
161                min_idx = base;
162            }
163            if b < min_val {
164                min_val = b;
165                min_idx = base + 1;
166            }
167            if c < min_val {
168                min_val = c;
169                min_idx = base + 2;
170            }
171            if d < min_val {
172                min_val = d;
173                min_idx = base + 3;
174            }
175        }
176
177        // Handle remainder
178        for i in (chunks * 4)..data.len() {
179            if data[i] < min_val {
180                min_val = data[i];
181                min_idx = i;
182            }
183        }
184
185        (min_val, min_idx)
186    }
187
188    /// Batch update: dist[i] = min(dist[i], dist[source] + weight[i])
189    ///
190    /// This is the core Dijkstra relaxation operation
191    #[cfg(target_arch = "wasm32")]
192    pub fn relax_batch(
193        distances: &mut DistanceArray,
194        source_dist: f64,
195        neighbors: &[(VertexId, f64)], // (neighbor_id, edge_weight)
196    ) -> usize {
197        let mut updated = 0;
198        let data = distances.as_mut_slice();
199
200        unsafe {
201            let source_v = f64x2_splat(source_dist);
202
203            // Process pairs
204            let pairs = neighbors.len() / 2;
205            for i in 0..pairs {
206                let idx0 = neighbors[i * 2].0 as usize;
207                let idx1 = neighbors[i * 2 + 1].0 as usize;
208                let w0 = neighbors[i * 2].1;
209                let w1 = neighbors[i * 2 + 1].1;
210
211                if idx0 < data.len() && idx1 < data.len() {
212                    let weights = f64x2(w0, w1);
213                    let new_dist = f64x2_add(source_v, weights);
214
215                    let old0 = data[idx0];
216                    let old1 = data[idx1];
217
218                    let new0 = f64x2_extract_lane::<0>(new_dist);
219                    let new1 = f64x2_extract_lane::<1>(new_dist);
220
221                    if new0 < old0 {
222                        data[idx0] = new0;
223                        updated += 1;
224                    }
225                    if new1 < old1 {
226                        data[idx1] = new1;
227                        updated += 1;
228                    }
229                }
230            }
231        }
232
233        // Handle odd remainder
234        if neighbors.len() % 2 == 1 {
235            let (idx, weight) = neighbors[neighbors.len() - 1];
236            let idx = idx as usize;
237            if idx < data.len() {
238                let new_dist = source_dist + weight;
239                if new_dist < data[idx] {
240                    data[idx] = new_dist;
241                    updated += 1;
242                }
243            }
244        }
245
246        updated
247    }
248
249    /// Batch update (scalar fallback)
250    #[cfg(not(target_arch = "wasm32"))]
251    pub fn relax_batch(
252        distances: &mut DistanceArray,
253        source_dist: f64,
254        neighbors: &[(VertexId, f64)],
255    ) -> usize {
256        let mut updated = 0;
257        let data = distances.as_mut_slice();
258
259        // Process in chunks of 4 for better ILP
260        let chunks = neighbors.len() / 4;
261
262        for i in 0..chunks {
263            let base = i * 4;
264
265            let (idx0, w0) = neighbors[base];
266            let (idx1, w1) = neighbors[base + 1];
267            let (idx2, w2) = neighbors[base + 2];
268            let (idx3, w3) = neighbors[base + 3];
269
270            let new0 = source_dist + w0;
271            let new1 = source_dist + w1;
272            let new2 = source_dist + w2;
273            let new3 = source_dist + w3;
274
275            let idx0 = idx0 as usize;
276            let idx1 = idx1 as usize;
277            let idx2 = idx2 as usize;
278            let idx3 = idx3 as usize;
279
280            if idx0 < data.len() && new0 < data[idx0] {
281                data[idx0] = new0;
282                updated += 1;
283            }
284            if idx1 < data.len() && new1 < data[idx1] {
285                data[idx1] = new1;
286                updated += 1;
287            }
288            if idx2 < data.len() && new2 < data[idx2] {
289                data[idx2] = new2;
290                updated += 1;
291            }
292            if idx3 < data.len() && new3 < data[idx3] {
293                data[idx3] = new3;
294                updated += 1;
295            }
296        }
297
298        // Handle remainder
299        for i in (chunks * 4)..neighbors.len() {
300            let (idx, weight) = neighbors[i];
301            let idx = idx as usize;
302            if idx < data.len() {
303                let new_dist = source_dist + weight;
304                if new_dist < data[idx] {
305                    data[idx] = new_dist;
306                    updated += 1;
307                }
308            }
309        }
310
311        updated
312    }
313
314    /// Count vertices with distance less than threshold
315    #[cfg(target_arch = "wasm32")]
316    pub fn count_below_threshold(distances: &DistanceArray, threshold: f64) -> usize {
317        let data = distances.as_slice();
318        let mut count = 0;
319
320        unsafe {
321            let thresh_v = f64x2_splat(threshold);
322
323            let chunks = data.len() / 2;
324            for i in 0..chunks {
325                let offset = i * 2;
326                let v = v128_load(data.as_ptr().add(offset) as *const v128);
327                let cmp = f64x2_lt(v, thresh_v);
328
329                // Extract comparison results
330                let mask = i8x16_bitmask(cmp);
331                // Each f64 lane uses 8 bits in bitmask
332                if mask & 0xFF != 0 {
333                    count += 1;
334                }
335                if mask & 0xFF00 != 0 {
336                    count += 1;
337                }
338            }
339        }
340
341        // Handle remainder
342        for i in (data.len() / 2 * 2)..data.len() {
343            if data[i] < threshold {
344                count += 1;
345            }
346        }
347
348        count
349    }
350
351    /// Count vertices with distance less than threshold (scalar fallback)
352    #[cfg(not(target_arch = "wasm32"))]
353    pub fn count_below_threshold(distances: &DistanceArray, threshold: f64) -> usize {
354        distances
355            .as_slice()
356            .iter()
357            .filter(|&&d| d < threshold)
358            .count()
359    }
360
361    /// Compute sum of distances (for average)
362    pub fn sum_finite(distances: &DistanceArray) -> (f64, usize) {
363        let mut sum = 0.0;
364        let mut count = 0;
365
366        for &d in distances.as_slice() {
367            if d.is_finite() {
368                sum += d;
369                count += 1;
370            }
371        }
372
373        (sum, count)
374    }
375
376    /// Element-wise minimum of two distance arrays
377    pub fn elementwise_min(a: &DistanceArray, b: &DistanceArray) -> DistanceArray {
378        let len = a.len().min(b.len());
379        let mut result = DistanceArray::new(len);
380
381        let a_data = a.as_slice();
382        let b_data = b.as_slice();
383        let r_data = result.as_mut_slice();
384
385        // Unrolled loop
386        let chunks = len / 4;
387        for i in 0..chunks {
388            let base = i * 4;
389            r_data[base] = a_data[base].min(b_data[base]);
390            r_data[base + 1] = a_data[base + 1].min(b_data[base + 1]);
391            r_data[base + 2] = a_data[base + 2].min(b_data[base + 2]);
392            r_data[base + 3] = a_data[base + 3].min(b_data[base + 3]);
393        }
394
395        for i in (chunks * 4)..len {
396            r_data[i] = a_data[i].min(b_data[i]);
397        }
398
399        result
400    }
401
402    /// Scale all distances by a factor
403    pub fn scale(distances: &mut DistanceArray, factor: f64) {
404        for d in distances.as_mut_slice() {
405            if d.is_finite() {
406                *d *= factor;
407            }
408        }
409    }
410}
411
412/// Priority queue entry for Dijkstra with SIMD-friendly layout
413#[repr(C)]
414#[derive(Debug, Clone, Copy)]
415pub struct PriorityEntry {
416    /// Distance (key)
417    pub distance: f64,
418    /// Vertex ID
419    pub vertex: VertexId,
420}
421
422impl PriorityEntry {
423    /// Create a new priority entry with given distance and vertex.
424    pub fn new(distance: f64, vertex: VertexId) -> Self {
425        Self { distance, vertex }
426    }
427}
428
429impl PartialEq for PriorityEntry {
430    fn eq(&self, other: &Self) -> bool {
431        self.distance == other.distance && self.vertex == other.vertex
432    }
433}
434
435impl Eq for PriorityEntry {}
436
437impl PartialOrd for PriorityEntry {
438    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
439        // Reverse order for min-heap
440        other.distance.partial_cmp(&self.distance)
441    }
442}
443
444impl Ord for PriorityEntry {
445    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
446        self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_distance_array_basic() {
456        let mut arr = DistanceArray::new(10);
457
458        arr.set(0, 1.0);
459        arr.set(5, 5.0);
460
461        assert_eq!(arr.get(0), 1.0);
462        assert_eq!(arr.get(5), 5.0);
463        assert_eq!(arr.get(9), f64::INFINITY);
464    }
465
466    #[test]
467    fn test_find_min() {
468        let mut arr = DistanceArray::new(100);
469
470        arr.set(50, 1.0);
471        arr.set(25, 0.5);
472        arr.set(75, 2.0);
473
474        let (min_val, min_idx) = SimdDistanceOps::find_min(&arr);
475
476        assert_eq!(min_val, 0.5);
477        assert_eq!(min_idx, 25);
478    }
479
480    #[test]
481    fn test_find_min_empty() {
482        let arr = DistanceArray::new(0);
483        let (min_val, _) = SimdDistanceOps::find_min(&arr);
484        assert!(min_val.is_infinite());
485    }
486
487    #[test]
488    fn test_relax_batch() {
489        let mut arr = DistanceArray::new(10);
490        arr.set(0, 0.0); // Source
491
492        let neighbors = vec![(1, 1.0), (2, 2.0), (3, 3.0), (4, 4.0)];
493
494        let updated = SimdDistanceOps::relax_batch(&mut arr, 0.0, &neighbors);
495
496        assert_eq!(updated, 4);
497        assert_eq!(arr.get(1), 1.0);
498        assert_eq!(arr.get(2), 2.0);
499        assert_eq!(arr.get(3), 3.0);
500        assert_eq!(arr.get(4), 4.0);
501    }
502
503    #[test]
504    fn test_relax_batch_no_update() {
505        let mut arr = DistanceArray::from_slice(&[0.0, 0.5, 1.0, 1.5, 2.0]);
506
507        let neighbors = vec![
508            (1, 2.0), // New dist = 0 + 2.0 = 2.0 > 0.5
509            (2, 3.0), // New dist = 0 + 3.0 = 3.0 > 1.0
510        ];
511
512        let updated = SimdDistanceOps::relax_batch(&mut arr, 0.0, &neighbors);
513
514        assert_eq!(updated, 0); // No updates, existing distances are better
515    }
516
517    #[test]
518    fn test_count_below_threshold() {
519        let arr = DistanceArray::from_slice(&[0.0, 0.5, 1.0, 1.5, 2.0, f64::INFINITY]);
520
521        assert_eq!(SimdDistanceOps::count_below_threshold(&arr, 1.0), 2);
522        assert_eq!(SimdDistanceOps::count_below_threshold(&arr, 2.0), 4);
523        assert_eq!(SimdDistanceOps::count_below_threshold(&arr, 10.0), 5);
524    }
525
526    #[test]
527    fn test_sum_finite() {
528        let arr = DistanceArray::from_slice(&[1.0, 2.0, 3.0, f64::INFINITY, f64::INFINITY]);
529
530        let (sum, count) = SimdDistanceOps::sum_finite(&arr);
531
532        assert_eq!(sum, 6.0);
533        assert_eq!(count, 3);
534    }
535
536    #[test]
537    fn test_elementwise_min() {
538        let a = DistanceArray::from_slice(&[1.0, 5.0, 3.0, 7.0]);
539        let b = DistanceArray::from_slice(&[2.0, 4.0, 6.0, 1.0]);
540
541        let result = SimdDistanceOps::elementwise_min(&a, &b);
542
543        assert_eq!(result.as_slice(), &[1.0, 4.0, 3.0, 1.0]);
544    }
545
546    #[test]
547    fn test_scale() {
548        let mut arr = DistanceArray::from_slice(&[1.0, 2.0, f64::INFINITY, 4.0]);
549
550        SimdDistanceOps::scale(&mut arr, 2.0);
551
552        assert_eq!(arr.get(0), 2.0);
553        assert_eq!(arr.get(1), 4.0);
554        assert!(arr.get(2).is_infinite());
555        assert_eq!(arr.get(3), 8.0);
556    }
557
558    #[test]
559    fn test_priority_entry_ordering() {
560        let a = PriorityEntry::new(1.0, 1);
561        let b = PriorityEntry::new(2.0, 2);
562
563        // Min-heap ordering: smaller distance is "greater"
564        assert!(a > b);
565    }
566}