rustpython_vm/
anystr.rs

1use crate::{
2    builtins::{PyIntRef, PyTuple},
3    cformat::cformat_string,
4    convert::TryFromBorrowedObject,
5    function::OptionalOption,
6    Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine,
7};
8use num_traits::{cast::ToPrimitive, sign::Signed};
9
10#[derive(FromArgs)]
11pub struct SplitArgs<T: TryFromObject + AnyStrWrapper> {
12    #[pyarg(any, default)]
13    sep: Option<T>,
14    #[pyarg(any, default = "-1")]
15    maxsplit: isize,
16}
17
18impl<T: TryFromObject + AnyStrWrapper> SplitArgs<T> {
19    pub fn get_value(self, vm: &VirtualMachine) -> PyResult<(Option<T>, isize)> {
20        let sep = if let Some(s) = self.sep {
21            let sep = s.as_ref();
22            if sep.is_empty() {
23                return Err(vm.new_value_error("empty separator".to_owned()));
24            }
25            Some(s)
26        } else {
27            None
28        };
29        Ok((sep, self.maxsplit))
30    }
31}
32
33#[derive(FromArgs)]
34pub struct SplitLinesArgs {
35    #[pyarg(any, default = "false")]
36    pub keepends: bool,
37}
38
39#[derive(FromArgs)]
40pub struct ExpandTabsArgs {
41    #[pyarg(any, default = "8")]
42    tabsize: isize,
43}
44
45impl ExpandTabsArgs {
46    pub fn tabsize(&self) -> usize {
47        self.tabsize.to_usize().unwrap_or(0)
48    }
49}
50
51#[derive(FromArgs)]
52pub struct StartsEndsWithArgs {
53    #[pyarg(positional)]
54    affix: PyObjectRef,
55    #[pyarg(positional, default)]
56    start: Option<PyIntRef>,
57    #[pyarg(positional, default)]
58    end: Option<PyIntRef>,
59}
60
61impl StartsEndsWithArgs {
62    pub fn get_value(self, len: usize) -> (PyObjectRef, Option<std::ops::Range<usize>>) {
63        let range = if self.start.is_some() || self.end.is_some() {
64            Some(adjust_indices(self.start, self.end, len))
65        } else {
66            None
67        };
68        (self.affix, range)
69    }
70
71    #[inline]
72    pub fn prepare<S, F>(self, s: &S, len: usize, substr: F) -> Option<(PyObjectRef, &S)>
73    where
74        S: ?Sized + AnyStr,
75        F: Fn(&S, std::ops::Range<usize>) -> &S,
76    {
77        let (affix, range) = self.get_value(len);
78        let substr = if let Some(range) = range {
79            if !range.is_normal() {
80                return None;
81            }
82            substr(s, range)
83        } else {
84            s
85        };
86        Some((affix, substr))
87    }
88}
89
90fn saturate_to_isize(py_int: PyIntRef) -> isize {
91    let big = py_int.as_bigint();
92    big.to_isize().unwrap_or_else(|| {
93        if big.is_negative() {
94            isize::MIN
95        } else {
96            isize::MAX
97        }
98    })
99}
100
101// help get optional string indices
102pub fn adjust_indices(
103    start: Option<PyIntRef>,
104    end: Option<PyIntRef>,
105    len: usize,
106) -> std::ops::Range<usize> {
107    let mut start = start.map_or(0, saturate_to_isize);
108    let mut end = end.map_or(len as isize, saturate_to_isize);
109    if end > len as isize {
110        end = len as isize;
111    } else if end < 0 {
112        end += len as isize;
113        if end < 0 {
114            end = 0;
115        }
116    }
117    if start < 0 {
118        start += len as isize;
119        if start < 0 {
120            start = 0;
121        }
122    }
123    start as usize..end as usize
124}
125
126pub trait StringRange {
127    fn is_normal(&self) -> bool;
128}
129
130impl StringRange for std::ops::Range<usize> {
131    fn is_normal(&self) -> bool {
132        self.start <= self.end
133    }
134}
135
136pub trait AnyStrWrapper {
137    type Str: ?Sized + AnyStr;
138    fn as_ref(&self) -> &Self::Str;
139}
140
141pub trait AnyStrContainer<S>
142where
143    S: ?Sized,
144{
145    fn new() -> Self;
146    fn with_capacity(capacity: usize) -> Self;
147    fn push_str(&mut self, s: &S);
148}
149
150pub trait AnyStr {
151    type Char: Copy;
152    type Container: AnyStrContainer<Self> + Extend<Self::Char>;
153
154    fn element_bytes_len(c: Self::Char) -> usize;
155
156    fn to_container(&self) -> Self::Container;
157    fn as_bytes(&self) -> &[u8];
158    fn as_utf8_str(&self) -> Result<&str, std::str::Utf8Error>;
159    fn chars(&self) -> impl Iterator<Item = char>;
160    fn elements(&self) -> impl Iterator<Item = Self::Char>;
161    fn get_bytes(&self, range: std::ops::Range<usize>) -> &Self;
162    // FIXME: get_chars is expensive for str
163    fn get_chars(&self, range: std::ops::Range<usize>) -> &Self;
164    fn bytes_len(&self) -> usize;
165    // NOTE: str::chars().count() consumes the O(n) time. But pystr::char_len does cache.
166    //       So using chars_len directly is too expensive and the below method shouldn't be implemented.
167    // fn chars_len(&self) -> usize;
168    fn is_empty(&self) -> bool;
169
170    fn py_add(&self, other: &Self) -> Self::Container {
171        let mut new = Self::Container::with_capacity(self.bytes_len() + other.bytes_len());
172        new.push_str(self);
173        new.push_str(other);
174        new
175    }
176
177    fn py_split<T, SP, SN, SW, R>(
178        &self,
179        args: SplitArgs<T>,
180        vm: &VirtualMachine,
181        split: SP,
182        splitn: SN,
183        splitw: SW,
184    ) -> PyResult<Vec<R>>
185    where
186        T: TryFromObject + AnyStrWrapper<Str = Self>,
187        SP: Fn(&Self, &Self, &VirtualMachine) -> Vec<R>,
188        SN: Fn(&Self, &Self, usize, &VirtualMachine) -> Vec<R>,
189        SW: Fn(&Self, isize, &VirtualMachine) -> Vec<R>,
190    {
191        let (sep, maxsplit) = args.get_value(vm)?;
192        let splits = if let Some(pattern) = sep {
193            if maxsplit < 0 {
194                split(self, pattern.as_ref(), vm)
195            } else {
196                splitn(self, pattern.as_ref(), (maxsplit + 1) as usize, vm)
197            }
198        } else {
199            splitw(self, maxsplit, vm)
200        };
201        Ok(splits)
202    }
203    fn py_split_whitespace<F>(&self, maxsplit: isize, convert: F) -> Vec<PyObjectRef>
204    where
205        F: Fn(&Self) -> PyObjectRef;
206    fn py_rsplit_whitespace<F>(&self, maxsplit: isize, convert: F) -> Vec<PyObjectRef>
207    where
208        F: Fn(&Self) -> PyObjectRef;
209
210    #[inline]
211    fn py_startsendswith<'a, T, F>(
212        &self,
213        affix: &'a PyObject,
214        func_name: &str,
215        py_type_name: &str,
216        func: F,
217        vm: &VirtualMachine,
218    ) -> PyResult<bool>
219    where
220        T: TryFromBorrowedObject<'a>,
221        F: Fn(&Self, T) -> bool,
222    {
223        single_or_tuple_any(
224            affix,
225            &|s: T| Ok(func(self, s)),
226            &|o| {
227                format!(
228                    "{} first arg must be {} or a tuple of {}, not {}",
229                    func_name,
230                    py_type_name,
231                    py_type_name,
232                    o.class(),
233                )
234            },
235            vm,
236        )
237    }
238
239    #[inline]
240    fn py_strip<'a, S, FC, FD>(
241        &'a self,
242        chars: OptionalOption<S>,
243        func_chars: FC,
244        func_default: FD,
245    ) -> &'a Self
246    where
247        S: AnyStrWrapper<Str = Self>,
248        FC: Fn(&'a Self, &Self) -> &'a Self,
249        FD: Fn(&'a Self) -> &'a Self,
250    {
251        let chars = chars.flatten();
252        match chars {
253            Some(chars) => func_chars(self, chars.as_ref()),
254            None => func_default(self),
255        }
256    }
257
258    #[inline]
259    fn py_find<F>(&self, needle: &Self, range: std::ops::Range<usize>, find: F) -> Option<usize>
260    where
261        F: Fn(&Self, &Self) -> Option<usize>,
262    {
263        if range.is_normal() {
264            let start = range.start;
265            let index = find(self.get_chars(range), needle)?;
266            Some(start + index)
267        } else {
268            None
269        }
270    }
271
272    #[inline]
273    fn py_count<F>(&self, needle: &Self, range: std::ops::Range<usize>, count: F) -> usize
274    where
275        F: Fn(&Self, &Self) -> usize,
276    {
277        if range.is_normal() {
278            count(self.get_chars(range), needle)
279        } else {
280            0
281        }
282    }
283
284    fn py_pad(&self, left: usize, right: usize, fillchar: Self::Char) -> Self::Container {
285        let mut u = Self::Container::with_capacity(
286            (left + right) * Self::element_bytes_len(fillchar) + self.bytes_len(),
287        );
288        u.extend(std::iter::repeat(fillchar).take(left));
289        u.push_str(self);
290        u.extend(std::iter::repeat(fillchar).take(right));
291        u
292    }
293
294    fn py_center(&self, width: usize, fillchar: Self::Char, len: usize) -> Self::Container {
295        let marg = width - len;
296        let left = marg / 2 + (marg & width & 1);
297        self.py_pad(left, marg - left, fillchar)
298    }
299
300    fn py_ljust(&self, width: usize, fillchar: Self::Char, len: usize) -> Self::Container {
301        self.py_pad(0, width - len, fillchar)
302    }
303
304    fn py_rjust(&self, width: usize, fillchar: Self::Char, len: usize) -> Self::Container {
305        self.py_pad(width - len, 0, fillchar)
306    }
307
308    fn py_join(
309        &self,
310        mut iter: impl std::iter::Iterator<
311            Item = PyResult<impl AnyStrWrapper<Str = Self> + TryFromObject>,
312        >,
313    ) -> PyResult<Self::Container> {
314        let mut joined = if let Some(elem) = iter.next() {
315            elem?.as_ref().to_container()
316        } else {
317            return Ok(Self::Container::new());
318        };
319        for elem in iter {
320            let elem = elem?;
321            joined.push_str(self);
322            joined.push_str(elem.as_ref());
323        }
324        Ok(joined)
325    }
326
327    fn py_partition<'a, F, S>(
328        &'a self,
329        sub: &Self,
330        split: F,
331        vm: &VirtualMachine,
332    ) -> PyResult<(Self::Container, bool, Self::Container)>
333    where
334        F: Fn() -> S,
335        S: std::iter::Iterator<Item = &'a Self>,
336    {
337        if sub.is_empty() {
338            return Err(vm.new_value_error("empty separator".to_owned()));
339        }
340
341        let mut sp = split();
342        let front = sp.next().unwrap().to_container();
343        let (has_mid, back) = if let Some(back) = sp.next() {
344            (true, back.to_container())
345        } else {
346            (false, Self::Container::new())
347        };
348        Ok((front, has_mid, back))
349    }
350
351    fn py_removeprefix<FC>(&self, prefix: &Self, prefix_len: usize, is_prefix: FC) -> &Self
352    where
353        FC: Fn(&Self, &Self) -> bool,
354    {
355        //if self.py_starts_with(prefix) {
356        if is_prefix(self, prefix) {
357            self.get_bytes(prefix_len..self.bytes_len())
358        } else {
359            self
360        }
361    }
362
363    fn py_removesuffix<FC>(&self, suffix: &Self, suffix_len: usize, is_suffix: FC) -> &Self
364    where
365        FC: Fn(&Self, &Self) -> bool,
366    {
367        if is_suffix(self, suffix) {
368            self.get_bytes(0..self.bytes_len() - suffix_len)
369        } else {
370            self
371        }
372    }
373
374    // TODO: remove this function from anystr.
375    // See https://github.com/RustPython/RustPython/pull/4709/files#r1141013993
376    fn py_bytes_splitlines<FW, W>(&self, options: SplitLinesArgs, into_wrapper: FW) -> Vec<W>
377    where
378        FW: Fn(&Self) -> W,
379    {
380        let keep = options.keepends as usize;
381        let mut elements = Vec::new();
382        let mut last_i = 0;
383        let mut enumerated = self.as_bytes().iter().enumerate().peekable();
384        while let Some((i, ch)) = enumerated.next() {
385            let (end_len, i_diff) = match *ch {
386                b'\n' => (keep, 1),
387                b'\r' => {
388                    let is_rn = enumerated.peek().map_or(false, |(_, ch)| **ch == b'\n');
389                    if is_rn {
390                        let _ = enumerated.next();
391                        (keep + keep, 2)
392                    } else {
393                        (keep, 1)
394                    }
395                }
396                _ => {
397                    continue;
398                }
399            };
400            let range = last_i..i + end_len;
401            last_i = i + i_diff;
402            elements.push(into_wrapper(self.get_bytes(range)));
403        }
404        if last_i != self.bytes_len() {
405            elements.push(into_wrapper(self.get_bytes(last_i..self.bytes_len())));
406        }
407        elements
408    }
409
410    fn py_zfill(&self, width: isize) -> Vec<u8> {
411        let width = width.to_usize().unwrap_or(0);
412        rustpython_common::str::zfill(self.as_bytes(), width)
413    }
414
415    fn py_iscase<F, G>(&self, is_case: F, is_opposite: G) -> bool
416    where
417        F: Fn(char) -> bool,
418        G: Fn(char) -> bool,
419    {
420        // Unified form of CPython functions:
421        //  _Py_bytes_islower
422        //   Py_bytes_isupper
423        //  unicode_islower_impl
424        //  unicode_isupper_impl
425        let mut cased = false;
426        for c in self.chars() {
427            if is_opposite(c) {
428                return false;
429            } else if !cased && is_case(c) {
430                cased = true
431            }
432        }
433        cased
434    }
435
436    fn py_cformat(&self, values: PyObjectRef, vm: &VirtualMachine) -> PyResult<String> {
437        let format_string = self.as_utf8_str().unwrap();
438        cformat_string(vm, format_string, values)
439    }
440}
441
442/// Tests that the predicate is True on a single value, or if the value is a tuple a tuple, then
443/// test that any of the values contained within the tuples satisfies the predicate. Type parameter
444/// T specifies the type that is expected, if the input value is not of that type or a tuple of
445/// values of that type, then a TypeError is raised.
446pub fn single_or_tuple_any<'a, T, F, M>(
447    obj: &'a PyObject,
448    predicate: &F,
449    message: &M,
450    vm: &VirtualMachine,
451) -> PyResult<bool>
452where
453    T: TryFromBorrowedObject<'a>,
454    F: Fn(T) -> PyResult<bool>,
455    M: Fn(&PyObject) -> String,
456{
457    match obj.try_to_value::<T>(vm) {
458        Ok(single) => (predicate)(single),
459        Err(_) => {
460            let tuple: &Py<PyTuple> = obj
461                .try_to_value(vm)
462                .map_err(|_| vm.new_type_error((message)(obj)))?;
463            for obj in tuple {
464                if single_or_tuple_any(obj, predicate, message, vm)? {
465                    return Ok(true);
466                }
467            }
468            Ok(false)
469        }
470    }
471}