uv_pypi_types/
lenient_requirement.rs

1use regex::Regex;
2use serde::{Deserialize, Deserializer, Serialize, de};
3use std::borrow::Cow;
4use std::str::FromStr;
5use std::sync::LazyLock;
6use tracing::warn;
7
8use uv_pep440::{VersionSpecifiers, VersionSpecifiersParseError};
9use uv_pep508::{Pep508Error, Pep508Url, Requirement};
10
11use crate::VerbatimParsedUrl;
12
13/// Ex) `>=7.2.0<8.0.0`
14static MISSING_COMMA: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(\d)([<>=~^!])").unwrap());
15/// Ex) `!=~5.0`
16static NOT_EQUAL_TILDE: LazyLock<Regex> =
17    LazyLock::new(|| Regex::new(r"!=~((?:\d\.)*\d)").unwrap());
18/// Ex) `>=1.9.*`, `<3.4.*`
19static INVALID_TRAILING_DOT_STAR: LazyLock<Regex> =
20    LazyLock::new(|| Regex::new(r"(<=|>=|<|>)(\d+(\.\d+)*)\.\*").unwrap());
21/// Ex) `!=3.0*`
22static MISSING_DOT: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(\d\.\d)+\*").unwrap());
23/// Ex) `>=3.6,`
24static TRAILING_COMMA: LazyLock<Regex> = LazyLock::new(|| Regex::new(r",\s*$").unwrap());
25/// Ex) `>dev`
26static GREATER_THAN_DEV: LazyLock<Regex> = LazyLock::new(|| Regex::new(r">dev").unwrap());
27/// Ex) `>=9.0.0a1.0`
28static TRAILING_ZERO: LazyLock<Regex> =
29    LazyLock::new(|| Regex::new(r"(\d+(\.\d)*(a|b|rc|post|dev)\d+)\.0").unwrap());
30
31// Search and replace functions that fix invalid specifiers.
32type FixUp = for<'a> fn(&'a str) -> Cow<'a, str>;
33
34/// A list of fixups with a corresponding message about what was fixed.
35static FIXUPS: &[(FixUp, &str)] = &[
36    // Given `>=7.2.0<8.0.0`, rewrite to `>=7.2.0,<8.0.0`.
37    (
38        |input| MISSING_COMMA.replace_all(input, r"$1,$2"),
39        "inserting missing comma",
40    ),
41    // Given `!=~5.0,>=4.12`, rewrite to `!=5.0.*,>=4.12`.
42    (
43        |input| NOT_EQUAL_TILDE.replace_all(input, r"!=${1}.*"),
44        "replacing invalid tilde with wildcard",
45    ),
46    // Given `>=1.9.*`, rewrite to `>=1.9`.
47    (
48        |input| INVALID_TRAILING_DOT_STAR.replace_all(input, r"${1}${2}"),
49        "removing star after comparison operator other than equal and not equal",
50    ),
51    // Given `!=3.0*`, rewrite to `!=3.0.*`.
52    (
53        |input| MISSING_DOT.replace_all(input, r"${1}.*"),
54        "inserting missing dot",
55    ),
56    // Given `>=3.6,`, rewrite to `>=3.6`
57    (
58        |input| TRAILING_COMMA.replace_all(input, r"${1}"),
59        "removing trailing comma",
60    ),
61    // Given `>dev`, rewrite to `>0.0.0dev`
62    (
63        |input| GREATER_THAN_DEV.replace_all(input, r">0.0.0dev"),
64        "assuming 0.0.0dev",
65    ),
66    // Given `>=9.0.0a1.0`, rewrite to `>=9.0.0a1`
67    (
68        |input| TRAILING_ZERO.replace_all(input, r"${1}"),
69        "removing trailing zero",
70    ),
71    (remove_stray_quotes, "removing stray quotes"),
72];
73
74// Given `>= 2.7'`, rewrite to `>= 2.7`
75fn remove_stray_quotes(input: &str) -> Cow<'_, str> {
76    /// Ex) `'>= 2.7'`, `>=3.6'`
77    static STRAY_QUOTES: LazyLock<Regex> = LazyLock::new(|| Regex::new(r#"['"]"#).unwrap());
78
79    // make sure not to touch markers, which can have quotes (e.g. `python_version >= '3.7'`)
80    match input.find(';') {
81        Some(markers) => {
82            let requirement = STRAY_QUOTES.replace_all(&input[..markers], "");
83            format!("{}{}", requirement, &input[markers..]).into()
84        }
85        None => STRAY_QUOTES.replace_all(input, ""),
86    }
87}
88
89fn parse_with_fixups<Err, T: FromStr<Err = Err>>(input: &str, type_name: &str) -> Result<T, Err> {
90    match T::from_str(input) {
91        Ok(requirement) => Ok(requirement),
92        Err(err) => {
93            let mut patched_input = input.to_string();
94            let mut messages = Vec::new();
95            for (fixup, message) in FIXUPS {
96                let patched = fixup(patched_input.as_ref());
97                if patched != patched_input {
98                    messages.push(*message);
99
100                    if let Ok(requirement) = T::from_str(&patched) {
101                        warn!(
102                            "Fixing invalid {type_name} by {} (before: `{input}`; after: `{patched}`)",
103                            messages.join(", ")
104                        );
105                        return Ok(requirement);
106                    }
107
108                    patched_input = patched.to_string();
109                }
110            }
111
112            Err(err)
113        }
114    }
115}
116
117/// Like [`Requirement`], but attempts to correct some common errors in user-provided requirements.
118#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
119pub struct LenientRequirement<T: Pep508Url = VerbatimParsedUrl>(Requirement<T>);
120
121impl<T: Pep508Url> FromStr for LenientRequirement<T> {
122    type Err = Pep508Error<T>;
123
124    fn from_str(input: &str) -> Result<Self, Self::Err> {
125        Ok(Self(parse_with_fixups(input, "requirement")?))
126    }
127}
128
129impl<T: Pep508Url> From<LenientRequirement<T>> for Requirement<T> {
130    fn from(requirement: LenientRequirement<T>) -> Self {
131        requirement.0
132    }
133}
134
135/// Like [`VersionSpecifiers`], but attempts to correct some common errors in user-provided requirements.
136///
137/// For example, we turn `>=3.x.*` into `>=3.x`.
138#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
139pub struct LenientVersionSpecifiers(VersionSpecifiers);
140
141impl FromStr for LenientVersionSpecifiers {
142    type Err = VersionSpecifiersParseError;
143
144    fn from_str(input: &str) -> Result<Self, Self::Err> {
145        Ok(Self(parse_with_fixups(input, "version specifier")?))
146    }
147}
148
149impl From<LenientVersionSpecifiers> for VersionSpecifiers {
150    fn from(specifiers: LenientVersionSpecifiers) -> Self {
151        specifiers.0
152    }
153}
154
155impl<'de> Deserialize<'de> for LenientVersionSpecifiers {
156    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
157    where
158        D: Deserializer<'de>,
159    {
160        struct Visitor;
161
162        impl de::Visitor<'_> for Visitor {
163            type Value = LenientVersionSpecifiers;
164
165            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
166                f.write_str("a string")
167            }
168
169            fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
170                LenientVersionSpecifiers::from_str(v).map_err(de::Error::custom)
171            }
172        }
173
174        deserializer.deserialize_str(Visitor)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use std::str::FromStr;
181
182    use uv_pep440::VersionSpecifiers;
183    use uv_pep508::Requirement;
184
185    use crate::LenientVersionSpecifiers;
186
187    use super::LenientRequirement;
188
189    #[test]
190    fn requirement_missing_comma() {
191        let actual: Requirement = LenientRequirement::from_str("elasticsearch-dsl (>=7.2.0<8.0.0)")
192            .unwrap()
193            .into();
194        let expected: Requirement =
195            Requirement::from_str("elasticsearch-dsl (>=7.2.0,<8.0.0)").unwrap();
196        assert_eq!(actual, expected);
197    }
198
199    #[test]
200    fn requirement_not_equal_tile() {
201        let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5.0,>=4.12)")
202            .unwrap()
203            .into();
204        let expected: Requirement = Requirement::from_str("jupyter-core (!=5.0.*,>=4.12)").unwrap();
205        assert_eq!(actual, expected);
206
207        let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5,>=4.12)")
208            .unwrap()
209            .into();
210        let expected: Requirement = Requirement::from_str("jupyter-core (!=5.*,>=4.12)").unwrap();
211        assert_eq!(actual, expected);
212    }
213
214    #[test]
215    fn requirement_greater_than_star() {
216        let actual: Requirement = LenientRequirement::from_str("torch (>=1.9.*)")
217            .unwrap()
218            .into();
219        let expected: Requirement = Requirement::from_str("torch (>=1.9)").unwrap();
220        assert_eq!(actual, expected);
221    }
222
223    #[test]
224    fn requirement_missing_dot() {
225        let actual: Requirement =
226            LenientRequirement::from_str("pyzmq (>=2.7,!=3.0*,!=3.1*,!=3.2*)")
227                .unwrap()
228                .into();
229        let expected: Requirement =
230            Requirement::from_str("pyzmq (>=2.7,!=3.0.*,!=3.1.*,!=3.2.*)").unwrap();
231        assert_eq!(actual, expected);
232    }
233
234    #[test]
235    fn requirement_trailing_comma() {
236        let actual: Requirement = LenientRequirement::from_str("pyzmq >=3.6,").unwrap().into();
237        let expected: Requirement = Requirement::from_str("pyzmq >=3.6").unwrap();
238        assert_eq!(actual, expected);
239    }
240
241    #[test]
242    fn specifier_missing_comma() {
243        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=7.2.0<8.0.0")
244            .unwrap()
245            .into();
246        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=7.2.0,<8.0.0").unwrap();
247        assert_eq!(actual, expected);
248    }
249
250    #[test]
251    fn specifier_not_equal_tile() {
252        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str("!=~5.0,>=4.12")
253            .unwrap()
254            .into();
255        let expected: VersionSpecifiers = VersionSpecifiers::from_str("!=5.0.*,>=4.12").unwrap();
256        assert_eq!(actual, expected);
257
258        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str("!=~5,>=4.12")
259            .unwrap()
260            .into();
261        let expected: VersionSpecifiers = VersionSpecifiers::from_str("!=5.*,>=4.12").unwrap();
262        assert_eq!(actual, expected);
263    }
264
265    #[test]
266    fn specifier_greater_than_star() {
267        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=1.9.*")
268            .unwrap()
269            .into();
270        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=1.9").unwrap();
271        assert_eq!(actual, expected);
272
273        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=1.*").unwrap().into();
274        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=1").unwrap();
275        assert_eq!(actual, expected);
276    }
277
278    #[test]
279    fn specifier_missing_dot() {
280        let actual: VersionSpecifiers =
281            LenientVersionSpecifiers::from_str(">=2.7,!=3.0*,!=3.1*,!=3.2*")
282                .unwrap()
283                .into();
284        let expected: VersionSpecifiers =
285            VersionSpecifiers::from_str(">=2.7,!=3.0.*,!=3.1.*,!=3.2.*").unwrap();
286        assert_eq!(actual, expected);
287    }
288
289    #[test]
290    fn specifier_trailing_comma() {
291        let actual: VersionSpecifiers =
292            LenientVersionSpecifiers::from_str(">=3.6,").unwrap().into();
293        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
294        assert_eq!(actual, expected);
295    }
296
297    #[test]
298    fn specifier_trailing_comma_trailing_space() {
299        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=3.6, ")
300            .unwrap()
301            .into();
302        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
303        assert_eq!(actual, expected);
304    }
305
306    /// <https://pypi.org/simple/shellingham/?format=application/vnd.pypi.simple.v1+json>
307    #[test]
308    fn specifier_invalid_single_quotes() {
309        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">= '2.7'")
310            .unwrap()
311            .into();
312        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">= 2.7").unwrap();
313        assert_eq!(actual, expected);
314    }
315
316    /// <https://pypi.org/simple/tensorflowonspark/?format=application/vnd.pypi.simple.v1+json>
317    #[test]
318    fn specifier_invalid_double_quotes() {
319        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=\"3.6\"")
320            .unwrap()
321            .into();
322        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
323        assert_eq!(actual, expected);
324    }
325
326    /// <https://pypi.org/simple/celery/?format=application/vnd.pypi.simple.v1+json>
327    #[test]
328    fn specifier_multi_fix() {
329        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(
330            ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*,",
331        )
332        .unwrap()
333        .into();
334        let expected: VersionSpecifiers =
335            VersionSpecifiers::from_str(">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*")
336                .unwrap();
337        assert_eq!(actual, expected);
338    }
339
340    /// <https://pypi.org/simple/wincertstore/?format=application/vnd.pypi.simple.v1+json>
341    #[test]
342    fn smaller_than_star() {
343        let actual: VersionSpecifiers =
344            LenientVersionSpecifiers::from_str(">=2.7,!=3.0.*,!=3.1.*,<3.4.*")
345                .unwrap()
346                .into();
347        let expected: VersionSpecifiers =
348            VersionSpecifiers::from_str(">=2.7,!=3.0.*,!=3.1.*,<3.4").unwrap();
349        assert_eq!(actual, expected);
350    }
351
352    /// <https://pypi.org/simple/algoliasearch/?format=application/vnd.pypi.simple.v1+json>
353    /// <https://pypi.org/simple/okta/?format=application/vnd.pypi.simple.v1+json>
354    #[test]
355    fn stray_quote() {
356        let actual: VersionSpecifiers =
357            LenientVersionSpecifiers::from_str(">=2.7, !=3.0.*, !=3.1.*', !=3.2.*, !=3.3.*'")
358                .unwrap()
359                .into();
360        let expected: VersionSpecifiers =
361            VersionSpecifiers::from_str(">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*").unwrap();
362        assert_eq!(actual, expected);
363        let actual: VersionSpecifiers =
364            LenientVersionSpecifiers::from_str(">=3.6'").unwrap().into();
365        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
366        assert_eq!(actual, expected);
367    }
368
369    /// <https://files.pythonhosted.org/packages/74/49/7349527cea7f708e7d3253ab6b32c9b5bdf84a57dde8fc265a33e6a4e662/boto3-1.2.0-py2.py3-none-any.whl>
370    #[test]
371    fn trailing_comma_after_quote() {
372        let actual: Requirement = LenientRequirement::from_str("botocore>=1.3.0,<1.4.0',")
373            .unwrap()
374            .into();
375        let expected: Requirement = Requirement::from_str("botocore>=1.3.0,<1.4.0").unwrap();
376        assert_eq!(actual, expected);
377    }
378
379    /// <https://github.com/celery/celery/blob/6215f34d2675441ef2177bd850bf5f4b442e944c/requirements/default.txt#L1>
380    #[test]
381    fn greater_than_dev() {
382        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">dev").unwrap().into();
383        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">0.0.0dev").unwrap();
384        assert_eq!(actual, expected);
385    }
386
387    /// <https://github.com/astral-sh/uv/issues/1798>
388    #[test]
389    fn trailing_alpha_zero() {
390        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=9.0.0a1.0")
391            .unwrap()
392            .into();
393        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=9.0.0a1").unwrap();
394        assert_eq!(actual, expected);
395
396        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=9.0a1.0")
397            .unwrap()
398            .into();
399        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=9.0a1").unwrap();
400        assert_eq!(actual, expected);
401
402        let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=9a1.0")
403            .unwrap()
404            .into();
405        let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=9a1").unwrap();
406        assert_eq!(actual, expected);
407    }
408
409    /// <https://github.com/astral-sh/uv/issues/2551>
410    #[test]
411    fn stray_quote_preserve_marker() {
412        let actual: Requirement =
413            LenientRequirement::from_str("numpy >=1.19; python_version >= \"3.7\"")
414                .unwrap()
415                .into();
416        let expected: Requirement =
417            Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap();
418        assert_eq!(actual, expected);
419
420        let actual: Requirement =
421            LenientRequirement::from_str("numpy \">=1.19\"; python_version >= \"3.7\"")
422                .unwrap()
423                .into();
424        let expected: Requirement =
425            Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap();
426        assert_eq!(actual, expected);
427
428        let actual: Requirement =
429            LenientRequirement::from_str("'numpy' >=1.19\"; python_version >= \"3.7\"")
430                .unwrap()
431                .into();
432        let expected: Requirement =
433            Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap();
434        assert_eq!(actual, expected);
435    }
436}