Skip to main content

shape_runtime/
simd_comparisons.rs

1//! SIMD-accelerated comparison operations for series
2//!
3//! Provides vectorized implementations of comparison and logical operators.
4//! Uses manual loop unrolling to enable compiler auto-vectorization.
5//!
6//! Expected speedup: 2-4x for arrays larger than SIMD_THRESHOLD (64 elements)
7
8const SIMD_THRESHOLD: usize = 64;
9
10// ===== Public API - Feature-gated SIMD selection =====
11
12/// Greater-than comparison (auto-selects SIMD or scalar)
13#[cfg(feature = "simd")]
14#[inline]
15pub fn gt(left: &[f64], right: &[f64]) -> Vec<f64> {
16    gt_simd(left, right)
17}
18
19#[cfg(not(feature = "simd"))]
20#[inline]
21pub fn gt(left: &[f64], right: &[f64]) -> Vec<f64> {
22    gt_scalar(left, right)
23}
24
25/// Less-than comparison (auto-selects SIMD or scalar)
26#[cfg(feature = "simd")]
27#[inline]
28pub fn lt(left: &[f64], right: &[f64]) -> Vec<f64> {
29    lt_simd(left, right)
30}
31
32#[cfg(not(feature = "simd"))]
33#[inline]
34pub fn lt(left: &[f64], right: &[f64]) -> Vec<f64> {
35    lt_scalar(left, right)
36}
37
38/// Greater-than-or-equal comparison (auto-selects SIMD or scalar)
39#[cfg(feature = "simd")]
40#[inline]
41pub fn gte(left: &[f64], right: &[f64]) -> Vec<f64> {
42    gte_simd(left, right)
43}
44
45#[cfg(not(feature = "simd"))]
46#[inline]
47pub fn gte(left: &[f64], right: &[f64]) -> Vec<f64> {
48    gte_scalar(left, right)
49}
50
51/// Less-than-or-equal comparison (auto-selects SIMD or scalar)
52#[cfg(feature = "simd")]
53#[inline]
54pub fn lte(left: &[f64], right: &[f64]) -> Vec<f64> {
55    lte_simd(left, right)
56}
57
58#[cfg(not(feature = "simd"))]
59#[inline]
60pub fn lte(left: &[f64], right: &[f64]) -> Vec<f64> {
61    lte_scalar(left, right)
62}
63
64/// Equality comparison (auto-selects SIMD or scalar)
65#[cfg(feature = "simd")]
66#[inline]
67pub fn eq(left: &[f64], right: &[f64]) -> Vec<f64> {
68    eq_simd(left, right)
69}
70
71#[cfg(not(feature = "simd"))]
72#[inline]
73pub fn eq(left: &[f64], right: &[f64]) -> Vec<f64> {
74    eq_scalar(left, right)
75}
76
77/// Not-equal comparison (auto-selects SIMD or scalar)
78#[cfg(feature = "simd")]
79#[inline]
80pub fn ne(left: &[f64], right: &[f64]) -> Vec<f64> {
81    ne_simd(left, right)
82}
83
84#[cfg(not(feature = "simd"))]
85#[inline]
86pub fn ne(left: &[f64], right: &[f64]) -> Vec<f64> {
87    ne_scalar(left, right)
88}
89
90/// Logical AND operation (auto-selects SIMD or scalar)
91#[cfg(feature = "simd")]
92#[inline]
93pub fn and(left: &[f64], right: &[f64]) -> Vec<f64> {
94    and_simd(left, right)
95}
96
97#[cfg(not(feature = "simd"))]
98#[inline]
99pub fn and(left: &[f64], right: &[f64]) -> Vec<f64> {
100    and_scalar(left, right)
101}
102
103/// Logical OR operation (auto-selects SIMD or scalar)
104#[cfg(feature = "simd")]
105#[inline]
106pub fn or(left: &[f64], right: &[f64]) -> Vec<f64> {
107    or_simd(left, right)
108}
109
110#[cfg(not(feature = "simd"))]
111#[inline]
112pub fn or(left: &[f64], right: &[f64]) -> Vec<f64> {
113    or_scalar(left, right)
114}
115
116/// Logical NOT operation (auto-selects SIMD or scalar)
117#[cfg(feature = "simd")]
118#[inline]
119pub fn not(values: &[f64]) -> Vec<f64> {
120    not_simd(values)
121}
122
123#[cfg(not(feature = "simd"))]
124#[inline]
125pub fn not(values: &[f64]) -> Vec<f64> {
126    not_scalar(values)
127}
128
129// ===== Internal SIMD implementations =====
130
131/// SIMD greater-than comparison
132/// Returns 1.0 where left > right, 0.0 otherwise
133fn gt_simd(left: &[f64], right: &[f64]) -> Vec<f64> {
134    assert_eq!(left.len(), right.len(), "Array lengths must match");
135    let n = left.len();
136
137    if n < SIMD_THRESHOLD {
138        return gt_scalar(left, right);
139    }
140
141    let mut result = vec![0.0; n];
142    let chunks = n / 4;
143
144    // Process 4 elements at a time - allows compiler auto-vectorization
145    for i in 0..chunks {
146        let idx = i * 4;
147        result[idx] = if left[idx] > right[idx] { 1.0 } else { 0.0 };
148        result[idx + 1] = if left[idx + 1] > right[idx + 1] {
149            1.0
150        } else {
151            0.0
152        };
153        result[idx + 2] = if left[idx + 2] > right[idx + 2] {
154            1.0
155        } else {
156            0.0
157        };
158        result[idx + 3] = if left[idx + 3] > right[idx + 3] {
159            1.0
160        } else {
161            0.0
162        };
163    }
164
165    // Handle remainder
166    for i in (chunks * 4)..n {
167        result[i] = if left[i] > right[i] { 1.0 } else { 0.0 };
168    }
169
170    result
171}
172
173fn gt_scalar(left: &[f64], right: &[f64]) -> Vec<f64> {
174    left.iter()
175        .zip(right.iter())
176        .map(|(a, b)| if a > b { 1.0 } else { 0.0 })
177        .collect()
178}
179
180/// SIMD less-than comparison (internal)
181fn lt_simd(left: &[f64], right: &[f64]) -> Vec<f64> {
182    assert_eq!(left.len(), right.len(), "Array lengths must match");
183    let n = left.len();
184
185    if n < SIMD_THRESHOLD {
186        return lt_scalar(left, right);
187    }
188
189    let mut result = vec![0.0; n];
190    let chunks = n / 4;
191
192    for i in 0..chunks {
193        let idx = i * 4;
194        result[idx] = if left[idx] < right[idx] { 1.0 } else { 0.0 };
195        result[idx + 1] = if left[idx + 1] < right[idx + 1] {
196            1.0
197        } else {
198            0.0
199        };
200        result[idx + 2] = if left[idx + 2] < right[idx + 2] {
201            1.0
202        } else {
203            0.0
204        };
205        result[idx + 3] = if left[idx + 3] < right[idx + 3] {
206            1.0
207        } else {
208            0.0
209        };
210    }
211
212    for i in (chunks * 4)..n {
213        result[i] = if left[i] < right[i] { 1.0 } else { 0.0 };
214    }
215
216    result
217}
218
219fn lt_scalar(left: &[f64], right: &[f64]) -> Vec<f64> {
220    left.iter()
221        .zip(right.iter())
222        .map(|(a, b)| if a < b { 1.0 } else { 0.0 })
223        .collect()
224}
225
226/// SIMD greater-than-or-equal comparison (internal)
227fn gte_simd(left: &[f64], right: &[f64]) -> Vec<f64> {
228    assert_eq!(left.len(), right.len(), "Array lengths must match");
229    let n = left.len();
230
231    if n < SIMD_THRESHOLD {
232        return gte_scalar(left, right);
233    }
234
235    let mut result = vec![0.0; n];
236    let chunks = n / 4;
237
238    for i in 0..chunks {
239        let idx = i * 4;
240        result[idx] = if left[idx] >= right[idx] { 1.0 } else { 0.0 };
241        result[idx + 1] = if left[idx + 1] >= right[idx + 1] {
242            1.0
243        } else {
244            0.0
245        };
246        result[idx + 2] = if left[idx + 2] >= right[idx + 2] {
247            1.0
248        } else {
249            0.0
250        };
251        result[idx + 3] = if left[idx + 3] >= right[idx + 3] {
252            1.0
253        } else {
254            0.0
255        };
256    }
257
258    for i in (chunks * 4)..n {
259        result[i] = if left[i] >= right[i] { 1.0 } else { 0.0 };
260    }
261
262    result
263}
264
265fn gte_scalar(left: &[f64], right: &[f64]) -> Vec<f64> {
266    left.iter()
267        .zip(right.iter())
268        .map(|(a, b)| if a >= b { 1.0 } else { 0.0 })
269        .collect()
270}
271
272/// SIMD less-than-or-equal comparison (internal)
273fn lte_simd(left: &[f64], right: &[f64]) -> Vec<f64> {
274    assert_eq!(left.len(), right.len(), "Array lengths must match");
275    let n = left.len();
276
277    if n < SIMD_THRESHOLD {
278        return lte_scalar(left, right);
279    }
280
281    let mut result = vec![0.0; n];
282    let chunks = n / 4;
283
284    for i in 0..chunks {
285        let idx = i * 4;
286        result[idx] = if left[idx] <= right[idx] { 1.0 } else { 0.0 };
287        result[idx + 1] = if left[idx + 1] <= right[idx + 1] {
288            1.0
289        } else {
290            0.0
291        };
292        result[idx + 2] = if left[idx + 2] <= right[idx + 2] {
293            1.0
294        } else {
295            0.0
296        };
297        result[idx + 3] = if left[idx + 3] <= right[idx + 3] {
298            1.0
299        } else {
300            0.0
301        };
302    }
303
304    for i in (chunks * 4)..n {
305        result[i] = if left[i] <= right[i] { 1.0 } else { 0.0 };
306    }
307
308    result
309}
310
311fn lte_scalar(left: &[f64], right: &[f64]) -> Vec<f64> {
312    left.iter()
313        .zip(right.iter())
314        .map(|(a, b)| if a <= b { 1.0 } else { 0.0 })
315        .collect()
316}
317
318/// SIMD equality comparison (internal)
319fn eq_simd(left: &[f64], right: &[f64]) -> Vec<f64> {
320    assert_eq!(left.len(), right.len(), "Array lengths must match");
321    let n = left.len();
322
323    if n < SIMD_THRESHOLD {
324        return eq_scalar(left, right);
325    }
326
327    let mut result = vec![0.0; n];
328    let chunks = n / 4;
329
330    for i in 0..chunks {
331        let idx = i * 4;
332        result[idx] = if left[idx] == right[idx] { 1.0 } else { 0.0 };
333        result[idx + 1] = if left[idx + 1] == right[idx + 1] {
334            1.0
335        } else {
336            0.0
337        };
338        result[idx + 2] = if left[idx + 2] == right[idx + 2] {
339            1.0
340        } else {
341            0.0
342        };
343        result[idx + 3] = if left[idx + 3] == right[idx + 3] {
344            1.0
345        } else {
346            0.0
347        };
348    }
349
350    for i in (chunks * 4)..n {
351        result[i] = if left[i] == right[i] { 1.0 } else { 0.0 };
352    }
353
354    result
355}
356
357fn eq_scalar(left: &[f64], right: &[f64]) -> Vec<f64> {
358    left.iter()
359        .zip(right.iter())
360        .map(|(a, b)| if a == b { 1.0 } else { 0.0 })
361        .collect()
362}
363
364/// SIMD not-equal comparison (internal)
365fn ne_simd(left: &[f64], right: &[f64]) -> Vec<f64> {
366    assert_eq!(left.len(), right.len(), "Array lengths must match");
367    let n = left.len();
368
369    if n < SIMD_THRESHOLD {
370        return ne_scalar(left, right);
371    }
372
373    let mut result = vec![0.0; n];
374    let chunks = n / 4;
375
376    for i in 0..chunks {
377        let idx = i * 4;
378        result[idx] = if left[idx] != right[idx] { 1.0 } else { 0.0 };
379        result[idx + 1] = if left[idx + 1] != right[idx + 1] {
380            1.0
381        } else {
382            0.0
383        };
384        result[idx + 2] = if left[idx + 2] != right[idx + 2] {
385            1.0
386        } else {
387            0.0
388        };
389        result[idx + 3] = if left[idx + 3] != right[idx + 3] {
390            1.0
391        } else {
392            0.0
393        };
394    }
395
396    for i in (chunks * 4)..n {
397        result[i] = if left[i] != right[i] { 1.0 } else { 0.0 };
398    }
399
400    result
401}
402
403fn ne_scalar(left: &[f64], right: &[f64]) -> Vec<f64> {
404    left.iter()
405        .zip(right.iter())
406        .map(|(a, b)| if a != b { 1.0 } else { 0.0 })
407        .collect()
408}
409
410/// SIMD logical AND operation (internal)
411fn and_simd(left: &[f64], right: &[f64]) -> Vec<f64> {
412    assert_eq!(left.len(), right.len(), "Array lengths must match");
413    let n = left.len();
414
415    if n < SIMD_THRESHOLD {
416        return and_scalar(left, right);
417    }
418
419    let mut result = vec![0.0; n];
420    let chunks = n / 4;
421
422    for i in 0..chunks {
423        let idx = i * 4;
424        result[idx] = if left[idx] > 0.5 && right[idx] > 0.5 {
425            1.0
426        } else {
427            0.0
428        };
429        result[idx + 1] = if left[idx + 1] > 0.5 && right[idx + 1] > 0.5 {
430            1.0
431        } else {
432            0.0
433        };
434        result[idx + 2] = if left[idx + 2] > 0.5 && right[idx + 2] > 0.5 {
435            1.0
436        } else {
437            0.0
438        };
439        result[idx + 3] = if left[idx + 3] > 0.5 && right[idx + 3] > 0.5 {
440            1.0
441        } else {
442            0.0
443        };
444    }
445
446    for i in (chunks * 4)..n {
447        result[i] = if left[i] > 0.5 && right[i] > 0.5 {
448            1.0
449        } else {
450            0.0
451        };
452    }
453
454    result
455}
456
457fn and_scalar(left: &[f64], right: &[f64]) -> Vec<f64> {
458    left.iter()
459        .zip(right.iter())
460        .map(|(a, b)| if *a > 0.5 && *b > 0.5 { 1.0 } else { 0.0 })
461        .collect()
462}
463
464/// SIMD logical OR operation (internal)
465fn or_simd(left: &[f64], right: &[f64]) -> Vec<f64> {
466    assert_eq!(left.len(), right.len(), "Array lengths must match");
467    let n = left.len();
468
469    if n < SIMD_THRESHOLD {
470        return or_scalar(left, right);
471    }
472
473    let mut result = vec![0.0; n];
474    let chunks = n / 4;
475
476    for i in 0..chunks {
477        let idx = i * 4;
478        result[idx] = if left[idx] > 0.5 || right[idx] > 0.5 {
479            1.0
480        } else {
481            0.0
482        };
483        result[idx + 1] = if left[idx + 1] > 0.5 || right[idx + 1] > 0.5 {
484            1.0
485        } else {
486            0.0
487        };
488        result[idx + 2] = if left[idx + 2] > 0.5 || right[idx + 2] > 0.5 {
489            1.0
490        } else {
491            0.0
492        };
493        result[idx + 3] = if left[idx + 3] > 0.5 || right[idx + 3] > 0.5 {
494            1.0
495        } else {
496            0.0
497        };
498    }
499
500    for i in (chunks * 4)..n {
501        result[i] = if left[i] > 0.5 || right[i] > 0.5 {
502            1.0
503        } else {
504            0.0
505        };
506    }
507
508    result
509}
510
511fn or_scalar(left: &[f64], right: &[f64]) -> Vec<f64> {
512    left.iter()
513        .zip(right.iter())
514        .map(|(a, b)| if *a > 0.5 || *b > 0.5 { 1.0 } else { 0.0 })
515        .collect()
516}
517
518/// SIMD logical NOT operation (internal)
519fn not_simd(values: &[f64]) -> Vec<f64> {
520    let n = values.len();
521
522    if n < SIMD_THRESHOLD {
523        return not_scalar(values);
524    }
525
526    let mut result = vec![0.0; n];
527    let chunks = n / 4;
528
529    for i in 0..chunks {
530        let idx = i * 4;
531        result[idx] = if values[idx] > 0.5 { 0.0 } else { 1.0 };
532        result[idx + 1] = if values[idx + 1] > 0.5 { 0.0 } else { 1.0 };
533        result[idx + 2] = if values[idx + 2] > 0.5 { 0.0 } else { 1.0 };
534        result[idx + 3] = if values[idx + 3] > 0.5 { 0.0 } else { 1.0 };
535    }
536
537    for i in (chunks * 4)..n {
538        result[i] = if values[i] > 0.5 { 0.0 } else { 1.0 };
539    }
540
541    result
542}
543
544fn not_scalar(values: &[f64]) -> Vec<f64> {
545    values
546        .iter()
547        .map(|v| if *v > 0.5 { 0.0 } else { 1.0 })
548        .collect()
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554
555    #[test]
556    fn test_gt_simd_small() {
557        let left = vec![1.0, 2.0, 3.0, 4.0, 5.0];
558        let right = vec![0.5, 2.5, 3.0, 3.5, 6.0];
559        let result = gt_simd(&left, &right);
560        assert_eq!(result, vec![1.0, 0.0, 0.0, 1.0, 0.0]);
561    }
562
563    #[test]
564    fn test_gt_simd_large() {
565        let n = 1000;
566        let left: Vec<f64> = (0..n).map(|i| i as f64).collect();
567        let right: Vec<f64> = (0..n).map(|i| (i - 500) as f64).collect();
568        let result = gt_simd(&left, &right);
569
570        // left[i] > right[i] means i > i - 500, which is always true
571        assert_eq!(result[0], 1.0);
572        assert_eq!(result[600], 1.0);
573    }
574
575    #[test]
576    fn test_lt_simd() {
577        let left = vec![1.0, 2.0, 3.0, 4.0, 5.0];
578        let right = vec![2.0, 2.0, 2.0, 5.0, 3.0];
579        let result = lt_simd(&left, &right);
580        assert_eq!(result, vec![1.0, 0.0, 0.0, 1.0, 0.0]);
581    }
582
583    #[test]
584    fn test_eq_simd() {
585        let left = vec![1.0, 2.0, 3.0, 4.0, 5.0];
586        let right = vec![1.0, 3.0, 3.0, 4.0, 6.0];
587        let result = eq_simd(&left, &right);
588        assert_eq!(result, vec![1.0, 0.0, 1.0, 1.0, 0.0]);
589    }
590
591    #[test]
592    fn test_and_simd() {
593        let left = vec![1.0, 1.0, 0.0, 0.0];
594        let right = vec![1.0, 0.0, 1.0, 0.0];
595        let result = and_simd(&left, &right);
596        assert_eq!(result, vec![1.0, 0.0, 0.0, 0.0]);
597    }
598
599    #[test]
600    fn test_or_simd() {
601        let left = vec![1.0, 1.0, 0.0, 0.0];
602        let right = vec![1.0, 0.0, 1.0, 0.0];
603        let result = or_simd(&left, &right);
604        assert_eq!(result, vec![1.0, 1.0, 1.0, 0.0]);
605    }
606
607    #[test]
608    fn test_not_simd() {
609        let values = vec![1.0, 0.0, 1.0, 0.0];
610        let result = not_simd(&values);
611        assert_eq!(result, vec![0.0, 1.0, 0.0, 1.0]);
612    }
613
614    #[test]
615    fn test_simd_vs_scalar_correctness() {
616        let n = 10000;
617        let left: Vec<f64> = (0..n).map(|i| (i as f64 * 0.5).sin()).collect();
618        let right: Vec<f64> = (0..n).map(|i| (i as f64 * 0.5).cos()).collect();
619
620        let simd_result = gt_simd(&left, &right);
621        let scalar_result = gt_scalar(&left, &right);
622
623        assert_eq!(
624            simd_result, scalar_result,
625            "SIMD and scalar results must match"
626        );
627    }
628}