Skip to main content

sqry_core/search/simd/
mod.rs

1//! SIMD-accelerated text search operations
2//!
3//! This module provides SIMD-optimized implementations of core search operations:
4//! - Substring search (Boyer-Moore-Horspool with SIMD)
5//! - Trigram extraction (bulk loading with SIMD)
6//! - ASCII case conversion (range check with SIMD)
7//!
8//! Platform support:
9//! - x86_64: AVX2 (primary), SSE4.2 (fallback)
10//! - aarch64: NEON
11//! - Other: Scalar fallback
12//!
13//! Safety:
14//! - All SIMD operations use safe wrappers from std::arch
15//! - Runtime feature detection ensures CPU support
16//! - Fallback to scalar when SIMD unavailable
17//! - Property tests validate SIMD ≡ scalar
18
19pub mod scalar;
20
21mod common;
22
23#[cfg(target_arch = "x86_64")]
24pub mod avx2;
25
26#[cfg(target_arch = "x86_64")]
27pub mod sse42;
28
29#[cfg(target_arch = "aarch64")]
30pub mod neon;
31
32use std::fmt;
33
34/// Search result: byte offset of match, or None
35pub type SearchResult = Option<usize>;
36
37/// SIMD platform in use
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum SimdPlatform {
40    /// `x86_64` AVX2 (32-byte vectors)
41    Avx2,
42    /// `x86_64` SSE4.2 (16-byte vectors)
43    Sse42,
44    /// ARM64 NEON (16-byte vectors)
45    Neon,
46    /// Scalar fallback (no SIMD)
47    Scalar,
48}
49
50impl fmt::Display for SimdPlatform {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match self {
53            SimdPlatform::Avx2 => write!(f, "AVX2"),
54            SimdPlatform::Sse42 => write!(f, "SSE4.2"),
55            SimdPlatform::Neon => write!(f, "NEON"),
56            SimdPlatform::Scalar => write!(f, "Scalar"),
57        }
58    }
59}
60
61/// Detect the best available SIMD platform for the current CPU
62#[must_use]
63pub fn detect_platform() -> SimdPlatform {
64    #[cfg(target_arch = "x86_64")]
65    {
66        if is_x86_feature_detected!("avx2") {
67            log::debug!("SIMD platform: AVX2");
68            return SimdPlatform::Avx2;
69        }
70        if is_x86_feature_detected!("sse4.2") {
71            log::debug!("SIMD platform: SSE4.2");
72            return SimdPlatform::Sse42;
73        }
74    }
75
76    #[cfg(target_arch = "aarch64")]
77    {
78        if std::arch::is_aarch64_feature_detected!("neon") {
79            log::debug!("SIMD platform: NEON");
80            return SimdPlatform::Neon;
81        }
82    }
83
84    log::debug!("SIMD platform: Scalar (no SIMD available)");
85    SimdPlatform::Scalar
86}
87
88/// Search for needle in haystack using the best available SIMD implementation
89///
90/// # Safety
91/// This function performs runtime feature detection and dispatches to the
92/// appropriate SIMD implementation. All SIMD code uses safe wrappers from
93/// `std::arch`, so this function is safe to call.
94///
95/// # Examples
96/// ```
97/// use sqry_core::search::simd::search;
98///
99/// let haystack = b"hello world";
100/// let needle = b"world";
101/// assert_eq!(search(haystack, needle), Some(6));
102/// ```
103#[must_use]
104pub fn search(haystack: &[u8], needle: &[u8]) -> SearchResult {
105    if needle.is_empty() {
106        return Some(0);
107    }
108    if haystack.len() < needle.len() {
109        return None;
110    }
111
112    // Phase 2: AVX2 implementation active
113    #[cfg(target_arch = "x86_64")]
114    {
115        if is_x86_feature_detected!("avx2") {
116            return unsafe { avx2::search(haystack, needle) };
117        }
118        // Phase 3: SSE4.2 fallback for x86_64
119        if is_x86_feature_detected!("sse4.2") {
120            return unsafe { sse42::search(haystack, needle) };
121        }
122    }
123
124    // Phase 3: NEON support for aarch64
125    #[cfg(target_arch = "aarch64")]
126    {
127        if std::arch::is_aarch64_feature_detected!("neon") {
128            return unsafe { neon::search(haystack, needle) };
129        }
130    }
131
132    scalar::search(haystack, needle)
133}
134
135/// Extract trigrams from text using the best available SIMD implementation
136///
137/// A trigram is a 3-character sliding window over the input text.
138/// For example: `"abcd"` → `["abc", "bcd"]`
139///
140/// Strings shorter than 3 characters return a single-element vector with
141/// the original string.
142///
143/// # Examples
144/// ```
145/// use sqry_core::search::simd::extract_trigrams;
146///
147/// let trigrams = extract_trigrams("hello");
148/// assert_eq!(trigrams, vec!["hel", "ell", "llo"]);
149/// ```
150#[must_use]
151pub fn extract_trigrams(text: &str) -> Vec<String> {
152    if text.len() < 3 {
153        return vec![text.to_string()];
154    }
155
156    // Phase 2: AVX2 implementation active
157    #[cfg(target_arch = "x86_64")]
158    {
159        if is_x86_feature_detected!("avx2") {
160            return unsafe { avx2::extract_trigrams(text) };
161        }
162        // Phase 3: SSE4.2 fallback for x86_64
163        if is_x86_feature_detected!("sse4.2") {
164            return unsafe { sse42::extract_trigrams(text) };
165        }
166    }
167
168    // Phase 3: NEON support for aarch64
169    #[cfg(target_arch = "aarch64")]
170    {
171        if std::arch::is_aarch64_feature_detected!("neon") {
172            return unsafe { neon::extract_trigrams(text) };
173        }
174    }
175
176    scalar::extract_trigrams(text)
177}
178
179/// Convert ASCII text to lowercase using the best available SIMD implementation
180///
181/// Only ASCII characters (A-Z) are converted to lowercase. Non-ASCII characters
182/// are preserved unchanged.
183///
184/// # Examples
185/// ```
186/// use sqry_core::search::simd::to_lowercase_ascii;
187///
188/// assert_eq!(to_lowercase_ascii("HELLO"), "hello");
189/// assert_eq!(to_lowercase_ascii("HeLLo"), "hello");
190/// ```
191#[must_use]
192pub fn to_lowercase_ascii(text: &str) -> String {
193    // Phase 2: AVX2 implementation active
194    #[cfg(target_arch = "x86_64")]
195    {
196        if is_x86_feature_detected!("avx2") {
197            return unsafe { avx2::to_lowercase_ascii(text) };
198        }
199        // Phase 3: SSE4.2 fallback for x86_64
200        if is_x86_feature_detected!("sse4.2") {
201            return unsafe { sse42::to_lowercase_ascii(text) };
202        }
203    }
204
205    // Phase 3: NEON support for aarch64
206    #[cfg(target_arch = "aarch64")]
207    {
208        if std::arch::is_aarch64_feature_detected!("neon") {
209            return unsafe { neon::to_lowercase_ascii(text) };
210        }
211    }
212
213    scalar::to_lowercase_ascii(text)
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_detect_platform() {
222        let platform = detect_platform();
223        // Should detect some platform (even if just scalar)
224        assert!(matches!(
225            platform,
226            SimdPlatform::Avx2 | SimdPlatform::Sse42 | SimdPlatform::Neon | SimdPlatform::Scalar
227        ));
228    }
229
230    #[test]
231    fn test_platform_display() {
232        assert_eq!(SimdPlatform::Avx2.to_string(), "AVX2");
233        assert_eq!(SimdPlatform::Sse42.to_string(), "SSE4.2");
234        assert_eq!(SimdPlatform::Neon.to_string(), "NEON");
235        assert_eq!(SimdPlatform::Scalar.to_string(), "Scalar");
236    }
237
238    #[test]
239    fn test_search_empty_needle() {
240        let haystack = b"hello";
241        let needle = b"";
242        assert_eq!(search(haystack, needle), Some(0));
243    }
244
245    #[test]
246    fn test_search_needle_too_long() {
247        let haystack = b"hi";
248        let needle = b"hello";
249        assert_eq!(search(haystack, needle), None);
250    }
251
252    #[test]
253    fn test_extract_trigrams_short_string() {
254        assert_eq!(extract_trigrams("ab"), vec!["ab"]);
255        assert_eq!(extract_trigrams(""), vec![""]);
256    }
257
258    #[test]
259    fn test_to_lowercase_ascii_empty() {
260        assert_eq!(to_lowercase_ascii(""), "");
261    }
262
263    #[test]
264    fn test_extract_trigrams_ascii_matches_scalar() {
265        let inputs = [
266            "hello",
267            "abc",
268            "abcdefghijklmnopqrstuvwxyz0123456789",
269            "createCompilerHost",
270            "aaaa",
271            "HELLO_WORLD",
272        ];
273        for input in &inputs {
274            let mut dispatched = extract_trigrams(input);
275            let mut scalar_result = scalar::extract_trigrams(input);
276            dispatched.sort();
277            scalar_result.sort();
278            assert_eq!(
279                dispatched, scalar_result,
280                "SIMD ≡ scalar mismatch for ASCII input: {input}"
281            );
282        }
283    }
284
285    #[test]
286    fn test_extract_trigrams_non_ascii_matches_scalar() {
287        let inputs = ["héllo", "日本語", "café", "naïve", "über"];
288        for input in &inputs {
289            let mut dispatched = extract_trigrams(input);
290            let mut scalar_result = scalar::extract_trigrams(input);
291            dispatched.sort();
292            scalar_result.sort();
293            assert_eq!(
294                dispatched, scalar_result,
295                "SIMD ≡ scalar mismatch for non-ASCII input: {input}"
296            );
297        }
298    }
299}