rustpython_sre_engine/
string.rs

1#[derive(Debug, Clone, Copy)]
2pub struct StringCursor {
3    pub(crate) ptr: *const u8,
4    pub position: usize,
5}
6
7impl Default for StringCursor {
8    fn default() -> Self {
9        Self {
10            ptr: std::ptr::null(),
11            position: 0,
12        }
13    }
14}
15
16pub trait StrDrive: Copy {
17    fn count(&self) -> usize;
18    fn create_cursor(&self, n: usize) -> StringCursor;
19    fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize);
20    fn advance(cursor: &mut StringCursor) -> u32;
21    fn peek(cursor: &StringCursor) -> u32;
22    fn skip(cursor: &mut StringCursor, n: usize);
23    fn back_advance(cursor: &mut StringCursor) -> u32;
24    fn back_peek(cursor: &StringCursor) -> u32;
25    fn back_skip(cursor: &mut StringCursor, n: usize);
26}
27
28impl<'a> StrDrive for &'a [u8] {
29    #[inline]
30    fn count(&self) -> usize {
31        self.len()
32    }
33
34    #[inline]
35    fn create_cursor(&self, n: usize) -> StringCursor {
36        StringCursor {
37            ptr: self[n..].as_ptr(),
38            position: n,
39        }
40    }
41
42    #[inline]
43    fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) {
44        cursor.position = n;
45        cursor.ptr = self[n..].as_ptr();
46    }
47
48    #[inline]
49    fn advance(cursor: &mut StringCursor) -> u32 {
50        cursor.position += 1;
51        unsafe { cursor.ptr = cursor.ptr.add(1) };
52        unsafe { *cursor.ptr as u32 }
53    }
54
55    #[inline]
56    fn peek(cursor: &StringCursor) -> u32 {
57        unsafe { *cursor.ptr as u32 }
58    }
59
60    #[inline]
61    fn skip(cursor: &mut StringCursor, n: usize) {
62        cursor.position += n;
63        unsafe { cursor.ptr = cursor.ptr.add(n) };
64    }
65
66    #[inline]
67    fn back_advance(cursor: &mut StringCursor) -> u32 {
68        cursor.position -= 1;
69        unsafe { cursor.ptr = cursor.ptr.sub(1) };
70        unsafe { *cursor.ptr as u32 }
71    }
72
73    #[inline]
74    fn back_peek(cursor: &StringCursor) -> u32 {
75        unsafe { *cursor.ptr.offset(-1) as u32 }
76    }
77
78    #[inline]
79    fn back_skip(cursor: &mut StringCursor, n: usize) {
80        cursor.position -= n;
81        unsafe { cursor.ptr = cursor.ptr.sub(n) };
82    }
83}
84
85impl StrDrive for &str {
86    #[inline]
87    fn count(&self) -> usize {
88        self.chars().count()
89    }
90
91    #[inline]
92    fn create_cursor(&self, n: usize) -> StringCursor {
93        let mut cursor = StringCursor {
94            ptr: self.as_ptr(),
95            position: 0,
96        };
97        Self::skip(&mut cursor, n);
98        cursor
99    }
100
101    #[inline]
102    fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) {
103        if cursor.ptr.is_null() || cursor.position > n {
104            *cursor = Self::create_cursor(self, n);
105        } else if cursor.position < n {
106            Self::skip(cursor, n - cursor.position);
107        }
108    }
109
110    #[inline]
111    fn advance(cursor: &mut StringCursor) -> u32 {
112        cursor.position += 1;
113        unsafe { next_code_point(&mut cursor.ptr) }
114    }
115
116    #[inline]
117    fn peek(cursor: &StringCursor) -> u32 {
118        let mut ptr = cursor.ptr;
119        unsafe { next_code_point(&mut ptr) }
120    }
121
122    #[inline]
123    fn skip(cursor: &mut StringCursor, n: usize) {
124        cursor.position += n;
125        for _ in 0..n {
126            unsafe { next_code_point(&mut cursor.ptr) };
127        }
128    }
129
130    #[inline]
131    fn back_advance(cursor: &mut StringCursor) -> u32 {
132        cursor.position -= 1;
133        unsafe { next_code_point_reverse(&mut cursor.ptr) }
134    }
135
136    #[inline]
137    fn back_peek(cursor: &StringCursor) -> u32 {
138        let mut ptr = cursor.ptr;
139        unsafe { next_code_point_reverse(&mut ptr) }
140    }
141
142    #[inline]
143    fn back_skip(cursor: &mut StringCursor, n: usize) {
144        cursor.position -= n;
145        for _ in 0..n {
146            unsafe { next_code_point_reverse(&mut cursor.ptr) };
147        }
148    }
149}
150
151/// Reads the next code point out of a byte iterator (assuming a
152/// UTF-8-like encoding).
153///
154/// # Safety
155///
156/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
157#[inline]
158unsafe fn next_code_point(ptr: &mut *const u8) -> u32 {
159    // Decode UTF-8
160    let x = **ptr;
161    *ptr = ptr.offset(1);
162
163    if x < 128 {
164        return x as u32;
165    }
166
167    // Multibyte case follows
168    // Decode from a byte combination out of: [[[x y] z] w]
169    // NOTE: Performance is sensitive to the exact formulation here
170    let init = utf8_first_byte(x, 2);
171    // SAFETY: `bytes` produces an UTF-8-like string,
172    // so the iterator must produce a value here.
173    let y = **ptr;
174    *ptr = ptr.offset(1);
175    let mut ch = utf8_acc_cont_byte(init, y);
176    if x >= 0xE0 {
177        // [[x y z] w] case
178        // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid
179        // SAFETY: `bytes` produces an UTF-8-like string,
180        // so the iterator must produce a value here.
181        let z = **ptr;
182        *ptr = ptr.offset(1);
183        let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z);
184        ch = init << 12 | y_z;
185        if x >= 0xF0 {
186            // [x y z w] case
187            // use only the lower 3 bits of `init`
188            // SAFETY: `bytes` produces an UTF-8-like string,
189            // so the iterator must produce a value here.
190            let w = **ptr;
191            *ptr = ptr.offset(1);
192            ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w);
193        }
194    }
195
196    ch
197}
198
199/// Reads the last code point out of a byte iterator (assuming a
200/// UTF-8-like encoding).
201///
202/// # Safety
203///
204/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string
205#[inline]
206unsafe fn next_code_point_reverse(ptr: &mut *const u8) -> u32 {
207    // Decode UTF-8
208    *ptr = ptr.offset(-1);
209    let w = match **ptr {
210        next_byte if next_byte < 128 => return next_byte as u32,
211        back_byte => back_byte,
212    };
213
214    // Multibyte case follows
215    // Decode from a byte combination out of: [x [y [z w]]]
216    let mut ch;
217    // SAFETY: `bytes` produces an UTF-8-like string,
218    // so the iterator must produce a value here.
219    *ptr = ptr.offset(-1);
220    let z = **ptr;
221    ch = utf8_first_byte(z, 2);
222    if utf8_is_cont_byte(z) {
223        // SAFETY: `bytes` produces an UTF-8-like string,
224        // so the iterator must produce a value here.
225        *ptr = ptr.offset(-1);
226        let y = **ptr;
227        ch = utf8_first_byte(y, 3);
228        if utf8_is_cont_byte(y) {
229            // SAFETY: `bytes` produces an UTF-8-like string,
230            // so the iterator must produce a value here.
231            *ptr = ptr.offset(-1);
232            let x = **ptr;
233            ch = utf8_first_byte(x, 4);
234            ch = utf8_acc_cont_byte(ch, y);
235        }
236        ch = utf8_acc_cont_byte(ch, z);
237    }
238    ch = utf8_acc_cont_byte(ch, w);
239
240    ch
241}
242
243/// Returns the initial codepoint accumulator for the first byte.
244/// The first byte is special, only want bottom 5 bits for width 2, 4 bits
245/// for width 3, and 3 bits for width 4.
246#[inline]
247const fn utf8_first_byte(byte: u8, width: u32) -> u32 {
248    (byte & (0x7F >> width)) as u32
249}
250
251/// Returns the value of `ch` updated with continuation byte `byte`.
252#[inline]
253const fn utf8_acc_cont_byte(ch: u32, byte: u8) -> u32 {
254    (ch << 6) | (byte & CONT_MASK) as u32
255}
256
257/// Checks whether the byte is a UTF-8 continuation byte (i.e., starts with the
258/// bits `10`).
259#[inline]
260const fn utf8_is_cont_byte(byte: u8) -> bool {
261    (byte as i8) < -64
262}
263
264/// Mask of the value bits of a continuation byte.
265const CONT_MASK: u8 = 0b0011_1111;
266
267const fn is_py_ascii_whitespace(b: u8) -> bool {
268    matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B')
269}
270
271#[inline]
272pub(crate) fn is_word(ch: u32) -> bool {
273    ch == '_' as u32
274        || u8::try_from(ch)
275            .map(|x| x.is_ascii_alphanumeric())
276            .unwrap_or(false)
277}
278#[inline]
279pub(crate) fn is_space(ch: u32) -> bool {
280    u8::try_from(ch)
281        .map(is_py_ascii_whitespace)
282        .unwrap_or(false)
283}
284#[inline]
285pub(crate) fn is_digit(ch: u32) -> bool {
286    u8::try_from(ch)
287        .map(|x| x.is_ascii_digit())
288        .unwrap_or(false)
289}
290#[inline]
291pub(crate) fn is_loc_alnum(ch: u32) -> bool {
292    // FIXME: Ignore the locales
293    u8::try_from(ch)
294        .map(|x| x.is_ascii_alphanumeric())
295        .unwrap_or(false)
296}
297#[inline]
298pub(crate) fn is_loc_word(ch: u32) -> bool {
299    ch == '_' as u32 || is_loc_alnum(ch)
300}
301#[inline]
302pub(crate) fn is_linebreak(ch: u32) -> bool {
303    ch == '\n' as u32
304}
305#[inline]
306pub fn lower_ascii(ch: u32) -> u32 {
307    u8::try_from(ch)
308        .map(|x| x.to_ascii_lowercase() as u32)
309        .unwrap_or(ch)
310}
311#[inline]
312pub(crate) fn lower_locate(ch: u32) -> u32 {
313    // FIXME: Ignore the locales
314    lower_ascii(ch)
315}
316#[inline]
317pub(crate) fn upper_locate(ch: u32) -> u32 {
318    // FIXME: Ignore the locales
319    u8::try_from(ch)
320        .map(|x| x.to_ascii_uppercase() as u32)
321        .unwrap_or(ch)
322}
323#[inline]
324pub(crate) fn is_uni_digit(ch: u32) -> bool {
325    // TODO: check with cpython
326    char::try_from(ch)
327        .map(|x| x.is_ascii_digit())
328        .unwrap_or(false)
329}
330#[inline]
331pub(crate) fn is_uni_space(ch: u32) -> bool {
332    // TODO: check with cpython
333    is_space(ch)
334        || matches!(
335            ch,
336            0x0009
337                | 0x000A
338                | 0x000B
339                | 0x000C
340                | 0x000D
341                | 0x001C
342                | 0x001D
343                | 0x001E
344                | 0x001F
345                | 0x0020
346                | 0x0085
347                | 0x00A0
348                | 0x1680
349                | 0x2000
350                | 0x2001
351                | 0x2002
352                | 0x2003
353                | 0x2004
354                | 0x2005
355                | 0x2006
356                | 0x2007
357                | 0x2008
358                | 0x2009
359                | 0x200A
360                | 0x2028
361                | 0x2029
362                | 0x202F
363                | 0x205F
364                | 0x3000
365        )
366}
367#[inline]
368pub(crate) fn is_uni_linebreak(ch: u32) -> bool {
369    matches!(
370        ch,
371        0x000A | 0x000B | 0x000C | 0x000D | 0x001C | 0x001D | 0x001E | 0x0085 | 0x2028 | 0x2029
372    )
373}
374#[inline]
375pub(crate) fn is_uni_alnum(ch: u32) -> bool {
376    // TODO: check with cpython
377    char::try_from(ch)
378        .map(|x| x.is_alphanumeric())
379        .unwrap_or(false)
380}
381#[inline]
382pub(crate) fn is_uni_word(ch: u32) -> bool {
383    ch == '_' as u32 || is_uni_alnum(ch)
384}
385#[inline]
386pub fn lower_unicode(ch: u32) -> u32 {
387    // TODO: check with cpython
388    char::try_from(ch)
389        .map(|x| x.to_lowercase().next().unwrap() as u32)
390        .unwrap_or(ch)
391}
392#[inline]
393pub fn upper_unicode(ch: u32) -> u32 {
394    // TODO: check with cpython
395    char::try_from(ch)
396        .map(|x| x.to_uppercase().next().unwrap() as u32)
397        .unwrap_or(ch)
398}