rustpython_common/
str.rs

1use crate::{
2    atomic::{PyAtomic, Radium},
3    hash::PyHash,
4};
5use ascii::AsciiString;
6use rustpython_format::CharLen;
7use std::ops::{Bound, RangeBounds};
8
9#[cfg(not(target_arch = "wasm32"))]
10#[allow(non_camel_case_types)]
11pub type wchar_t = libc::wchar_t;
12#[cfg(target_arch = "wasm32")]
13#[allow(non_camel_case_types)]
14pub type wchar_t = u32;
15
16/// Utf8 + state.ascii (+ PyUnicode_Kind in future)
17#[derive(Debug, Copy, Clone, PartialEq)]
18pub enum PyStrKind {
19    Ascii,
20    Utf8,
21}
22
23impl std::ops::BitOr for PyStrKind {
24    type Output = Self;
25    fn bitor(self, other: Self) -> Self {
26        match (self, other) {
27            (Self::Ascii, Self::Ascii) => Self::Ascii,
28            _ => Self::Utf8,
29        }
30    }
31}
32
33impl PyStrKind {
34    #[inline]
35    pub fn new_data(self) -> PyStrKindData {
36        match self {
37            PyStrKind::Ascii => PyStrKindData::Ascii,
38            PyStrKind::Utf8 => PyStrKindData::Utf8(Radium::new(usize::MAX)),
39        }
40    }
41}
42
43#[derive(Debug)]
44pub enum PyStrKindData {
45    Ascii,
46    // uses usize::MAX as a sentinel for "uncomputed"
47    Utf8(PyAtomic<usize>),
48}
49
50impl PyStrKindData {
51    #[inline]
52    pub fn kind(&self) -> PyStrKind {
53        match self {
54            PyStrKindData::Ascii => PyStrKind::Ascii,
55            PyStrKindData::Utf8(_) => PyStrKind::Utf8,
56        }
57    }
58}
59
60pub struct BorrowedStr<'a> {
61    bytes: &'a [u8],
62    kind: PyStrKindData,
63    #[allow(dead_code)]
64    hash: PyAtomic<PyHash>,
65}
66
67impl<'a> BorrowedStr<'a> {
68    /// # Safety
69    /// `s` have to be an ascii string
70    #[inline]
71    pub unsafe fn from_ascii_unchecked(s: &'a [u8]) -> Self {
72        debug_assert!(s.is_ascii());
73        Self {
74            bytes: s,
75            kind: PyStrKind::Ascii.new_data(),
76            hash: PyAtomic::<PyHash>::new(0),
77        }
78    }
79
80    #[inline]
81    pub fn from_bytes(s: &'a [u8]) -> Self {
82        let k = if s.is_ascii() {
83            PyStrKind::Ascii.new_data()
84        } else {
85            PyStrKind::Utf8.new_data()
86        };
87        Self {
88            bytes: s,
89            kind: k,
90            hash: PyAtomic::<PyHash>::new(0),
91        }
92    }
93
94    #[inline]
95    pub fn as_str(&self) -> &str {
96        unsafe {
97            // SAFETY: Both PyStrKind::{Ascii, Utf8} are valid utf8 string
98            std::str::from_utf8_unchecked(self.bytes)
99        }
100    }
101
102    #[inline]
103    pub fn char_len(&self) -> usize {
104        match self.kind {
105            PyStrKindData::Ascii => self.bytes.len(),
106            PyStrKindData::Utf8(ref len) => match len.load(core::sync::atomic::Ordering::Relaxed) {
107                usize::MAX => self._compute_char_len(),
108                len => len,
109            },
110        }
111    }
112
113    #[cold]
114    fn _compute_char_len(&self) -> usize {
115        match self.kind {
116            PyStrKindData::Utf8(ref char_len) => {
117                let len = self.as_str().chars().count();
118                // len cannot be usize::MAX, since vec.capacity() < sys.maxsize
119                char_len.store(len, core::sync::atomic::Ordering::Relaxed);
120                len
121            }
122            _ => unsafe {
123                debug_assert!(false); // invalid for non-utf8 strings
124                std::hint::unreachable_unchecked()
125            },
126        }
127    }
128}
129
130impl std::ops::Deref for BorrowedStr<'_> {
131    type Target = str;
132    fn deref(&self) -> &str {
133        self.as_str()
134    }
135}
136
137impl std::fmt::Display for BorrowedStr<'_> {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        self.as_str().fmt(f)
140    }
141}
142
143impl CharLen for BorrowedStr<'_> {
144    fn char_len(&self) -> usize {
145        self.char_len()
146    }
147}
148
149pub fn try_get_chars(s: &str, range: impl RangeBounds<usize>) -> Option<&str> {
150    let mut chars = s.chars();
151    let start = match range.start_bound() {
152        Bound::Included(&i) => i,
153        Bound::Excluded(&i) => i + 1,
154        Bound::Unbounded => 0,
155    };
156    for _ in 0..start {
157        chars.next()?;
158    }
159    let s = chars.as_str();
160    let range_len = match range.end_bound() {
161        Bound::Included(&i) => i + 1 - start,
162        Bound::Excluded(&i) => i - start,
163        Bound::Unbounded => return Some(s),
164    };
165    char_range_end(s, range_len).map(|end| &s[..end])
166}
167
168pub fn get_chars(s: &str, range: impl RangeBounds<usize>) -> &str {
169    try_get_chars(s, range).unwrap()
170}
171
172#[inline]
173pub fn char_range_end(s: &str, nchars: usize) -> Option<usize> {
174    let i = match nchars.checked_sub(1) {
175        Some(last_char_index) => {
176            let (index, c) = s.char_indices().nth(last_char_index)?;
177            index + c.len_utf8()
178        }
179        None => 0,
180    };
181    Some(i)
182}
183
184pub fn zfill(bytes: &[u8], width: usize) -> Vec<u8> {
185    if width <= bytes.len() {
186        bytes.to_vec()
187    } else {
188        let (sign, s) = match bytes.first() {
189            Some(_sign @ b'+') | Some(_sign @ b'-') => {
190                (unsafe { bytes.get_unchecked(..1) }, &bytes[1..])
191            }
192            _ => (&b""[..], bytes),
193        };
194        let mut filled = Vec::new();
195        filled.extend_from_slice(sign);
196        filled.extend(std::iter::repeat(b'0').take(width - bytes.len()));
197        filled.extend_from_slice(s);
198        filled
199    }
200}
201
202/// Convert a string to ascii compatible, escaping unicodes into escape
203/// sequences.
204pub fn to_ascii(value: &str) -> AsciiString {
205    let mut ascii = Vec::new();
206    for c in value.chars() {
207        if c.is_ascii() {
208            ascii.push(c as u8);
209        } else {
210            let c = c as i64;
211            let hex = if c < 0x100 {
212                format!("\\x{c:02x}")
213            } else if c < 0x10000 {
214                format!("\\u{c:04x}")
215            } else {
216                format!("\\U{c:08x}")
217            };
218            ascii.append(&mut hex.into_bytes());
219        }
220    }
221    unsafe { AsciiString::from_ascii_unchecked(ascii) }
222}
223
224pub mod levenshtein {
225    use std::{cell::RefCell, thread_local};
226
227    pub const MOVE_COST: usize = 2;
228    const CASE_COST: usize = 1;
229    const MAX_STRING_SIZE: usize = 40;
230
231    fn substitution_cost(mut a: u8, mut b: u8) -> usize {
232        if (a & 31) != (b & 31) {
233            return MOVE_COST;
234        }
235        if a == b {
236            return 0;
237        }
238        if a.is_ascii_uppercase() {
239            a += b'a' - b'A';
240        }
241        if b.is_ascii_uppercase() {
242            b += b'a' - b'A';
243        }
244        if a == b {
245            CASE_COST
246        } else {
247            MOVE_COST
248        }
249    }
250
251    pub fn levenshtein_distance(a: &str, b: &str, max_cost: usize) -> usize {
252        thread_local! {
253            static BUFFER: RefCell<[usize; MAX_STRING_SIZE]> = const { RefCell::new([0usize; MAX_STRING_SIZE]) };
254        }
255
256        if a == b {
257            return 0;
258        }
259
260        let (mut a_bytes, mut b_bytes) = (a.as_bytes(), b.as_bytes());
261        let (mut a_begin, mut a_end) = (0usize, a.len());
262        let (mut b_begin, mut b_end) = (0usize, b.len());
263
264        while a_end > 0 && b_end > 0 && (a_bytes[a_begin] == b_bytes[b_begin]) {
265            a_begin += 1;
266            b_begin += 1;
267            a_end -= 1;
268            b_end -= 1;
269        }
270        while a_end > 0
271            && b_end > 0
272            && (a_bytes[a_begin + a_end - 1] == b_bytes[b_begin + b_end - 1])
273        {
274            a_end -= 1;
275            b_end -= 1;
276        }
277        if a_end == 0 || b_end == 0 {
278            return (a_end + b_end) * MOVE_COST;
279        }
280        if a_end > MAX_STRING_SIZE || b_end > MAX_STRING_SIZE {
281            return max_cost + 1;
282        }
283
284        if b_end < a_end {
285            std::mem::swap(&mut a_bytes, &mut b_bytes);
286            std::mem::swap(&mut a_begin, &mut b_begin);
287            std::mem::swap(&mut a_end, &mut b_end);
288        }
289
290        if (b_end - a_end) * MOVE_COST > max_cost {
291            return max_cost + 1;
292        }
293
294        BUFFER.with(|buffer| {
295            let mut buffer = buffer.borrow_mut();
296            for i in 0..a_end {
297                buffer[i] = (i + 1) * MOVE_COST;
298            }
299
300            let mut result = 0usize;
301            for (b_index, b_code) in b_bytes[b_begin..(b_begin + b_end)].iter().enumerate() {
302                result = b_index * MOVE_COST;
303                let mut distance = result;
304                let mut minimum = usize::MAX;
305                for (a_index, a_code) in a_bytes[a_begin..(a_begin + a_end)].iter().enumerate() {
306                    let substitute = distance + substitution_cost(*b_code, *a_code);
307                    distance = buffer[a_index];
308                    let insert_delete = usize::min(result, distance) + MOVE_COST;
309                    result = usize::min(insert_delete, substitute);
310
311                    buffer[a_index] = result;
312                    if result < minimum {
313                        minimum = result;
314                    }
315                }
316                if minimum > max_cost {
317                    return max_cost + 1;
318                }
319            }
320            result
321        })
322    }
323}
324
325/// Creates an [`AsciiStr`][ascii::AsciiStr] from a string literal, throwing a compile error if the
326/// literal isn't actually ascii.
327///
328/// ```compile_fail
329/// # use rustpython_common::str::ascii;
330/// ascii!("I ❤️ Rust & Python");
331/// ```
332#[macro_export]
333macro_rules! ascii {
334    ($x:literal) => {{
335        const STR: &str = $x;
336        const _: () = if !STR.is_ascii() {
337            panic!("ascii!() argument is not an ascii string");
338        };
339        unsafe { $crate::vendored::ascii::AsciiStr::from_ascii_unchecked(STR.as_bytes()) }
340    }};
341}
342pub use ascii;
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_get_chars() {
350        let s = "0123456789";
351        assert_eq!(get_chars(s, 3..7), "3456");
352        assert_eq!(get_chars(s, 3..7), &s[3..7]);
353
354        let s = "0유니코드 문자열9";
355        assert_eq!(get_chars(s, 3..7), "코드 문");
356
357        let s = "0😀😃😄😁😆😅😂🤣9";
358        assert_eq!(get_chars(s, 3..7), "😄😁😆😅");
359    }
360}