Skip to main content

rustpython_sre_engine/
string.rs

1use rustpython_wtf8::Wtf8;
2
3#[derive(Debug, Clone, Copy)]
4pub struct StringCursor {
5    pub(crate) ptr: *const u8,
6    pub position: usize,
7}
8
9impl Default for StringCursor {
10    fn default() -> Self {
11        Self {
12            ptr: core::ptr::null(),
13            position: 0,
14        }
15    }
16}
17
18pub trait StrDrive: Copy {
19    fn count(&self) -> usize;
20    fn create_cursor(&self, n: usize) -> StringCursor;
21    fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize);
22    fn advance(cursor: &mut StringCursor) -> u32;
23    fn peek(cursor: &StringCursor) -> u32;
24    fn skip(cursor: &mut StringCursor, n: usize);
25    fn back_advance(cursor: &mut StringCursor) -> u32;
26    fn back_peek(cursor: &StringCursor) -> u32;
27    fn back_skip(cursor: &mut StringCursor, n: usize);
28}
29
30impl StrDrive for &[u8] {
31    #[inline]
32    fn count(&self) -> usize {
33        self.len()
34    }
35
36    #[inline]
37    fn create_cursor(&self, n: usize) -> StringCursor {
38        StringCursor {
39            ptr: self[n..].as_ptr(),
40            position: n,
41        }
42    }
43
44    #[inline]
45    fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) {
46        cursor.position = n;
47        cursor.ptr = self[n..].as_ptr();
48    }
49
50    #[inline]
51    fn advance(cursor: &mut StringCursor) -> u32 {
52        cursor.position += 1;
53        unsafe { cursor.ptr = cursor.ptr.add(1) };
54        unsafe { *cursor.ptr as u32 }
55    }
56
57    #[inline]
58    fn peek(cursor: &StringCursor) -> u32 {
59        unsafe { *cursor.ptr as u32 }
60    }
61
62    #[inline]
63    fn skip(cursor: &mut StringCursor, n: usize) {
64        cursor.position += n;
65        unsafe { cursor.ptr = cursor.ptr.add(n) };
66    }
67
68    #[inline]
69    fn back_advance(cursor: &mut StringCursor) -> u32 {
70        cursor.position -= 1;
71        unsafe { cursor.ptr = cursor.ptr.sub(1) };
72        unsafe { *cursor.ptr as u32 }
73    }
74
75    #[inline]
76    fn back_peek(cursor: &StringCursor) -> u32 {
77        unsafe { *cursor.ptr.offset(-1) as u32 }
78    }
79
80    #[inline]
81    fn back_skip(cursor: &mut StringCursor, n: usize) {
82        cursor.position -= n;
83        unsafe { cursor.ptr = cursor.ptr.sub(n) };
84    }
85}
86
87impl StrDrive for &str {
88    #[inline]
89    fn count(&self) -> usize {
90        self.chars().count()
91    }
92
93    #[inline]
94    fn create_cursor(&self, n: usize) -> StringCursor {
95        let mut cursor = StringCursor {
96            ptr: self.as_ptr(),
97            position: 0,
98        };
99        Self::skip(&mut cursor, n);
100        cursor
101    }
102
103    #[inline]
104    fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) {
105        if cursor.ptr.is_null() || cursor.position > n {
106            *cursor = Self::create_cursor(self, n);
107        } else if cursor.position < n {
108            Self::skip(cursor, n - cursor.position);
109        }
110    }
111
112    #[inline]
113    fn advance(cursor: &mut StringCursor) -> u32 {
114        cursor.position += 1;
115        unsafe { next_code_point(&mut cursor.ptr) }
116    }
117
118    #[inline]
119    fn peek(cursor: &StringCursor) -> u32 {
120        let mut ptr = cursor.ptr;
121        unsafe { next_code_point(&mut ptr) }
122    }
123
124    #[inline]
125    fn skip(cursor: &mut StringCursor, n: usize) {
126        cursor.position += n;
127        for _ in 0..n {
128            unsafe { next_code_point(&mut cursor.ptr) };
129        }
130    }
131
132    #[inline]
133    fn back_advance(cursor: &mut StringCursor) -> u32 {
134        cursor.position -= 1;
135        unsafe { next_code_point_reverse(&mut cursor.ptr) }
136    }
137
138    #[inline]
139    fn back_peek(cursor: &StringCursor) -> u32 {
140        let mut ptr = cursor.ptr;
141        unsafe { next_code_point_reverse(&mut ptr) }
142    }
143
144    #[inline]
145    fn back_skip(cursor: &mut StringCursor, n: usize) {
146        cursor.position -= n;
147        for _ in 0..n {
148            unsafe { next_code_point_reverse(&mut cursor.ptr) };
149        }
150    }
151}
152
153impl StrDrive for &Wtf8 {
154    #[inline]
155    fn count(&self) -> usize {
156        self.code_points().count()
157    }
158
159    #[inline]
160    fn create_cursor(&self, n: usize) -> StringCursor {
161        let mut cursor = StringCursor {
162            ptr: self.as_bytes().as_ptr(),
163            position: 0,
164        };
165        Self::skip(&mut cursor, n);
166        cursor
167    }
168
169    #[inline]
170    fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) {
171        if cursor.ptr.is_null() || cursor.position > n {
172            *cursor = Self::create_cursor(self, n);
173        } else if cursor.position < n {
174            Self::skip(cursor, n - cursor.position);
175        }
176    }
177
178    #[inline]
179    fn advance(cursor: &mut StringCursor) -> u32 {
180        cursor.position += 1;
181        unsafe { next_code_point(&mut cursor.ptr) }
182    }
183
184    #[inline]
185    fn peek(cursor: &StringCursor) -> u32 {
186        let mut ptr = cursor.ptr;
187        unsafe { next_code_point(&mut ptr) }
188    }
189
190    #[inline]
191    fn skip(cursor: &mut StringCursor, n: usize) {
192        cursor.position += n;
193        for _ in 0..n {
194            unsafe { next_code_point(&mut cursor.ptr) };
195        }
196    }
197
198    #[inline]
199    fn back_advance(cursor: &mut StringCursor) -> u32 {
200        cursor.position -= 1;
201        unsafe { next_code_point_reverse(&mut cursor.ptr) }
202    }
203
204    #[inline]
205    fn back_peek(cursor: &StringCursor) -> u32 {
206        let mut ptr = cursor.ptr;
207        unsafe { next_code_point_reverse(&mut ptr) }
208    }
209
210    #[inline]
211    fn back_skip(cursor: &mut StringCursor, n: usize) {
212        cursor.position -= n;
213        for _ in 0..n {
214            unsafe { next_code_point_reverse(&mut cursor.ptr) };
215        }
216    }
217}
218
219/// Reads the next code point out of a byte iterator (assuming a
220/// UTF-8-like encoding).
221///
222/// # Safety
223///
224/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
225#[inline]
226const unsafe fn next_code_point(ptr: &mut *const u8) -> u32 {
227    // Decode UTF-8
228    let x = unsafe { **ptr };
229    *ptr = unsafe { ptr.offset(1) };
230
231    if x < 128 {
232        return x as u32;
233    }
234
235    // Multibyte case follows
236    // Decode from a byte combination out of: [[[x y] z] w]
237    // NOTE: Performance is sensitive to the exact formulation here
238    let init = utf8_first_byte(x, 2);
239    // SAFETY: `bytes` produces an UTF-8-like string,
240    // so the iterator must produce a value here.
241    let y = unsafe { **ptr };
242    *ptr = unsafe { ptr.offset(1) };
243    let mut ch = utf8_acc_cont_byte(init, y);
244    if x >= 0xE0 {
245        // [[x y z] w] case
246        // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid
247        // SAFETY: `bytes` produces an UTF-8-like string,
248        // so the iterator must produce a value here.
249        let z = unsafe { **ptr };
250        *ptr = unsafe { ptr.offset(1) };
251        let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z);
252        ch = (init << 12) | y_z;
253        if x >= 0xF0 {
254            // [x y z w] case
255            // use only the lower 3 bits of `init`
256            // SAFETY: `bytes` produces an UTF-8-like string,
257            // so the iterator must produce a value here.
258            let w = unsafe { **ptr };
259            *ptr = unsafe { ptr.offset(1) };
260            ch = ((init & 7) << 18) | utf8_acc_cont_byte(y_z, w);
261        }
262    }
263
264    ch
265}
266
267/// Reads the last code point out of a byte iterator (assuming a
268/// UTF-8-like encoding).
269///
270/// # Safety
271///
272/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
273#[inline]
274const unsafe fn next_code_point_reverse(ptr: &mut *const u8) -> u32 {
275    // Decode UTF-8
276    *ptr = unsafe { ptr.offset(-1) };
277    let w = match unsafe { **ptr } {
278        next_byte if next_byte < 128 => return next_byte as u32,
279        back_byte => back_byte,
280    };
281
282    // Multibyte case follows
283    // Decode from a byte combination out of: [x [y [z w]]]
284    let mut ch;
285    // SAFETY: `bytes` produces an UTF-8-like string,
286    // so the iterator must produce a value here.
287    *ptr = unsafe { ptr.offset(-1) };
288    let z = unsafe { **ptr };
289    ch = utf8_first_byte(z, 2);
290    if utf8_is_cont_byte(z) {
291        // SAFETY: `bytes` produces an UTF-8-like string,
292        // so the iterator must produce a value here.
293        *ptr = unsafe { ptr.offset(-1) };
294        let y = unsafe { **ptr };
295        ch = utf8_first_byte(y, 3);
296        if utf8_is_cont_byte(y) {
297            // SAFETY: `bytes` produces an UTF-8-like string,
298            // so the iterator must produce a value here.
299            *ptr = unsafe { ptr.offset(-1) };
300            let x = unsafe { **ptr };
301            ch = utf8_first_byte(x, 4);
302            ch = utf8_acc_cont_byte(ch, y);
303        }
304        ch = utf8_acc_cont_byte(ch, z);
305    }
306    ch = utf8_acc_cont_byte(ch, w);
307
308    ch
309}
310
311/// Returns the initial codepoint accumulator for the first byte.
312/// The first byte is special, only want bottom 5 bits for width 2, 4 bits
313/// for width 3, and 3 bits for width 4.
314#[inline]
315const fn utf8_first_byte(byte: u8, width: u32) -> u32 {
316    (byte & (0x7F >> width)) as u32
317}
318
319/// Returns the value of `ch` updated with continuation byte `byte`.
320#[inline]
321const fn utf8_acc_cont_byte(ch: u32, byte: u8) -> u32 {
322    (ch << 6) | (byte & CONT_MASK) as u32
323}
324
325/// Checks whether the byte is a UTF-8 continuation byte (i.e., starts with the
326/// bits `10`).
327#[inline]
328const fn utf8_is_cont_byte(byte: u8) -> bool {
329    (byte as i8) < -64
330}
331
332/// Mask of the value bits of a continuation byte.
333const CONT_MASK: u8 = 0b0011_1111;
334
335const fn is_py_ascii_whitespace(b: u8) -> bool {
336    matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B')
337}
338
339#[inline]
340pub(crate) fn is_word(ch: u32) -> bool {
341    ch == '_' as u32
342        || u8::try_from(ch)
343            .map(|x| x.is_ascii_alphanumeric())
344            .unwrap_or(false)
345}
346#[inline]
347pub(crate) fn is_space(ch: u32) -> bool {
348    u8::try_from(ch)
349        .map(is_py_ascii_whitespace)
350        .unwrap_or(false)
351}
352#[inline]
353pub(crate) fn is_digit(ch: u32) -> bool {
354    u8::try_from(ch)
355        .map(|x| x.is_ascii_digit())
356        .unwrap_or(false)
357}
358#[inline]
359pub(crate) fn is_loc_alnum(ch: u32) -> bool {
360    // FIXME: Ignore the locales
361    u8::try_from(ch)
362        .map(|x| x.is_ascii_alphanumeric())
363        .unwrap_or(false)
364}
365#[inline]
366pub(crate) fn is_loc_word(ch: u32) -> bool {
367    ch == '_' as u32 || is_loc_alnum(ch)
368}
369#[inline]
370pub(crate) const fn is_linebreak(ch: u32) -> bool {
371    ch == '\n' as u32
372}
373#[inline]
374pub fn lower_ascii(ch: u32) -> u32 {
375    u8::try_from(ch)
376        .map(|x| x.to_ascii_lowercase() as u32)
377        .unwrap_or(ch)
378}
379#[inline]
380pub(crate) fn lower_locate(ch: u32) -> u32 {
381    // FIXME: Ignore the locales
382    lower_ascii(ch)
383}
384#[inline]
385pub(crate) fn upper_locate(ch: u32) -> u32 {
386    // FIXME: Ignore the locales
387    u8::try_from(ch)
388        .map(|x| x.to_ascii_uppercase() as u32)
389        .unwrap_or(ch)
390}
391#[inline]
392pub(crate) fn is_uni_digit(ch: u32) -> bool {
393    // TODO: check with cpython
394    char::try_from(ch)
395        .map(|x| x.is_ascii_digit())
396        .unwrap_or(false)
397}
398#[inline]
399pub(crate) fn is_uni_space(ch: u32) -> bool {
400    // TODO: check with cpython
401    is_space(ch)
402        || matches!(
403            ch,
404            0x0009
405                | 0x000A
406                | 0x000B
407                | 0x000C
408                | 0x000D
409                | 0x001C
410                | 0x001D
411                | 0x001E
412                | 0x001F
413                | 0x0020
414                | 0x0085
415                | 0x00A0
416                | 0x1680
417                | 0x2000
418                | 0x2001
419                | 0x2002
420                | 0x2003
421                | 0x2004
422                | 0x2005
423                | 0x2006
424                | 0x2007
425                | 0x2008
426                | 0x2009
427                | 0x200A
428                | 0x2028
429                | 0x2029
430                | 0x202F
431                | 0x205F
432                | 0x3000
433        )
434}
435#[inline]
436pub(crate) const fn is_uni_linebreak(ch: u32) -> bool {
437    matches!(
438        ch,
439        0x000A | 0x000B | 0x000C | 0x000D | 0x001C | 0x001D | 0x001E | 0x0085 | 0x2028 | 0x2029
440    )
441}
442#[inline]
443pub(crate) fn is_uni_alnum(ch: u32) -> bool {
444    // TODO: check with cpython
445    char::try_from(ch)
446        .map(|x| x.is_alphanumeric())
447        .unwrap_or(false)
448}
449#[inline]
450pub(crate) fn is_uni_word(ch: u32) -> bool {
451    ch == '_' as u32 || is_uni_alnum(ch)
452}
453#[inline]
454pub fn lower_unicode(ch: u32) -> u32 {
455    // TODO: check with cpython
456    char::try_from(ch)
457        .map(|x| x.to_lowercase().next().unwrap() as u32)
458        .unwrap_or(ch)
459}
460#[inline]
461pub fn upper_unicode(ch: u32) -> u32 {
462    // TODO: check with cpython
463    char::try_from(ch)
464        .map(|x| x.to_uppercase().next().unwrap() as u32)
465        .unwrap_or(ch)
466}