1use crate::{
2 atomic::{PyAtomic, Radium},
3 hash::PyHash,
4};
5use ascii::AsciiString;
6use rustpython_format::CharLen;
7use std::ops::{Bound, RangeBounds};
8
9#[cfg(not(target_arch = "wasm32"))]
10#[allow(non_camel_case_types)]
11pub type wchar_t = libc::wchar_t;
12#[cfg(target_arch = "wasm32")]
13#[allow(non_camel_case_types)]
14pub type wchar_t = u32;
15
16#[derive(Debug, Copy, Clone, PartialEq)]
18pub enum PyStrKind {
19 Ascii,
20 Utf8,
21}
22
23impl std::ops::BitOr for PyStrKind {
24 type Output = Self;
25 fn bitor(self, other: Self) -> Self {
26 match (self, other) {
27 (Self::Ascii, Self::Ascii) => Self::Ascii,
28 _ => Self::Utf8,
29 }
30 }
31}
32
33impl PyStrKind {
34 #[inline]
35 pub fn new_data(self) -> PyStrKindData {
36 match self {
37 PyStrKind::Ascii => PyStrKindData::Ascii,
38 PyStrKind::Utf8 => PyStrKindData::Utf8(Radium::new(usize::MAX)),
39 }
40 }
41}
42
43#[derive(Debug)]
44pub enum PyStrKindData {
45 Ascii,
46 Utf8(PyAtomic<usize>),
48}
49
50impl PyStrKindData {
51 #[inline]
52 pub fn kind(&self) -> PyStrKind {
53 match self {
54 PyStrKindData::Ascii => PyStrKind::Ascii,
55 PyStrKindData::Utf8(_) => PyStrKind::Utf8,
56 }
57 }
58}
59
60pub struct BorrowedStr<'a> {
61 bytes: &'a [u8],
62 kind: PyStrKindData,
63 #[allow(dead_code)]
64 hash: PyAtomic<PyHash>,
65}
66
67impl<'a> BorrowedStr<'a> {
68 #[inline]
71 pub unsafe fn from_ascii_unchecked(s: &'a [u8]) -> Self {
72 debug_assert!(s.is_ascii());
73 Self {
74 bytes: s,
75 kind: PyStrKind::Ascii.new_data(),
76 hash: PyAtomic::<PyHash>::new(0),
77 }
78 }
79
80 #[inline]
81 pub fn from_bytes(s: &'a [u8]) -> Self {
82 let k = if s.is_ascii() {
83 PyStrKind::Ascii.new_data()
84 } else {
85 PyStrKind::Utf8.new_data()
86 };
87 Self {
88 bytes: s,
89 kind: k,
90 hash: PyAtomic::<PyHash>::new(0),
91 }
92 }
93
94 #[inline]
95 pub fn as_str(&self) -> &str {
96 unsafe {
97 std::str::from_utf8_unchecked(self.bytes)
99 }
100 }
101
102 #[inline]
103 pub fn char_len(&self) -> usize {
104 match self.kind {
105 PyStrKindData::Ascii => self.bytes.len(),
106 PyStrKindData::Utf8(ref len) => match len.load(core::sync::atomic::Ordering::Relaxed) {
107 usize::MAX => self._compute_char_len(),
108 len => len,
109 },
110 }
111 }
112
113 #[cold]
114 fn _compute_char_len(&self) -> usize {
115 match self.kind {
116 PyStrKindData::Utf8(ref char_len) => {
117 let len = self.as_str().chars().count();
118 char_len.store(len, core::sync::atomic::Ordering::Relaxed);
120 len
121 }
122 _ => unsafe {
123 debug_assert!(false); std::hint::unreachable_unchecked()
125 },
126 }
127 }
128}
129
130impl std::ops::Deref for BorrowedStr<'_> {
131 type Target = str;
132 fn deref(&self) -> &str {
133 self.as_str()
134 }
135}
136
137impl std::fmt::Display for BorrowedStr<'_> {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 self.as_str().fmt(f)
140 }
141}
142
143impl CharLen for BorrowedStr<'_> {
144 fn char_len(&self) -> usize {
145 self.char_len()
146 }
147}
148
149pub fn try_get_chars(s: &str, range: impl RangeBounds<usize>) -> Option<&str> {
150 let mut chars = s.chars();
151 let start = match range.start_bound() {
152 Bound::Included(&i) => i,
153 Bound::Excluded(&i) => i + 1,
154 Bound::Unbounded => 0,
155 };
156 for _ in 0..start {
157 chars.next()?;
158 }
159 let s = chars.as_str();
160 let range_len = match range.end_bound() {
161 Bound::Included(&i) => i + 1 - start,
162 Bound::Excluded(&i) => i - start,
163 Bound::Unbounded => return Some(s),
164 };
165 char_range_end(s, range_len).map(|end| &s[..end])
166}
167
168pub fn get_chars(s: &str, range: impl RangeBounds<usize>) -> &str {
169 try_get_chars(s, range).unwrap()
170}
171
172#[inline]
173pub fn char_range_end(s: &str, nchars: usize) -> Option<usize> {
174 let i = match nchars.checked_sub(1) {
175 Some(last_char_index) => {
176 let (index, c) = s.char_indices().nth(last_char_index)?;
177 index + c.len_utf8()
178 }
179 None => 0,
180 };
181 Some(i)
182}
183
184pub fn zfill(bytes: &[u8], width: usize) -> Vec<u8> {
185 if width <= bytes.len() {
186 bytes.to_vec()
187 } else {
188 let (sign, s) = match bytes.first() {
189 Some(_sign @ b'+') | Some(_sign @ b'-') => {
190 (unsafe { bytes.get_unchecked(..1) }, &bytes[1..])
191 }
192 _ => (&b""[..], bytes),
193 };
194 let mut filled = Vec::new();
195 filled.extend_from_slice(sign);
196 filled.extend(std::iter::repeat(b'0').take(width - bytes.len()));
197 filled.extend_from_slice(s);
198 filled
199 }
200}
201
202pub fn to_ascii(value: &str) -> AsciiString {
205 let mut ascii = Vec::new();
206 for c in value.chars() {
207 if c.is_ascii() {
208 ascii.push(c as u8);
209 } else {
210 let c = c as i64;
211 let hex = if c < 0x100 {
212 format!("\\x{c:02x}")
213 } else if c < 0x10000 {
214 format!("\\u{c:04x}")
215 } else {
216 format!("\\U{c:08x}")
217 };
218 ascii.append(&mut hex.into_bytes());
219 }
220 }
221 unsafe { AsciiString::from_ascii_unchecked(ascii) }
222}
223
224pub mod levenshtein {
225 use std::{cell::RefCell, thread_local};
226
227 pub const MOVE_COST: usize = 2;
228 const CASE_COST: usize = 1;
229 const MAX_STRING_SIZE: usize = 40;
230
231 fn substitution_cost(mut a: u8, mut b: u8) -> usize {
232 if (a & 31) != (b & 31) {
233 return MOVE_COST;
234 }
235 if a == b {
236 return 0;
237 }
238 if a.is_ascii_uppercase() {
239 a += b'a' - b'A';
240 }
241 if b.is_ascii_uppercase() {
242 b += b'a' - b'A';
243 }
244 if a == b {
245 CASE_COST
246 } else {
247 MOVE_COST
248 }
249 }
250
251 pub fn levenshtein_distance(a: &str, b: &str, max_cost: usize) -> usize {
252 thread_local! {
253 static BUFFER: RefCell<[usize; MAX_STRING_SIZE]> = const { RefCell::new([0usize; MAX_STRING_SIZE]) };
254 }
255
256 if a == b {
257 return 0;
258 }
259
260 let (mut a_bytes, mut b_bytes) = (a.as_bytes(), b.as_bytes());
261 let (mut a_begin, mut a_end) = (0usize, a.len());
262 let (mut b_begin, mut b_end) = (0usize, b.len());
263
264 while a_end > 0 && b_end > 0 && (a_bytes[a_begin] == b_bytes[b_begin]) {
265 a_begin += 1;
266 b_begin += 1;
267 a_end -= 1;
268 b_end -= 1;
269 }
270 while a_end > 0
271 && b_end > 0
272 && (a_bytes[a_begin + a_end - 1] == b_bytes[b_begin + b_end - 1])
273 {
274 a_end -= 1;
275 b_end -= 1;
276 }
277 if a_end == 0 || b_end == 0 {
278 return (a_end + b_end) * MOVE_COST;
279 }
280 if a_end > MAX_STRING_SIZE || b_end > MAX_STRING_SIZE {
281 return max_cost + 1;
282 }
283
284 if b_end < a_end {
285 std::mem::swap(&mut a_bytes, &mut b_bytes);
286 std::mem::swap(&mut a_begin, &mut b_begin);
287 std::mem::swap(&mut a_end, &mut b_end);
288 }
289
290 if (b_end - a_end) * MOVE_COST > max_cost {
291 return max_cost + 1;
292 }
293
294 BUFFER.with(|buffer| {
295 let mut buffer = buffer.borrow_mut();
296 for i in 0..a_end {
297 buffer[i] = (i + 1) * MOVE_COST;
298 }
299
300 let mut result = 0usize;
301 for (b_index, b_code) in b_bytes[b_begin..(b_begin + b_end)].iter().enumerate() {
302 result = b_index * MOVE_COST;
303 let mut distance = result;
304 let mut minimum = usize::MAX;
305 for (a_index, a_code) in a_bytes[a_begin..(a_begin + a_end)].iter().enumerate() {
306 let substitute = distance + substitution_cost(*b_code, *a_code);
307 distance = buffer[a_index];
308 let insert_delete = usize::min(result, distance) + MOVE_COST;
309 result = usize::min(insert_delete, substitute);
310
311 buffer[a_index] = result;
312 if result < minimum {
313 minimum = result;
314 }
315 }
316 if minimum > max_cost {
317 return max_cost + 1;
318 }
319 }
320 result
321 })
322 }
323}
324
325#[macro_export]
333macro_rules! ascii {
334 ($x:literal) => {{
335 const STR: &str = $x;
336 const _: () = if !STR.is_ascii() {
337 panic!("ascii!() argument is not an ascii string");
338 };
339 unsafe { $crate::vendored::ascii::AsciiStr::from_ascii_unchecked(STR.as_bytes()) }
340 }};
341}
342pub use ascii;
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_get_chars() {
350 let s = "0123456789";
351 assert_eq!(get_chars(s, 3..7), "3456");
352 assert_eq!(get_chars(s, 3..7), &s[3..7]);
353
354 let s = "0유니코드 문자열9";
355 assert_eq!(get_chars(s, 3..7), "코드 문");
356
357 let s = "0😀😃😄😁😆😅😂🤣9";
358 assert_eq!(get_chars(s, 3..7), "😄😁😆😅");
359 }
360}