1use regex::Regex;
2use serde::{Deserialize, Deserializer, Serialize, de};
3use std::borrow::Cow;
4use std::str::FromStr;
5use std::sync::LazyLock;
6use tracing::warn;
7
8use uv_pep440::{VersionSpecifiers, VersionSpecifiersParseError};
9use uv_pep508::{Pep508Error, Pep508Url, Requirement};
10
11use crate::VerbatimParsedUrl;
12
13static MISSING_COMMA: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(\d)([<>=~^!])").unwrap());
15static NOT_EQUAL_TILDE: LazyLock<Regex> =
17 LazyLock::new(|| Regex::new(r"!=~((?:\d\.)*\d)").unwrap());
18static INVALID_TRAILING_DOT_STAR: LazyLock<Regex> =
20 LazyLock::new(|| Regex::new(r"(<=|>=|<|>)(\d+(\.\d+)*)\.\*").unwrap());
21static MISSING_DOT: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(\d\.\d)+\*").unwrap());
23static TRAILING_COMMA: LazyLock<Regex> = LazyLock::new(|| Regex::new(r",\s*$").unwrap());
25static GREATER_THAN_DEV: LazyLock<Regex> = LazyLock::new(|| Regex::new(r">dev").unwrap());
27static TRAILING_ZERO: LazyLock<Regex> =
29 LazyLock::new(|| Regex::new(r"(\d+(\.\d)*(a|b|rc|post|dev)\d+)\.0").unwrap());
30
31type FixUp = for<'a> fn(&'a str) -> Cow<'a, str>;
33
34static FIXUPS: &[(FixUp, &str)] = &[
36 (
38 |input| MISSING_COMMA.replace_all(input, r"$1,$2"),
39 "inserting missing comma",
40 ),
41 (
43 |input| NOT_EQUAL_TILDE.replace_all(input, r"!=${1}.*"),
44 "replacing invalid tilde with wildcard",
45 ),
46 (
48 |input| INVALID_TRAILING_DOT_STAR.replace_all(input, r"${1}${2}"),
49 "removing star after comparison operator other than equal and not equal",
50 ),
51 (
53 |input| MISSING_DOT.replace_all(input, r"${1}.*"),
54 "inserting missing dot",
55 ),
56 (
58 |input| TRAILING_COMMA.replace_all(input, r"${1}"),
59 "removing trailing comma",
60 ),
61 (
63 |input| GREATER_THAN_DEV.replace_all(input, r">0.0.0dev"),
64 "assuming 0.0.0dev",
65 ),
66 (
68 |input| TRAILING_ZERO.replace_all(input, r"${1}"),
69 "removing trailing zero",
70 ),
71 (remove_stray_quotes, "removing stray quotes"),
72];
73
74fn remove_stray_quotes(input: &str) -> Cow<'_, str> {
76 static STRAY_QUOTES: LazyLock<Regex> = LazyLock::new(|| Regex::new(r#"['"]"#).unwrap());
78
79 match input.find(';') {
81 Some(markers) => {
82 let requirement = STRAY_QUOTES.replace_all(&input[..markers], "");
83 format!("{}{}", requirement, &input[markers..]).into()
84 }
85 None => STRAY_QUOTES.replace_all(input, ""),
86 }
87}
88
89fn parse_with_fixups<Err, T: FromStr<Err = Err>>(input: &str, type_name: &str) -> Result<T, Err> {
90 match T::from_str(input) {
91 Ok(requirement) => Ok(requirement),
92 Err(err) => {
93 let mut patched_input = input.to_string();
94 let mut messages = Vec::new();
95 for (fixup, message) in FIXUPS {
96 let patched = fixup(patched_input.as_ref());
97 if patched != patched_input {
98 messages.push(*message);
99
100 if let Ok(requirement) = T::from_str(&patched) {
101 warn!(
102 "Fixing invalid {type_name} by {} (before: `{input}`; after: `{patched}`)",
103 messages.join(", ")
104 );
105 return Ok(requirement);
106 }
107
108 patched_input = patched.to_string();
109 }
110 }
111
112 Err(err)
113 }
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
119pub struct LenientRequirement<T: Pep508Url = VerbatimParsedUrl>(Requirement<T>);
120
121impl<T: Pep508Url> FromStr for LenientRequirement<T> {
122 type Err = Pep508Error<T>;
123
124 fn from_str(input: &str) -> Result<Self, Self::Err> {
125 Ok(Self(parse_with_fixups(input, "requirement")?))
126 }
127}
128
129impl<T: Pep508Url> From<LenientRequirement<T>> for Requirement<T> {
130 fn from(requirement: LenientRequirement<T>) -> Self {
131 requirement.0
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
139pub struct LenientVersionSpecifiers(VersionSpecifiers);
140
141impl FromStr for LenientVersionSpecifiers {
142 type Err = VersionSpecifiersParseError;
143
144 fn from_str(input: &str) -> Result<Self, Self::Err> {
145 Ok(Self(parse_with_fixups(input, "version specifier")?))
146 }
147}
148
149impl From<LenientVersionSpecifiers> for VersionSpecifiers {
150 fn from(specifiers: LenientVersionSpecifiers) -> Self {
151 specifiers.0
152 }
153}
154
155impl<'de> Deserialize<'de> for LenientVersionSpecifiers {
156 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
157 where
158 D: Deserializer<'de>,
159 {
160 struct Visitor;
161
162 impl de::Visitor<'_> for Visitor {
163 type Value = LenientVersionSpecifiers;
164
165 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
166 f.write_str("a string")
167 }
168
169 fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
170 LenientVersionSpecifiers::from_str(v).map_err(de::Error::custom)
171 }
172 }
173
174 deserializer.deserialize_str(Visitor)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use std::str::FromStr;
181
182 use uv_pep440::VersionSpecifiers;
183 use uv_pep508::Requirement;
184
185 use crate::LenientVersionSpecifiers;
186
187 use super::LenientRequirement;
188
189 #[test]
190 fn requirement_missing_comma() {
191 let actual: Requirement = LenientRequirement::from_str("elasticsearch-dsl (>=7.2.0<8.0.0)")
192 .unwrap()
193 .into();
194 let expected: Requirement =
195 Requirement::from_str("elasticsearch-dsl (>=7.2.0,<8.0.0)").unwrap();
196 assert_eq!(actual, expected);
197 }
198
199 #[test]
200 fn requirement_not_equal_tile() {
201 let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5.0,>=4.12)")
202 .unwrap()
203 .into();
204 let expected: Requirement = Requirement::from_str("jupyter-core (!=5.0.*,>=4.12)").unwrap();
205 assert_eq!(actual, expected);
206
207 let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5,>=4.12)")
208 .unwrap()
209 .into();
210 let expected: Requirement = Requirement::from_str("jupyter-core (!=5.*,>=4.12)").unwrap();
211 assert_eq!(actual, expected);
212 }
213
214 #[test]
215 fn requirement_greater_than_star() {
216 let actual: Requirement = LenientRequirement::from_str("torch (>=1.9.*)")
217 .unwrap()
218 .into();
219 let expected: Requirement = Requirement::from_str("torch (>=1.9)").unwrap();
220 assert_eq!(actual, expected);
221 }
222
223 #[test]
224 fn requirement_missing_dot() {
225 let actual: Requirement =
226 LenientRequirement::from_str("pyzmq (>=2.7,!=3.0*,!=3.1*,!=3.2*)")
227 .unwrap()
228 .into();
229 let expected: Requirement =
230 Requirement::from_str("pyzmq (>=2.7,!=3.0.*,!=3.1.*,!=3.2.*)").unwrap();
231 assert_eq!(actual, expected);
232 }
233
234 #[test]
235 fn requirement_trailing_comma() {
236 let actual: Requirement = LenientRequirement::from_str("pyzmq >=3.6,").unwrap().into();
237 let expected: Requirement = Requirement::from_str("pyzmq >=3.6").unwrap();
238 assert_eq!(actual, expected);
239 }
240
241 #[test]
242 fn specifier_missing_comma() {
243 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=7.2.0<8.0.0")
244 .unwrap()
245 .into();
246 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=7.2.0,<8.0.0").unwrap();
247 assert_eq!(actual, expected);
248 }
249
250 #[test]
251 fn specifier_not_equal_tile() {
252 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str("!=~5.0,>=4.12")
253 .unwrap()
254 .into();
255 let expected: VersionSpecifiers = VersionSpecifiers::from_str("!=5.0.*,>=4.12").unwrap();
256 assert_eq!(actual, expected);
257
258 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str("!=~5,>=4.12")
259 .unwrap()
260 .into();
261 let expected: VersionSpecifiers = VersionSpecifiers::from_str("!=5.*,>=4.12").unwrap();
262 assert_eq!(actual, expected);
263 }
264
265 #[test]
266 fn specifier_greater_than_star() {
267 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=1.9.*")
268 .unwrap()
269 .into();
270 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=1.9").unwrap();
271 assert_eq!(actual, expected);
272
273 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=1.*").unwrap().into();
274 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=1").unwrap();
275 assert_eq!(actual, expected);
276 }
277
278 #[test]
279 fn specifier_missing_dot() {
280 let actual: VersionSpecifiers =
281 LenientVersionSpecifiers::from_str(">=2.7,!=3.0*,!=3.1*,!=3.2*")
282 .unwrap()
283 .into();
284 let expected: VersionSpecifiers =
285 VersionSpecifiers::from_str(">=2.7,!=3.0.*,!=3.1.*,!=3.2.*").unwrap();
286 assert_eq!(actual, expected);
287 }
288
289 #[test]
290 fn specifier_trailing_comma() {
291 let actual: VersionSpecifiers =
292 LenientVersionSpecifiers::from_str(">=3.6,").unwrap().into();
293 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
294 assert_eq!(actual, expected);
295 }
296
297 #[test]
298 fn specifier_trailing_comma_trailing_space() {
299 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=3.6, ")
300 .unwrap()
301 .into();
302 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
303 assert_eq!(actual, expected);
304 }
305
306 #[test]
308 fn specifier_invalid_single_quotes() {
309 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">= '2.7'")
310 .unwrap()
311 .into();
312 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">= 2.7").unwrap();
313 assert_eq!(actual, expected);
314 }
315
316 #[test]
318 fn specifier_invalid_double_quotes() {
319 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=\"3.6\"")
320 .unwrap()
321 .into();
322 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
323 assert_eq!(actual, expected);
324 }
325
326 #[test]
328 fn specifier_multi_fix() {
329 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(
330 ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*,",
331 )
332 .unwrap()
333 .into();
334 let expected: VersionSpecifiers =
335 VersionSpecifiers::from_str(">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*")
336 .unwrap();
337 assert_eq!(actual, expected);
338 }
339
340 #[test]
342 fn smaller_than_star() {
343 let actual: VersionSpecifiers =
344 LenientVersionSpecifiers::from_str(">=2.7,!=3.0.*,!=3.1.*,<3.4.*")
345 .unwrap()
346 .into();
347 let expected: VersionSpecifiers =
348 VersionSpecifiers::from_str(">=2.7,!=3.0.*,!=3.1.*,<3.4").unwrap();
349 assert_eq!(actual, expected);
350 }
351
352 #[test]
355 fn stray_quote() {
356 let actual: VersionSpecifiers =
357 LenientVersionSpecifiers::from_str(">=2.7, !=3.0.*, !=3.1.*', !=3.2.*, !=3.3.*'")
358 .unwrap()
359 .into();
360 let expected: VersionSpecifiers =
361 VersionSpecifiers::from_str(">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*").unwrap();
362 assert_eq!(actual, expected);
363 let actual: VersionSpecifiers =
364 LenientVersionSpecifiers::from_str(">=3.6'").unwrap().into();
365 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
366 assert_eq!(actual, expected);
367 }
368
369 #[test]
371 fn trailing_comma_after_quote() {
372 let actual: Requirement = LenientRequirement::from_str("botocore>=1.3.0,<1.4.0',")
373 .unwrap()
374 .into();
375 let expected: Requirement = Requirement::from_str("botocore>=1.3.0,<1.4.0").unwrap();
376 assert_eq!(actual, expected);
377 }
378
379 #[test]
381 fn greater_than_dev() {
382 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">dev").unwrap().into();
383 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">0.0.0dev").unwrap();
384 assert_eq!(actual, expected);
385 }
386
387 #[test]
389 fn trailing_alpha_zero() {
390 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=9.0.0a1.0")
391 .unwrap()
392 .into();
393 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=9.0.0a1").unwrap();
394 assert_eq!(actual, expected);
395
396 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=9.0a1.0")
397 .unwrap()
398 .into();
399 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=9.0a1").unwrap();
400 assert_eq!(actual, expected);
401
402 let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=9a1.0")
403 .unwrap()
404 .into();
405 let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=9a1").unwrap();
406 assert_eq!(actual, expected);
407 }
408
409 #[test]
411 fn stray_quote_preserve_marker() {
412 let actual: Requirement =
413 LenientRequirement::from_str("numpy >=1.19; python_version >= \"3.7\"")
414 .unwrap()
415 .into();
416 let expected: Requirement =
417 Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap();
418 assert_eq!(actual, expected);
419
420 let actual: Requirement =
421 LenientRequirement::from_str("numpy \">=1.19\"; python_version >= \"3.7\"")
422 .unwrap()
423 .into();
424 let expected: Requirement =
425 Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap();
426 assert_eq!(actual, expected);
427
428 let actual: Requirement =
429 LenientRequirement::from_str("'numpy' >=1.19\"; python_version >= \"3.7\"")
430 .unwrap()
431 .into();
432 let expected: Requirement =
433 Requirement::from_str("numpy >=1.19; python_version >= \"3.7\"").unwrap();
434 assert_eq!(actual, expected);
435 }
436}