1#![allow(dead_code)]
9
10extern crate unchecked_index;
11extern crate memchr;
12
13use std::cmp;
14use std::mem;
15use std::iter::Zip;
16
17use self::unchecked_index::get_unchecked;
18
19use TwoWaySearcher;
20
21fn zip<I, J>(i: I, j: J) -> Zip<I::IntoIter, J::IntoIter>
22 where I: IntoIterator,
23 J: IntoIterator
24{
25 i.into_iter().zip(j)
26}
27
28#[cfg(target_arch = "x86")]
29use std::arch::x86::*;
30
31#[cfg(target_arch = "x86_64")]
32use std::arch::x86_64::*;
33
34#[target_feature(enable = "sse4.2")]
45unsafe fn pcmpestri_16_mask(text: *const u8, offset: usize, text_len: usize,
46 needle: __m128i, needle_len: usize) -> u32 {
47 let text = mask_load(text.offset(offset as _) as *const _, text_len);
50 _mm_cmpestri(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_ORDERED) as _
51}
52
53#[target_feature(enable = "sse4.2")]
63unsafe fn pcmpestri_16_nomask(text: *const u8, offset: usize, text_len: usize,
64 needle: __m128i, needle_len: usize) -> u32 {
65 let text = _mm_loadu_si128(text.offset(offset as _) as *const _);
68 _mm_cmpestri(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_ORDERED) as _
69}
70
71#[target_feature(enable = "sse4.2")]
79unsafe fn pcmpestrm_eq_each(text: *const u8, offset: usize, text_len: usize,
80 needle: *const u8, noffset: usize, needle_len: usize) -> u64 {
81 let needle = _mm_loadu_si128(needle.offset(noffset as _) as *const _);
86 let text = _mm_loadu_si128(text.offset(offset as _) as *const _);
87 let mask = _mm_cmpestrm(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_EACH);
88
89 #[cfg(target_arch = "x86")] {
90 _mm_extract_epi32(mask, 0) as u64 | (_mm_extract_epi32(mask, 1) as (u64) << 32)
91 }
92
93 #[cfg(target_arch = "x86_64")] {
94 _mm_extract_epi64(mask, 0) as _
95 }
96}
97
98
99#[cfg(test)]
102fn first_start_of_match(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
103 let patl = pat.len();
105 assert!(patl <= 16);
106 unsafe { first_start_of_match_mask(text, pat.len(), pat128(pat)) }
107}
108
109#[target_feature(enable = "sse4.2")]
114unsafe fn first_start_of_match_mask(text: &[u8], pat_len: usize, p: __m128i) -> Option<(usize, usize)> {
115 let tp = text.as_ptr();
116 debug_assert!(pat_len <= 16);
117
118 let mut offset = 0;
119
120 while text.len() >= offset + pat_len {
121 let tlen = text.len() - offset;
122 let ret = pcmpestri_16_mask(tp, offset, tlen, p, pat_len) as usize;
123 if ret == 16 {
124 offset += 16;
125 } else {
126 let match_len = cmp::min(pat_len, 16 - ret);
127 return Some((offset + ret, match_len));
128 }
129 }
130
131 None
132}
133
134
135#[target_feature(enable = "sse4.2")]
140unsafe fn first_start_of_match_nomask(text: &[u8], pat_len: usize, p: __m128i) -> Option<(usize, usize)> {
141 let tp = text.as_ptr();
142 debug_assert!(pat_len <= 16);
143 debug_assert!(pat_len <= text.len());
144
145 let mut offset = 0;
146
147 while text.len() - pat_len >= offset {
148 let tlen = text.len() - offset;
149 let ret = pcmpestri_16_nomask(tp, offset, tlen, p, pat_len) as usize;
150 if ret == 16 {
151 offset += 16;
152 } else {
153 let match_len = cmp::min(pat_len, 16 - ret);
154 return Some((offset + ret, match_len));
155 }
156 }
157
158 None
159}
160
161#[test]
162fn test_first_start_of_match() {
163 let text = b"abc";
164 let longer = "longer text and so on";
165 assert_eq!(first_start_of_match(text, b"d"), None);
166 assert_eq!(first_start_of_match(text, b"c"), Some((2, 1)));
167 assert_eq!(first_start_of_match(text, b"abc"), Some((0, 3)));
168 assert_eq!(first_start_of_match(text, b"T"), None);
169 assert_eq!(first_start_of_match(text, b"\0text"), None);
170 assert_eq!(first_start_of_match(text, b"\0"), None);
171
172 for wsz in 1..17 {
174 for window in longer.as_bytes().windows(wsz) {
175 let str_find = longer.find(::std::str::from_utf8(window).unwrap());
176 assert!(str_find.is_some());
177 let first_start = first_start_of_match(longer.as_bytes(), window);
178 assert!(first_start.is_some());
179 let (pos, len) = first_start.unwrap();
180 assert!(len <= wsz);
181 assert!(len == wsz && Some(pos) == str_find
182 || pos <= str_find.unwrap());
183 }
184 }
185}
186
187fn find_2byte_pat(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
188 debug_assert!(text.len() >= pat.len());
189 debug_assert!(pat.len() == 2);
190 let mut off = 1;
194 while let Some(i) = memchr::memchr(pat[1], &text[off..]) {
195 match text.get(off + i - 1) {
196 None => break,
197 Some(&c) if c == pat[0] => return Some((off + i - 1, off + i + 1)),
198 _ => off += i + 1,
199 }
200
201 }
202 None
203}
204
205#[target_feature(enable = "sse4.2")]
207unsafe fn find_short_pat(text: &[u8], pat: &[u8]) -> Option<usize> {
208 debug_assert!(pat.len() <= 8);
209 let r = pat128(pat);
215
216 let safetext = &text[..cmp::max(text.len(), 16) - 16];
218
219 let mut pos = 0;
220 'search: loop {
221 if pos + pat.len() > safetext.len() {
222 break;
223 }
224 match first_start_of_match_nomask(&safetext[pos..], pat.len(), r) {
226 None => {
227 pos = cmp::max(pos, safetext.len() - pat.len());
228 break }
230 Some((mpos, mlen)) => {
231 pos += mpos;
232 if mlen < pat.len() {
233 if pos > text.len() - pat.len() {
234 return None;
235 }
236 for (&a, &b) in zip(&text[pos + mlen..], &pat[mlen..]) {
237 if a != b {
238 pos += 1;
239 continue 'search;
240 }
241 }
242 }
243
244 return Some(pos);
245 }
246 }
247 }
248
249 'tail: loop {
250 if pos > text.len() - pat.len() {
251 return None;
252 }
253 match first_start_of_match_mask(&text[pos..], pat.len(), r) {
255 None => return None, Some((mpos, mlen)) => {
257 pos += mpos;
258 if mlen < pat.len() {
259 if pos > text.len() - pat.len() {
260 return None;
261 }
262 for (&a, &b) in zip(&text[pos + mlen..], &pat[mlen..]) {
263 if a != b {
264 pos += 1;
265 continue 'tail;
266 }
267 }
268 }
269
270 return Some(pos);
271 }
272 }
273 }
274}
275
276pub fn is_supported() -> bool {
278 #[cfg(feature = "use_std")]
279 return is_x86_feature_detected!("sse4.2");
280 #[cfg(not(feature = "use_std"))]
281 return cfg!(target_feature = "sse4.2");
282}
283
284pub fn find(text: &[u8], pattern: &[u8]) -> Option<usize> {
288 assert!(is_supported());
289
290 if pattern.is_empty() {
291 return Some(0);
292 } else if text.len() < pattern.len() {
293 return None;
294 } else if pattern.len() == 1 {
295 return memchr::memchr(pattern[0], text);
296 } else {
297 unsafe { find_inner(text, pattern) }
298 }
299}
300
301#[target_feature(enable = "sse4.2")]
302pub(crate) unsafe fn find_inner(text: &[u8], pat: &[u8]) -> Option<usize> {
303 if pat.len() <= 6 {
304 return find_short_pat(text, pat);
305 }
306
307 let (crit_pos, mut period) = TwoWaySearcher::crit_params(pat);
312 let mut memory;
313
314 if &pat[..crit_pos] == &pat[period.. period + crit_pos] {
315 memory = 0; } else {
317 memory = !0; period = cmp::max(crit_pos, pat.len() - crit_pos) + 1;
320 }
321
322 let (left, right) = pat.split_at(crit_pos);
324 let (right16, _right17) = right.split_at(cmp::min(16, right.len()));
325 assert!(right.len() != 0);
326
327 let r = pat128(right);
328
329 let safetext = &text[..cmp::max(text.len(), 16) - 16];
331
332 let mut pos = 0;
333 if memory == !0 {
334 'search: loop {
336 if pos + pat.len() > safetext.len() {
337 break;
338 }
339 let start = crit_pos;
341 match first_start_of_match_nomask(&safetext[pos + start..], right16.len(), r) {
342 None => {
343 pos = cmp::max(pos, safetext.len() - pat.len());
344 break }
346 Some((mpos, mlen)) => {
347 pos += mpos;
348 let mut pfxlen = mlen;
349 if pfxlen < right.len() {
350 pfxlen += shared_prefix_inner(&text[pos + start + mlen..], &right[mlen..]);
351 }
352 if pfxlen != right.len() {
353 pos += pfxlen + 1;
356 continue 'search;
357 } else {
358 }
360 }
361 }
362
363 if left != &text[pos..pos + left.len()] {
366 pos += period;
367 continue 'search;
368 }
369
370 return Some(pos);
371 }
372 } else {
373 'search_memory: loop {
375 if pos + pat.len() > safetext.len() {
376 break;
377 }
378 let mut pfxlen = if memory == 0 {
381 let start = crit_pos;
382 match first_start_of_match_nomask(&safetext[pos + start..], right16.len(), r) {
383 None => {
384 pos = cmp::max(pos, safetext.len() - pat.len());
385 break }
387 Some((mpos, mlen)) => {
388 pos += mpos;
389 mlen
390 }
391 }
392 } else {
393 memory - crit_pos
394 };
395 if pfxlen < right.len() {
396 pfxlen += shared_prefix_inner(&text[pos + crit_pos + pfxlen..], &right[pfxlen..]);
397 }
398 if pfxlen != right.len() {
399 pos += pfxlen + 1;
402 memory = 0;
403 continue 'search_memory;
404 } else {
405 }
407
408 if memory <= left.len() && &left[memory..] != &text[pos + memory..pos + left.len()] {
411 pos += period;
412 memory = pat.len() - period;
413 continue 'search_memory;
414 }
415
416 return Some(pos);
417 }
418 }
419
420 'tail: loop {
422 if pos > text.len() - pat.len() {
423 return None;
424 }
425 let start = crit_pos;
427 match first_start_of_match_mask(&text[pos + start..], right16.len(), r) {
428 None => return None,
429 Some((mpos, mlen)) => {
430 pos += mpos;
431 let mut pfxlen = mlen;
432 if pfxlen < right.len() {
433 pfxlen += shared_prefix_inner(&text[pos + start + mlen..], &right[mlen..]);
434 }
435 if pfxlen != right.len() {
436 pos += pfxlen + 1;
439 continue 'tail;
440
441 } else {
442 }
444 }
445 }
446
447 if left != &text[pos..pos + left.len()] {
450 pos += period;
451 continue 'tail;
452 }
453
454 return Some(pos);
455 }
456}
457
458#[test]
459fn test_find() {
460 let text = b"abc";
461 assert_eq!(find(text, b"d"), None);
462 assert_eq!(find(text, b"c"), Some(2));
463
464 let longer = "longer text and so on, a bit more";
465
466 for wsz in 1..longer.len() {
468 for window in longer.as_bytes().windows(wsz) {
469 let str_find = longer.find(::std::str::from_utf8(window).unwrap());
470 assert!(str_find.is_some());
471 assert_eq!(find(longer.as_bytes(), window), str_find, "{:?} {:?}",
472 longer, ::std::str::from_utf8(window));
473 }
474 }
475
476 let pat = b"ger text and so on";
477 assert!(pat.len() > 16);
478 assert_eq!(Some(3), find(longer.as_bytes(), pat));
479
480 let text = "cbabababcbabababab";
483 let n = "abababab";
484 assert_eq!(text.find(n), find(text.as_bytes(), n.as_bytes()));
485
486 let text = "cbababababababababababababababab";
488 let n = "abababab";
489 assert_eq!(text.find(n), find(text.as_bytes(), n.as_bytes()));
490
491}
492
493#[inline(always)]
495fn pat128(pat: &[u8]) -> __m128i {
496 unsafe {
497 mask_load(pat.as_ptr() as *const _, pat.len())
498 }
499}
500
501#[inline(always)]
503unsafe fn mask_load(ptr: *const u8, mut len: usize) -> __m128i {
504 let mut data: __m128i = _mm_setzero_si128();
505 len = cmp::min(len, mem::size_of_val(&data));
506
507 ::std::ptr::copy_nonoverlapping(ptr, &mut data as *mut _ as _, len);
508 return data;
509}
510
511pub fn shared_prefix(text: &[u8], pat: &[u8]) -> usize {
515 assert!(is_supported());
516
517 unsafe { shared_prefix_inner(text, pat) }
518}
519
520#[target_feature(enable = "sse4.2")]
521unsafe fn shared_prefix_inner(text: &[u8], pat: &[u8]) -> usize {
522 let tp = text.as_ptr();
523 let tlen = text.len();
524 let pp = pat.as_ptr();
525 let plen = pat.len();
526 let len = cmp::min(tlen, plen);
527
528 let initial_part = len.saturating_sub(16);
531 let mut prefix_len = 0;
532 let mut offset = 0;
533 while offset < initial_part {
534 let initial_tail = initial_part - offset;
535 let mask = pcmpestrm_eq_each(tp, offset, initial_tail, pp, offset, initial_tail);
536 if mask != 0xffff {
538 let first_bit_set = (mask ^ 0xffff).trailing_zeros() as usize;
539 prefix_len += first_bit_set;
540 return prefix_len;
541 } else {
542 prefix_len += cmp::min(initial_tail, 16);
543 }
544 offset += 16;
545 }
546 let text_suffix = get_unchecked(text, prefix_len..len);
549 let pat_suffix = get_unchecked(pat, prefix_len..len);
550 for (&a, &b) in zip(text_suffix, pat_suffix) {
551 if a != b {
552 break;
553 }
554 prefix_len += 1;
555 }
556
557 prefix_len
558}
559
560#[test]
561fn test_prefixlen() {
562 let text_long = b"0123456789abcdefeffect";
563 let text_long2 = b"9123456789abcdefeffect";
564 let text_long3 = b"0123456789abcdefgffect";
565 let plen = shared_prefix(text_long, text_long);
566 assert_eq!(plen, text_long.len());
567 let plen = shared_prefix(b"abcd", b"abc");
568 assert_eq!(plen, 3);
569 let plen = shared_prefix(b"abcd", b"abcf");
570 assert_eq!(plen, 3);
571 assert_eq!(0, shared_prefix(text_long, text_long2));
572 assert_eq!(0, shared_prefix(text_long, &text_long[1..]));
573 assert_eq!(16, shared_prefix(text_long, text_long3));
574
575 for i in 0..text_long.len() + 1 {
576 assert_eq!(text_long.len() - i, shared_prefix(&text_long[i..], &text_long[i..]));
577 }
578
579 let l1 = [7u8; 1024];
580 let mut l2 = [7u8; 1024];
581 let off = 1000;
582 l2[off] = 0;
583 for i in 0..off {
584 let plen = shared_prefix(&l1[i..], &l2[i..]);
585 assert_eq!(plen, off - i);
586 }
587}