Skip to main content

varpulis_simd/
lib.rs

1//! SIMD-optimized operations for high-performance event processing
2//!
3//! This crate provides vectorized implementations of common operations:
4//! - Aggregations (sum, min, max, avg)
5//! - Batch comparisons for filtering
6//! - Field extraction to contiguous arrays
7//! - Incremental aggregation accumulators
8
9#![allow(unsafe_code)]
10
11#[cfg(target_arch = "x86_64")]
12use std::arch::x86_64::*;
13
14use varpulis_core::Event;
15
16// =============================================================================
17// SIMD Aggregations (f64)
18// =============================================================================
19
20/// SIMD-optimized sum of f64 values
21/// Uses AVX2 (256-bit) when available, falls back to scalar
22#[inline]
23pub fn sum_f64(values: &[f64]) -> f64 {
24    #[cfg(target_arch = "x86_64")]
25    {
26        if is_x86_feature_detected!("avx2") {
27            // SAFETY: We checked for AVX2 support
28            unsafe { sum_f64_avx2(values) }
29        } else {
30            sum_f64_scalar(values)
31        }
32    }
33    #[cfg(not(target_arch = "x86_64"))]
34    {
35        sum_f64_scalar(values)
36    }
37}
38
39#[inline]
40fn sum_f64_scalar(values: &[f64]) -> f64 {
41    // Use 4-way unrolling for better ILP
42    let mut sum0 = 0.0;
43    let mut sum1 = 0.0;
44    let mut sum2 = 0.0;
45    let mut sum3 = 0.0;
46
47    let chunks = values.len() / 4;
48    let remainder = values.len() % 4;
49
50    for i in 0..chunks {
51        let base = i * 4;
52        // SAFETY: Loop bounds guarantee base + 3 < chunks * 4 <= values.len().
53        // chunks = values.len() / 4, so base + 3 = i*4 + 3 < chunks*4 = (len/4)*4 <= len.
54        // All indices are within bounds.
55        unsafe {
56            sum0 += *values.get_unchecked(base);
57            sum1 += *values.get_unchecked(base + 1);
58            sum2 += *values.get_unchecked(base + 2);
59            sum3 += *values.get_unchecked(base + 3);
60        }
61    }
62
63    // Handle remainder
64    let base = chunks * 4;
65    for i in 0..remainder {
66        // SAFETY: remainder = len % 4, so base + i = chunks*4 + i where i < remainder.
67        // Thus base + i < chunks*4 + remainder = chunks*4 + (len % 4) = len.
68        // All indices are within bounds.
69        unsafe {
70            sum0 += *values.get_unchecked(base + i);
71        }
72    }
73
74    sum0 + sum1 + sum2 + sum3
75}
76
77#[cfg(target_arch = "x86_64")]
78#[target_feature(enable = "avx2")]
79/// SAFETY: Caller must ensure AVX2 is available (checked via is_x86_feature_detected!).
80/// This function uses AVX2 intrinsics for 4-way parallel f64 summation.
81unsafe fn sum_f64_avx2(values: &[f64]) -> f64 {
82    let mut sum_vec = _mm256_setzero_pd();
83    let chunks = values.len() / 4;
84
85    for i in 0..chunks {
86        // SAFETY: i * 4 + 3 < chunks * 4 <= values.len(), so pointer arithmetic is in bounds.
87        // _mm256_loadu_pd handles unaligned loads safely.
88        let ptr = values.as_ptr().add(i * 4);
89        let v = _mm256_loadu_pd(ptr);
90        sum_vec = _mm256_add_pd(sum_vec, v);
91    }
92
93    // Horizontal sum of the vector
94    let mut result = [0.0f64; 4];
95    // SAFETY: result is a 4-element array, correctly sized for 256-bit store.
96    _mm256_storeu_pd(result.as_mut_ptr(), sum_vec);
97    let mut total = result[0] + result[1] + result[2] + result[3];
98
99    // Handle remainder
100    let base = chunks * 4;
101    for i in base..values.len() {
102        // SAFETY: Loop range is [base, len), so i is always in bounds.
103        total += *values.get_unchecked(i);
104    }
105
106    total
107}
108
109/// SIMD-optimized min of f64 values
110#[inline]
111pub fn min_f64(values: &[f64]) -> Option<f64> {
112    if values.is_empty() {
113        return None;
114    }
115
116    #[cfg(target_arch = "x86_64")]
117    {
118        if is_x86_feature_detected!("avx2") {
119            // SAFETY: We checked for AVX2 support and non-empty
120            unsafe { Some(min_f64_avx2(values)) }
121        } else {
122            Some(min_f64_scalar(values))
123        }
124    }
125    #[cfg(not(target_arch = "x86_64"))]
126    {
127        Some(min_f64_scalar(values))
128    }
129}
130
131#[inline]
132fn min_f64_scalar(values: &[f64]) -> f64 {
133    let mut min = f64::INFINITY;
134    for &v in values {
135        if v < min {
136            min = v;
137        }
138    }
139    min
140}
141
142#[cfg(target_arch = "x86_64")]
143#[target_feature(enable = "avx2")]
144/// SAFETY: Caller must ensure AVX2 is available and values is non-empty.
145/// Uses AVX2 intrinsics for 4-way parallel minimum computation.
146unsafe fn min_f64_avx2(values: &[f64]) -> f64 {
147    let mut min_vec = _mm256_set1_pd(f64::INFINITY);
148    let chunks = values.len() / 4;
149
150    for i in 0..chunks {
151        // SAFETY: i * 4 + 3 < chunks * 4 <= values.len(), pointer arithmetic in bounds.
152        let ptr = values.as_ptr().add(i * 4);
153        let v = _mm256_loadu_pd(ptr);
154        min_vec = _mm256_min_pd(min_vec, v);
155    }
156
157    // Horizontal min
158    let mut result = [0.0f64; 4];
159    // SAFETY: result is correctly sized for 256-bit store.
160    _mm256_storeu_pd(result.as_mut_ptr(), min_vec);
161    let mut min = result[0].min(result[1]).min(result[2]).min(result[3]);
162
163    // Handle remainder
164    let base = chunks * 4;
165    for i in base..values.len() {
166        // SAFETY: Loop range [base, len) ensures i is in bounds.
167        let v = *values.get_unchecked(i);
168        if v < min {
169            min = v;
170        }
171    }
172
173    min
174}
175
176/// SIMD-optimized max of f64 values
177#[inline]
178pub fn max_f64(values: &[f64]) -> Option<f64> {
179    if values.is_empty() {
180        return None;
181    }
182
183    #[cfg(target_arch = "x86_64")]
184    {
185        if is_x86_feature_detected!("avx2") {
186            // SAFETY: We checked for AVX2 support and non-empty
187            unsafe { Some(max_f64_avx2(values)) }
188        } else {
189            Some(max_f64_scalar(values))
190        }
191    }
192    #[cfg(not(target_arch = "x86_64"))]
193    {
194        Some(max_f64_scalar(values))
195    }
196}
197
198#[inline]
199fn max_f64_scalar(values: &[f64]) -> f64 {
200    let mut max = f64::NEG_INFINITY;
201    for &v in values {
202        if v > max {
203            max = v;
204        }
205    }
206    max
207}
208
209#[cfg(target_arch = "x86_64")]
210#[target_feature(enable = "avx2")]
211/// SAFETY: Caller must ensure AVX2 is available and values is non-empty.
212/// Uses AVX2 intrinsics for 4-way parallel maximum computation.
213unsafe fn max_f64_avx2(values: &[f64]) -> f64 {
214    let mut max_vec = _mm256_set1_pd(f64::NEG_INFINITY);
215    let chunks = values.len() / 4;
216
217    for i in 0..chunks {
218        // SAFETY: i * 4 + 3 < chunks * 4 <= values.len(), pointer arithmetic in bounds.
219        let ptr = values.as_ptr().add(i * 4);
220        let v = _mm256_loadu_pd(ptr);
221        max_vec = _mm256_max_pd(max_vec, v);
222    }
223
224    // Horizontal max
225    let mut result = [0.0f64; 4];
226    // SAFETY: result is correctly sized for 256-bit store.
227    _mm256_storeu_pd(result.as_mut_ptr(), max_vec);
228    let mut max = result[0].max(result[1]).max(result[2]).max(result[3]);
229
230    // Handle remainder
231    let base = chunks * 4;
232    for i in base..values.len() {
233        // SAFETY: Loop range [base, len) ensures i is in bounds.
234        let v = *values.get_unchecked(i);
235        if v > max {
236            max = v;
237        }
238    }
239
240    max
241}
242
243// =============================================================================
244// SIMD Batch Comparisons
245// =============================================================================
246
247/// SIMD-optimized greater-than comparison
248/// Returns a bitmask where bit `i` is set if `values[i] > threshold`
249#[inline]
250pub fn compare_gt_f64(values: &[f64], threshold: f64) -> Vec<bool> {
251    let mut result = vec![false; values.len()];
252
253    #[cfg(target_arch = "x86_64")]
254    {
255        if is_x86_feature_detected!("avx2") {
256            // SAFETY: We checked for AVX2 support
257            unsafe { compare_gt_f64_avx2(values, threshold, &mut result) };
258            return result;
259        }
260    }
261
262    // Scalar fallback
263    for (i, &v) in values.iter().enumerate() {
264        result[i] = v > threshold;
265    }
266    result
267}
268
269#[cfg(target_arch = "x86_64")]
270#[target_feature(enable = "avx2")]
271/// SAFETY: Caller must ensure AVX2 is available.
272/// result must have length >= values.len().
273/// Uses AVX2 intrinsics for 4-way parallel greater-than comparison.
274unsafe fn compare_gt_f64_avx2(values: &[f64], threshold: f64, result: &mut [bool]) {
275    let thresh_vec = _mm256_set1_pd(threshold);
276    let chunks = values.len() / 4;
277
278    for i in 0..chunks {
279        // SAFETY: i * 4 + 3 < chunks * 4 <= values.len(), pointer in bounds.
280        let ptr = values.as_ptr().add(i * 4);
281        let v = _mm256_loadu_pd(ptr);
282        let cmp = _mm256_cmp_pd(v, thresh_vec, _CMP_GT_OQ);
283        let mask = _mm256_movemask_pd(cmp);
284
285        let base = i * 4;
286        // SAFETY: base + 3 < chunks * 4 <= result.len() (result has same len as values).
287        *result.get_unchecked_mut(base) = (mask & 1) != 0;
288        *result.get_unchecked_mut(base + 1) = (mask & 2) != 0;
289        *result.get_unchecked_mut(base + 2) = (mask & 4) != 0;
290        *result.get_unchecked_mut(base + 3) = (mask & 8) != 0;
291    }
292
293    // Handle remainder
294    let base = chunks * 4;
295    for i in base..values.len() {
296        // SAFETY: i is in [base, values.len()), both slices have same length.
297        *result.get_unchecked_mut(i) = *values.get_unchecked(i) > threshold;
298    }
299}
300
301/// SIMD-optimized less-than comparison
302#[inline]
303pub fn compare_lt_f64(values: &[f64], threshold: f64) -> Vec<bool> {
304    let mut result = vec![false; values.len()];
305
306    #[cfg(target_arch = "x86_64")]
307    {
308        if is_x86_feature_detected!("avx2") {
309            // SAFETY: We checked for AVX2 support
310            unsafe { compare_lt_f64_avx2(values, threshold, &mut result) };
311            return result;
312        }
313    }
314
315    // Scalar fallback
316    for (i, &v) in values.iter().enumerate() {
317        result[i] = v < threshold;
318    }
319    result
320}
321
322#[cfg(target_arch = "x86_64")]
323#[target_feature(enable = "avx2")]
324/// SAFETY: Caller must ensure AVX2 is available.
325/// result must have length >= values.len().
326/// Uses AVX2 intrinsics for 4-way parallel less-than comparison.
327unsafe fn compare_lt_f64_avx2(values: &[f64], threshold: f64, result: &mut [bool]) {
328    let thresh_vec = _mm256_set1_pd(threshold);
329    let chunks = values.len() / 4;
330
331    for i in 0..chunks {
332        // SAFETY: i * 4 + 3 < chunks * 4 <= values.len(), pointer in bounds.
333        let ptr = values.as_ptr().add(i * 4);
334        let v = _mm256_loadu_pd(ptr);
335        let cmp = _mm256_cmp_pd(v, thresh_vec, _CMP_LT_OQ);
336        let mask = _mm256_movemask_pd(cmp);
337
338        let base = i * 4;
339        // SAFETY: base + 3 < chunks * 4 <= result.len() (result has same len as values).
340        *result.get_unchecked_mut(base) = (mask & 1) != 0;
341        *result.get_unchecked_mut(base + 1) = (mask & 2) != 0;
342        *result.get_unchecked_mut(base + 2) = (mask & 4) != 0;
343        *result.get_unchecked_mut(base + 3) = (mask & 8) != 0;
344    }
345
346    // Handle remainder
347    let base = chunks * 4;
348    for i in base..values.len() {
349        // SAFETY: i is in [base, values.len()), both slices have same length.
350        *result.get_unchecked_mut(i) = *values.get_unchecked(i) < threshold;
351    }
352}
353
354// =============================================================================
355// Field Extraction
356// =============================================================================
357
358/// Extract float field values from events into a contiguous array
359/// Returns None values as NaN for SIMD processing
360#[inline]
361pub fn extract_field_f64(events: &[Event], field: &str) -> Vec<f64> {
362    let mut values = Vec::with_capacity(events.len());
363    for event in events {
364        values.push(event.get_float(field).unwrap_or(f64::NAN));
365    }
366    values
367}
368
369/// Extract float field values from events, filtering out None values
370/// Returns (values, indices) where indices maps back to original positions
371#[inline]
372pub fn extract_field_f64_filtered(events: &[Event], field: &str) -> (Vec<f64>, Vec<usize>) {
373    let mut values = Vec::with_capacity(events.len());
374    let mut indices = Vec::with_capacity(events.len());
375
376    for (i, event) in events.iter().enumerate() {
377        if let Some(v) = event.get_float(field) {
378            if !v.is_nan() {
379                values.push(v);
380                indices.push(i);
381            }
382        }
383    }
384
385    (values, indices)
386}
387
388// =============================================================================
389// SIMD Aggregation Wrappers
390// =============================================================================
391
392/// Compute sum of a field across events using SIMD
393pub fn simd_sum(events: &[Event], field: &str) -> f64 {
394    let values = extract_field_f64(events, field);
395    // Filter out NaN before summing
396    let valid: Vec<f64> = values.into_iter().filter(|v| !v.is_nan()).collect();
397    sum_f64(&valid)
398}
399
400/// Compute avg of a field across events using SIMD
401pub fn simd_avg(events: &[Event], field: &str) -> Option<f64> {
402    let values = extract_field_f64(events, field);
403    let valid: Vec<f64> = values.into_iter().filter(|v| !v.is_nan()).collect();
404    if valid.is_empty() {
405        None
406    } else {
407        Some(sum_f64(&valid) / valid.len() as f64)
408    }
409}
410
411/// Compute min of a field across events using SIMD
412pub fn simd_min(events: &[Event], field: &str) -> Option<f64> {
413    let (values, _) = extract_field_f64_filtered(events, field);
414    min_f64(&values)
415}
416
417/// Compute max of a field across events using SIMD
418pub fn simd_max(events: &[Event], field: &str) -> Option<f64> {
419    let (values, _) = extract_field_f64_filtered(events, field);
420    max_f64(&values)
421}
422
423// =============================================================================
424// Incremental Aggregation
425// =============================================================================
426
427/// Incremental sum accumulator - O(1) updates instead of O(n) recomputation
428#[derive(Debug, Clone, Default)]
429pub struct IncrementalSum {
430    sum: f64,
431    count: usize,
432}
433
434impl IncrementalSum {
435    /// Create a new incremental sum accumulator.
436    pub fn new() -> Self {
437        Self::default()
438    }
439
440    /// Add a value to the accumulator.
441    #[inline]
442    pub fn add(&mut self, value: f64) {
443        if !value.is_nan() {
444            self.sum += value;
445            self.count += 1;
446        }
447    }
448
449    /// Remove a value from the accumulator.
450    #[inline]
451    pub fn remove(&mut self, value: f64) {
452        if !value.is_nan() {
453            self.sum -= value;
454            self.count = self.count.saturating_sub(1);
455        }
456    }
457
458    /// Current sum value.
459    #[inline]
460    pub const fn sum(&self) -> f64 {
461        self.sum
462    }
463
464    /// Number of values added.
465    #[inline]
466    pub const fn count(&self) -> usize {
467        self.count
468    }
469
470    /// Current average, or None if empty.
471    #[inline]
472    pub fn avg(&self) -> Option<f64> {
473        if self.count == 0 {
474            None
475        } else {
476            Some(self.sum / self.count as f64)
477        }
478    }
479
480    /// Reset the accumulator to zero.
481    pub const fn reset(&mut self) {
482        self.sum = 0.0;
483        self.count = 0;
484    }
485}
486
487/// Wrapper for f64 that implements Ord using total ordering.
488/// NaN values sort after all other values for consistency.
489#[derive(Debug, Clone, Copy)]
490struct OrderedF64(f64);
491
492impl PartialEq for OrderedF64 {
493    fn eq(&self, other: &Self) -> bool {
494        self.0.total_cmp(&other.0) == std::cmp::Ordering::Equal
495    }
496}
497
498impl Eq for OrderedF64 {}
499
500impl PartialOrd for OrderedF64 {
501    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
502        Some(self.cmp(other))
503    }
504}
505
506impl Ord for OrderedF64 {
507    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
508        self.0.total_cmp(&other.0)
509    }
510}
511
512/// Incremental min/max tracker using BTreeMap for O(log n) operations
513#[derive(Debug, Clone)]
514pub struct IncrementalMinMax {
515    // BTreeMap provides O(log n) insert, remove, min, max
516    // Value is count of duplicates for the same f64
517    values: std::collections::BTreeMap<OrderedF64, usize>,
518}
519
520impl Default for IncrementalMinMax {
521    fn default() -> Self {
522        Self::new()
523    }
524}
525
526impl IncrementalMinMax {
527    /// Create a new incremental min/max tracker.
528    pub const fn new() -> Self {
529        Self {
530            values: std::collections::BTreeMap::new(),
531        }
532    }
533
534    /// Add a value to the tracker.
535    #[inline]
536    pub fn add(&mut self, value: f64) {
537        if !value.is_nan() {
538            *self.values.entry(OrderedF64(value)).or_insert(0) += 1;
539        }
540    }
541
542    /// Remove a value from the tracker.
543    #[inline]
544    pub fn remove(&mut self, value: f64) {
545        if !value.is_nan() {
546            let key = OrderedF64(value);
547            if let std::collections::btree_map::Entry::Occupied(mut entry) = self.values.entry(key)
548            {
549                let count = entry.get_mut();
550                if *count > 1 {
551                    *count -= 1;
552                } else {
553                    entry.remove();
554                }
555            }
556        }
557    }
558
559    /// Current minimum value, or None if empty.
560    pub fn min(&mut self) -> Option<f64> {
561        self.values.first_key_value().map(|(k, _)| k.0)
562    }
563
564    /// Current maximum value, or None if empty.
565    pub fn max(&mut self) -> Option<f64> {
566        self.values.last_key_value().map(|(k, _)| k.0)
567    }
568
569    /// Clear all tracked values.
570    pub fn reset(&mut self) {
571        self.values.clear();
572    }
573}
574
575// =============================================================================
576// Tests
577// =============================================================================
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_sum_f64_empty() {
585        assert_eq!(sum_f64(&[]), 0.0);
586    }
587
588    #[test]
589    fn test_sum_f64_single() {
590        assert_eq!(sum_f64(&[42.0]), 42.0);
591    }
592
593    #[test]
594    fn test_sum_f64_multiple() {
595        let values: Vec<f64> = (1..=100).map(|x| x as f64).collect();
596        assert_eq!(sum_f64(&values), 5050.0);
597    }
598
599    #[test]
600    fn test_sum_f64_large() {
601        let values: Vec<f64> = (1..=10000).map(|x| x as f64).collect();
602        let expected: f64 = (1..=10000).sum::<i64>() as f64;
603        assert!((sum_f64(&values) - expected).abs() < 0.001);
604    }
605
606    #[test]
607    fn test_min_f64() {
608        assert_eq!(min_f64(&[3.0, 1.0, 4.0, 1.0, 5.0]), Some(1.0));
609        assert_eq!(min_f64(&[]), None);
610    }
611
612    #[test]
613    fn test_max_f64() {
614        assert_eq!(max_f64(&[3.0, 1.0, 4.0, 1.0, 5.0]), Some(5.0));
615        assert_eq!(max_f64(&[]), None);
616    }
617
618    #[test]
619    fn test_compare_gt() {
620        let values = vec![1.0, 5.0, 3.0, 7.0, 2.0];
621        let result = compare_gt_f64(&values, 3.0);
622        assert_eq!(result, vec![false, true, false, true, false]);
623    }
624
625    #[test]
626    fn test_compare_lt() {
627        let values = vec![1.0, 5.0, 3.0, 7.0, 2.0];
628        let result = compare_lt_f64(&values, 3.0);
629        assert_eq!(result, vec![true, false, false, false, true]);
630    }
631
632    #[test]
633    fn test_incremental_sum() {
634        let mut acc = IncrementalSum::new();
635        acc.add(10.0);
636        acc.add(20.0);
637        acc.add(30.0);
638        assert_eq!(acc.sum(), 60.0);
639        assert_eq!(acc.count(), 3);
640        assert_eq!(acc.avg(), Some(20.0));
641
642        acc.remove(20.0);
643        assert_eq!(acc.sum(), 40.0);
644        assert_eq!(acc.count(), 2);
645        assert_eq!(acc.avg(), Some(20.0));
646    }
647
648    #[test]
649    fn test_incremental_minmax() {
650        let mut tracker = IncrementalMinMax::new();
651        tracker.add(5.0);
652        tracker.add(3.0);
653        tracker.add(7.0);
654        tracker.add(1.0);
655
656        assert_eq!(tracker.min(), Some(1.0));
657        assert_eq!(tracker.max(), Some(7.0));
658
659        tracker.remove(1.0);
660        assert_eq!(tracker.min(), Some(3.0));
661
662        tracker.remove(7.0);
663        assert_eq!(tracker.max(), Some(5.0));
664    }
665
666    #[test]
667    fn test_sum_scalar_vs_avx2() {
668        // Test that scalar and AVX2 produce same results
669        let values: Vec<f64> = (1..=1000).map(|x| x as f64).collect();
670        let scalar = sum_f64_scalar(&values);
671
672        #[cfg(target_arch = "x86_64")]
673        {
674            if is_x86_feature_detected!("avx2") {
675                let avx2 = unsafe { sum_f64_avx2(&values) };
676                assert!((scalar - avx2).abs() < 0.001);
677            }
678        }
679    }
680}