sqry_core/search/simd/
mod.rs1pub mod scalar;
20
21mod common;
22
23#[cfg(target_arch = "x86_64")]
24pub mod avx2;
25
26#[cfg(target_arch = "x86_64")]
27pub mod sse42;
28
29#[cfg(target_arch = "aarch64")]
30pub mod neon;
31
32use std::fmt;
33
34pub type SearchResult = Option<usize>;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum SimdPlatform {
40 Avx2,
42 Sse42,
44 Neon,
46 Scalar,
48}
49
50impl fmt::Display for SimdPlatform {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 match self {
53 SimdPlatform::Avx2 => write!(f, "AVX2"),
54 SimdPlatform::Sse42 => write!(f, "SSE4.2"),
55 SimdPlatform::Neon => write!(f, "NEON"),
56 SimdPlatform::Scalar => write!(f, "Scalar"),
57 }
58 }
59}
60
61#[must_use]
63pub fn detect_platform() -> SimdPlatform {
64 #[cfg(target_arch = "x86_64")]
65 {
66 if is_x86_feature_detected!("avx2") {
67 log::debug!("SIMD platform: AVX2");
68 return SimdPlatform::Avx2;
69 }
70 if is_x86_feature_detected!("sse4.2") {
71 log::debug!("SIMD platform: SSE4.2");
72 return SimdPlatform::Sse42;
73 }
74 }
75
76 #[cfg(target_arch = "aarch64")]
77 {
78 if std::arch::is_aarch64_feature_detected!("neon") {
79 log::debug!("SIMD platform: NEON");
80 return SimdPlatform::Neon;
81 }
82 }
83
84 log::debug!("SIMD platform: Scalar (no SIMD available)");
85 SimdPlatform::Scalar
86}
87
88#[must_use]
104pub fn search(haystack: &[u8], needle: &[u8]) -> SearchResult {
105 if needle.is_empty() {
106 return Some(0);
107 }
108 if haystack.len() < needle.len() {
109 return None;
110 }
111
112 #[cfg(target_arch = "x86_64")]
114 {
115 if is_x86_feature_detected!("avx2") {
116 return unsafe { avx2::search(haystack, needle) };
117 }
118 if is_x86_feature_detected!("sse4.2") {
120 return unsafe { sse42::search(haystack, needle) };
121 }
122 }
123
124 #[cfg(target_arch = "aarch64")]
126 {
127 if std::arch::is_aarch64_feature_detected!("neon") {
128 return unsafe { neon::search(haystack, needle) };
129 }
130 }
131
132 scalar::search(haystack, needle)
133}
134
135#[must_use]
151pub fn extract_trigrams(text: &str) -> Vec<String> {
152 if text.len() < 3 {
153 return vec![text.to_string()];
154 }
155
156 #[cfg(target_arch = "x86_64")]
158 {
159 if is_x86_feature_detected!("avx2") {
160 return unsafe { avx2::extract_trigrams(text) };
161 }
162 if is_x86_feature_detected!("sse4.2") {
164 return unsafe { sse42::extract_trigrams(text) };
165 }
166 }
167
168 #[cfg(target_arch = "aarch64")]
170 {
171 if std::arch::is_aarch64_feature_detected!("neon") {
172 return unsafe { neon::extract_trigrams(text) };
173 }
174 }
175
176 scalar::extract_trigrams(text)
177}
178
179#[must_use]
192pub fn to_lowercase_ascii(text: &str) -> String {
193 #[cfg(target_arch = "x86_64")]
195 {
196 if is_x86_feature_detected!("avx2") {
197 return unsafe { avx2::to_lowercase_ascii(text) };
198 }
199 if is_x86_feature_detected!("sse4.2") {
201 return unsafe { sse42::to_lowercase_ascii(text) };
202 }
203 }
204
205 #[cfg(target_arch = "aarch64")]
207 {
208 if std::arch::is_aarch64_feature_detected!("neon") {
209 return unsafe { neon::to_lowercase_ascii(text) };
210 }
211 }
212
213 scalar::to_lowercase_ascii(text)
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_detect_platform() {
222 let platform = detect_platform();
223 assert!(matches!(
225 platform,
226 SimdPlatform::Avx2 | SimdPlatform::Sse42 | SimdPlatform::Neon | SimdPlatform::Scalar
227 ));
228 }
229
230 #[test]
231 fn test_platform_display() {
232 assert_eq!(SimdPlatform::Avx2.to_string(), "AVX2");
233 assert_eq!(SimdPlatform::Sse42.to_string(), "SSE4.2");
234 assert_eq!(SimdPlatform::Neon.to_string(), "NEON");
235 assert_eq!(SimdPlatform::Scalar.to_string(), "Scalar");
236 }
237
238 #[test]
239 fn test_search_empty_needle() {
240 let haystack = b"hello";
241 let needle = b"";
242 assert_eq!(search(haystack, needle), Some(0));
243 }
244
245 #[test]
246 fn test_search_needle_too_long() {
247 let haystack = b"hi";
248 let needle = b"hello";
249 assert_eq!(search(haystack, needle), None);
250 }
251
252 #[test]
253 fn test_extract_trigrams_short_string() {
254 assert_eq!(extract_trigrams("ab"), vec!["ab"]);
255 assert_eq!(extract_trigrams(""), vec![""]);
256 }
257
258 #[test]
259 fn test_to_lowercase_ascii_empty() {
260 assert_eq!(to_lowercase_ascii(""), "");
261 }
262
263 #[test]
264 fn test_extract_trigrams_ascii_matches_scalar() {
265 let inputs = [
266 "hello",
267 "abc",
268 "abcdefghijklmnopqrstuvwxyz0123456789",
269 "createCompilerHost",
270 "aaaa",
271 "HELLO_WORLD",
272 ];
273 for input in &inputs {
274 let mut dispatched = extract_trigrams(input);
275 let mut scalar_result = scalar::extract_trigrams(input);
276 dispatched.sort();
277 scalar_result.sort();
278 assert_eq!(
279 dispatched, scalar_result,
280 "SIMD ≡ scalar mismatch for ASCII input: {input}"
281 );
282 }
283 }
284
285 #[test]
286 fn test_extract_trigrams_non_ascii_matches_scalar() {
287 let inputs = ["héllo", "日本語", "café", "naïve", "über"];
288 for input in &inputs {
289 let mut dispatched = extract_trigrams(input);
290 let mut scalar_result = scalar::extract_trigrams(input);
291 dispatched.sort();
292 scalar_result.sort();
293 assert_eq!(
294 dispatched, scalar_result,
295 "SIMD ≡ scalar mismatch for non-ASCII input: {input}"
296 );
297 }
298 }
299}