1#[cfg(target_arch = "x86")]
2use std::arch::x86::*;
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6use super::{A, str_to_bytes, bytes_len};
7
8pub struct Bitap8x16 {
9 v: __m128i,
10 start_mask: __m128i
11}
12
13const fn get_masks(patterns: &[&str]) -> [A; 256] {
14 const TEMP_A: A = A([0u8; 16]);
19 let mut res = [TEMP_A; 256];
20 let mut i = 0;
21 let bit5 = 0b0010_0000u8;
22
23 while i < patterns.len() {
24 let bytes = patterns[i].as_bytes();
25 let offset = 16 - bytes.len();
28 let mut j = 0;
29
30 while j < bytes.len() {
31 let idx = i * 16 + j + offset;
32 res[bytes[j] as usize].0[idx / 8] |= 1u8 << (idx % 8);
33
34 if bytes[j].is_ascii_alphabetic() {
36 res[(bytes[j] ^ bit5) as usize].0[idx / 8] |= 1u8 << (idx % 8);
37 }
38
39 j += 1;
40 }
41
42 i += 1;
43 }
44
45 res
46}
47
48const fn get_start_mask(patterns: &[&str]) -> A {
49 let mut res = A([0u8; 16]);
51 let mut i = 0;
52
53 while i < patterns.len() {
54 let j = 16 - patterns[i].as_bytes().len();
55 let idx = i * 16 + j;
56 res.0[idx / 8] |= 1u8 << (idx % 8);
57 i += 1;
58 }
59
60 res
61}
62
63static PATTERNS: [&str; 8] = [
64 "small",
65 "cute",
66 "fluff",
67 "love",
68 "stupid",
69 "what",
70 "meow",
71 "meow"
72];
73
74static MASKS: [A; 256] = get_masks(&PATTERNS);
75static START_MASK: A = get_start_mask(&PATTERNS);
76
77static REPLACE: [A; 8] = [
82 str_to_bytes("smol"),
83 str_to_bytes("kawaii~"),
84 str_to_bytes("floof"),
85 str_to_bytes("luv"),
86 str_to_bytes("baka"),
87 str_to_bytes("nani"),
88 str_to_bytes("nya~"),
89 str_to_bytes("nya~")
90];
91
92const fn get_len(a: &[A]) -> [usize; 8] {
93 let mut res = [0usize; 8];
94 let mut i = 0;
95
96 while i < a.len() {
97 res[i] = bytes_len(&a[i].0);
98 i += 1;
99 }
100
101 res
102}
103
104static REPLACE_LEN: [usize; 8] = get_len(&REPLACE);
105
106#[derive(Debug, PartialEq)]
107pub struct Match {
108 pub match_len: usize,
109 pub replace_ptr: *const __m128i,
110 pub replace_len: usize
111}
112
113impl Bitap8x16 {
114 #[inline]
115 #[target_feature(enable = "sse4.1")]
116 pub unsafe fn new() -> Self {
117 Self {
118 v: _mm_setzero_si128(),
119 start_mask: _mm_load_si128(START_MASK.0.as_ptr() as *const __m128i)
120 }
121 }
122
123 #[inline]
124 #[target_feature(enable = "sse4.1")]
125 pub unsafe fn next(&mut self, c: u8) -> Option<Match> {
126 self.v = _mm_slli_epi16(self.v, 1);
127 self.v = _mm_or_si128(self.v, self.start_mask);
128 let mask = _mm_load_si128(MASKS.get_unchecked(c as usize).0.as_ptr() as *const __m128i);
129 self.v = _mm_and_si128(self.v, mask);
130
131 let match_mask = (_mm_movemask_epi8(self.v) as u32) & 0xAAAAAAAAu32;
132
133 if match_mask != 0 {
134 let match_idx = (match_mask.trailing_zeros() as usize) / 2;
135
136 return Some(Match {
137 match_len: PATTERNS.get_unchecked(match_idx).len(),
138 replace_ptr: REPLACE.get_unchecked(match_idx).0.as_ptr() as *const __m128i,
139 replace_len: *REPLACE_LEN.get_unchecked(match_idx)
140 });
141 }
142
143 None
144 }
145
146 #[inline]
147 #[target_feature(enable = "sse4.1")]
148 pub unsafe fn reset(&mut self) {
149 self.v = _mm_setzero_si128();
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_bitap() {
159 if !is_x86_feature_detected!("sse4.1") {
160 panic!("sse4.1 feature not detected!");
161 }
162
163 unsafe {
164 let mut b = Bitap8x16::new();
165 assert_eq!(b.next(b'c'), None);
166 assert_eq!(b.next(b'u'), None);
167 assert_eq!(b.next(b't'), None);
168 let next = b.next(b'e').unwrap();
169 assert_eq!(next.match_len, 4);
170 assert_eq!(next.replace_len, 7);
171
172 b.reset();
173 assert_eq!(b.next(b'w'), None);
174 assert_eq!(b.next(b'h'), None);
175 assert_eq!(b.next(b'a'), None);
176 let next = b.next(b't').unwrap();
177 assert_eq!(next.match_len, 4);
178 assert_eq!(next.replace_len, 4);
179
180 assert_eq!(b.next(b'w'), None);
181 assert_eq!(b.next(b'h'), None);
182 assert_eq!(b.next(b'a'), None);
183 assert_eq!(b.next(b'a'), None);
184
185 assert_eq!(b.next(b'W'), None);
186 assert_eq!(b.next(b'h'), None);
187 assert_eq!(b.next(b'A'), None);
188 let next = b.next(b't').unwrap();
189 assert_eq!(next.match_len, 4);
190 assert_eq!(next.replace_len, 4);
191 }
192 }
193}