1#![no_std]
2#![feature(portable_simd)]
3#![allow(dead_code)]
4#![allow(unused_imports)]
5#![feature(maybe_uninit_uninit_array)]
6
7extern crate alloc;
8use crate::alloc::vec::Vec;
9
10use core::simd::{Simd, cmp::SimdPartialEq, LaneCount, SupportedLaneCount};
11use core::ops::{BitAnd, BitAndAssign};
12use core::arch::x86_64::*;
13pub use simd_bmh_macro::parse_pattern;
14
15#[derive(Clone, Debug)]
16#[repr(align(32))]
17pub struct Pattern<const N: usize> {
18 pub bytes: [u8; N],
19 pub masks: [u8; N],
20 pub best_skip_value: u8,
21 pub best_skip_mask: u8,
22 pub max_skip: usize,
23 pub best_skip_offset: usize,
24}
25
26impl<const N: usize> Pattern<N> {
27 #[inline(always)]
28 pub fn find_all_matches(&self, text: &[u8]) -> Vec<usize> {
29 find_all_matches_sse::<N>(text, self)
30 }
31}
32
33#[inline(always)]
34pub fn find_all_matches_sse<const PATTERN_LEN: usize>(text: &[u8], pattern: &Pattern<PATTERN_LEN>) -> Vec<usize> {
35 if PATTERN_LEN > text.len() {
36 return Vec::new();
37 }
38
39 let mut matches = Vec::new();
40 let mut i = 0;
41
42 let best_skip = pattern.best_skip_value as i32;
43 let best_mask = pattern.best_skip_mask as i32;
44 let best_skip_offset = pattern.best_skip_offset as i32;
45
46 unsafe {
47 let skip_vector = _mm_set1_epi8(best_skip as i8);
48 let mask_vector = _mm_set1_epi8(best_mask as i8);
49
50 while i + 16 <= text.len() {
51 let mut match_masks = _mm_setzero_si128();
52 let chunk = _mm_loadu_si128(text.as_ptr().add(i) as *const __m128i);
53 let masked_chunk = _mm_and_si128(chunk, mask_vector);
54 let cmp_result = _mm_cmpeq_epi8(masked_chunk, skip_vector);
55 match_masks = _mm_or_si128(match_masks, cmp_result);
56
57 let match_positions = _mm_movemask_epi8(match_masks);
58 if match_positions != 0 {
59 for pos in 0..16 {
60 if (match_positions & (1 << pos)) != 0 {
61 let match_pos = i + pos;
62 let start_pos = match_pos - best_skip_offset as usize;
63
64 let mut valid = true;
65 for k in 0..PATTERN_LEN {
66 let pattern_byte = pattern.bytes[k];
67 let pattern_mask = pattern.masks[k];
68 let text_index = start_pos + k;
69
70 let masked_pattern_byte = pattern_byte & pattern_mask;
71 let masked_text_byte = text[text_index] & pattern_mask;
72 if masked_text_byte != masked_pattern_byte {
73 valid = false;
74 break;
75 }
76 }
77
78 if valid {
79 matches.push(start_pos);
80 }
81 }
82 }
83 }
84
85 i += 16;
86 }
87 }
88
89 while i + PATTERN_LEN <= text.len() {
90 let start_pos = i;
91 let mut match_found = true;
92
93 for k in 0..PATTERN_LEN {
94 let pattern_byte = pattern.bytes[k];
95 let pattern_mask = pattern.masks[k];
96 let text_index = start_pos + k;
97
98 let masked_pattern_byte = pattern_byte & pattern_mask;
99 let masked_text_byte = text[text_index] & pattern_mask;
100
101 if masked_text_byte != masked_pattern_byte {
102 match_found = false;
103 break;
104 }
105 }
106
107 if match_found {
108 matches.push(start_pos);
109 i += PATTERN_LEN;
110 } else {
111 let mismatch_byte = text[start_pos + PATTERN_LEN - 1];
112 i += (0..PATTERN_LEN - 1)
113 .rev()
114 .find(|&j| pattern.bytes[j] == mismatch_byte)
115 .map_or(PATTERN_LEN, |j| PATTERN_LEN - 1 - j);
116 }
117 }
118
119 matches
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use alloc::vec;
126 use rand::Rng;
127
128 #[test]
129 fn test_parse_pattern() {
130 let pattern = parse_pattern!("A?C?FF");
131 assert_eq!(&pattern.bytes[..], &[0xA0, 0xC0, 0xFF]);
132 assert_eq!(&pattern.masks[..], &[0xF0, 0xF0, 0xFF]);
133 }
134
135 #[test]
136 fn test_match() {
137 let pattern = parse_pattern!("A?C?FF");
138 let text = b"\xA0\xC0\xFF\x00\xA0\xC0\xFF";
139
140 let matches = pattern.find_all_matches(text);
141 assert_eq!(matches, [0, 4]);
142 }
143
144 #[test]
145 fn test_random_pool_with_fixed_pattern() {
146 let buffer_size = 2_000;
147 let mut random_buffer: Vec<u8> = (0..buffer_size).map(|_| rand::rng().random()).collect();
148 random_buffer[1337..1342].copy_from_slice(b"\xAA\xCC\xFF\xFF\xFF");
149
150 let pattern = parse_pattern!("A?C?FF");
151 let matches = find_all_matches_sse(&random_buffer, &pattern);
152 assert!(!matches.is_empty(), "Pattern matches should not be empty!");
153 }
154}