scouter_semver/
semver.rs

1use crate::error::VersionError;
2use pyo3::prelude::*;
3use semver::{BuildMetadata, Prerelease, Version};
4use serde::{Deserialize, Serialize};
5use std::str::FromStr;
6use std::string::ToString;
7#[pyclass]
8#[derive(Debug, PartialEq, Deserialize, Serialize, Clone)]
9pub enum VersionType {
10    Major,
11    Minor,
12    Patch,
13    Pre,
14    Build,
15    PreBuild,
16}
17
18impl FromStr for VersionType {
19    type Err = ();
20
21    fn from_str(input: &str) -> Result<VersionType, Self::Err> {
22        match input.to_lowercase().as_str() {
23            "major" => Ok(VersionType::Major),
24            "minor" => Ok(VersionType::Minor),
25            "patch" => Ok(VersionType::Patch),
26            "pre" => Ok(VersionType::Pre),
27            "build" => Ok(VersionType::Build),
28            "pre_build" => Ok(VersionType::PreBuild),
29            _ => Err(()),
30        }
31    }
32}
33
34#[pymethods]
35impl VersionType {
36    #[new]
37    fn new(version_type: &str) -> PyResult<Self> {
38        match VersionType::from_str(version_type) {
39            Ok(version_type) => Ok(version_type),
40            Err(()) => Err(pyo3::exceptions::PyValueError::new_err(
41                "Invalid version type",
42            )),
43        }
44    }
45
46    fn __eq__(&self, other: &Self) -> bool {
47        self == other
48    }
49}
50
51#[derive(Debug, PartialEq)]
52pub struct VersionArgs {
53    pub version: String,
54    pub version_type: VersionType,
55    pub pre: Option<String>,
56    pub build: Option<String>,
57}
58
59pub struct VersionValidator {}
60
61impl VersionValidator {
62    pub fn validate_version(version: &str) -> Result<(), VersionError> {
63        match Version::parse(version) {
64            Ok(_) => Ok(()),
65            Err(e) => Err(VersionError::InvalidVersion(e)),
66        }
67    }
68
69    pub fn bump_version(version_args: &VersionArgs) -> Result<Version, VersionError> {
70        // parse the version
71        let version = match Version::parse(&version_args.version) {
72            Ok(v) => v,
73            Err(e) => return Err(VersionError::InvalidVersion(e)),
74        };
75
76        let mut new_version = Version::new(version.major, version.minor, version.patch);
77
78        // check if version type is major, minor, or patch. If not, return the version as is
79        match version_args.version_type {
80            VersionType::Major => {
81                new_version.major += 1;
82                new_version.minor = 0;
83                new_version.patch = 0;
84            }
85            VersionType::Minor => {
86                new_version.minor += 1;
87                new_version.patch = 0;
88            }
89            VersionType::Patch => new_version.patch += 1,
90
91            // do nothing for pre and build
92            VersionType::Pre | VersionType::Build | VersionType::PreBuild => {}
93        };
94
95        // its possible someone creates a major, minor, patch version with a pre or build, or both
96        // in this case, we need to add the pre and build to the new version
97        if let Some(pre) = &version_args.pre {
98            new_version.pre = match Prerelease::new(pre) {
99                Ok(p) => p,
100                Err(e) => return Err(VersionError::InvalidPreReleaseIdentifier(e)),
101            };
102        }
103
104        if let Some(build) = &version_args.build {
105            new_version.build = match BuildMetadata::new(build) {
106                Ok(b) => b,
107                Err(e) => return Err(VersionError::InvalidPreReleaseIdentifier(e)),
108            };
109        }
110
111        Ok(new_version)
112    }
113
114    pub fn sort_string_versions(versions: Vec<String>) -> Result<Vec<String>, VersionError> {
115        let mut versions: Vec<Version> = versions
116            .iter()
117            .map(|v| Version::parse(v).map_err(VersionError::InvalidVersion))
118            .collect::<Result<Vec<_>, _>>()?;
119
120        versions.sort();
121
122        Ok(versions.iter().map(ToString::to_string).collect())
123    }
124
125    pub fn sort_semver_versions(
126        mut versions: Vec<Version>,
127        reverse: bool,
128    ) -> Result<Vec<String>, VersionError> {
129        if versions.is_empty() {
130            return Ok(vec![]);
131        } else {
132            versions.sort();
133
134            if reverse {
135                versions.reverse();
136            }
137        }
138
139        Ok(versions.iter().map(ToString::to_string).collect())
140    }
141
142    /// Take a semver that may be incomplete and expand it to a full semver
143    ///
144    fn expand_version(version: &str) -> String {
145        let version_parts: Vec<&str> = version.split('.').collect();
146
147        // Return early if we already have all parts
148        if version_parts.len() >= 3 {
149            return version.to_string();
150        }
151
152        // Create a new vector with the existing parts
153        let mut expanded_version = version_parts.to_vec();
154
155        // Fill in missing parts with "0"
156        while expanded_version.len() < 3 {
157            expanded_version.push("0");
158        }
159
160        // Join parts with dots and return owned String
161        expanded_version.join(".")
162    }
163
164    pub fn clean_version(version: &str) -> Result<Version, VersionError> {
165        // Check if the version is empty
166        if version.is_empty() {
167            return Err(VersionError::EmptyVersionString);
168        }
169
170        match Version::parse(&Self::expand_version(version)) {
171            Ok(version) => Ok(version),
172            Err(e) => Err(VersionError::InvalidVersion(e)),
173        }
174    }
175}
176
177#[derive(Debug, PartialEq)]
178pub struct VersionBounds {
179    pub lower_bound: Version,
180    pub upper_bound: Version,
181    pub no_upper_bound: bool,
182    pub parser_type: VersionParser,
183    pub num_parts: usize,
184}
185
186#[derive(PartialEq, Debug)]
187pub enum VersionParser {
188    Star,
189    Caret,
190    Tilde,
191    Exact,
192}
193
194impl VersionParser {
195    /// Create a new VersionParser from a version string
196    ///
197    /// # Errors
198    ///
199    /// Returns an error if the version string is invalid
200    pub fn new(version: &str) -> Result<VersionParser, VersionError> {
201        // check if version contains
202        if version.contains('*') {
203            Ok(VersionParser::Star)
204        } else if version.contains('^') {
205            Ok(VersionParser::Caret)
206        } else if version.contains('~') {
207            Ok(VersionParser::Tilde)
208        } else {
209            Ok(VersionParser::Exact)
210        }
211    }
212
213    pub fn remove_version_prefix(&self, version: &str) -> String {
214        // break version into parts
215        match self {
216            VersionParser::Star => version.replace('*', ""),
217            VersionParser::Caret => version.replace('^', ""),
218            VersionParser::Tilde => version.replace('~', ""),
219            VersionParser::Exact => version.to_string(),
220        }
221    }
222
223    /// Parse a version string into a Version struct
224    ///
225    /// # Errors
226    /// Errors if the version string is invalid
227    fn parse_version(version: &str) -> Result<Version, VersionError> {
228        Version::parse(version).map_err(VersionError::InvalidVersion)
229    }
230
231    /// Create a VersionBounds struct from a lower and upper version string
232    ///
233    /// # Errors
234    /// Errors if the version strings are invalid
235    fn create_bounds(
236        lower: &str,
237        upper: &str,
238        parser_type: VersionParser,
239        num_parts: usize,
240        no_upper_bound: bool,
241    ) -> Result<VersionBounds, VersionError> {
242        Ok(VersionBounds {
243            lower_bound: Self::parse_version(lower)?,
244            upper_bound: Self::parse_version(upper)?,
245            no_upper_bound,
246            parser_type,
247            num_parts,
248        })
249    }
250
251    pub fn get_version_to_search(version: &str) -> Result<VersionBounds, VersionError> {
252        let parser = VersionParser::new(version)?;
253
254        let cleaned_version = parser.remove_version_prefix(version);
255
256        // determine number of "." in the version and split into int parts
257        let version_parts = cleaned_version
258            .split('.')
259            .filter(|v| !v.is_empty())
260            .map(str::parse::<u64>)
261            .collect::<Result<Vec<_>, _>>()
262            .map_err(VersionError::ParseError)?;
263
264        let num_parts = version_parts.len();
265
266        match parser {
267            VersionParser::Star => match num_parts {
268                0 => Self::create_bounds("0.0.0", "0.0.0", VersionParser::Star, num_parts, true),
269                1 => Self::create_bounds(
270                    &format!("{}.0.0", version_parts[0]),
271                    &format!("{}.0.0", version_parts[0] + 1),
272                    VersionParser::Star,
273                    num_parts,
274                    false,
275                ),
276                2 => Self::create_bounds(
277                    &format!("{}.{}.0", version_parts[0], version_parts[1]),
278                    &format!("{}.{}.0", version_parts[0], version_parts[1] + 1),
279                    VersionParser::Star,
280                    num_parts,
281                    false,
282                ),
283                3 => Self::create_bounds(
284                    &format!(
285                        "{}.{}.{}",
286                        version_parts[0], version_parts[1], version_parts[2]
287                    ),
288                    &format!(
289                        "{}.{}.{}",
290                        version_parts[0],
291                        version_parts[1],
292                        version_parts[2] + 1
293                    ),
294                    VersionParser::Star,
295                    num_parts,
296                    false,
297                ),
298                _ => Err(VersionError::StarSyntaxError),
299            },
300            VersionParser::Tilde => match num_parts {
301                1 => Self::create_bounds(
302                    &format!("{}.0.0", version_parts[0]),
303                    &format!("{}.0.0", version_parts[0] + 1),
304                    VersionParser::Tilde,
305                    num_parts,
306                    false,
307                ),
308                2 => Self::create_bounds(
309                    &format!("{}.{}.0", version_parts[0], version_parts[1]),
310                    &format!("{}.{}.0", version_parts[0], version_parts[1] + 1),
311                    VersionParser::Tilde,
312                    num_parts,
313                    false,
314                ),
315                _ => Self::create_bounds(
316                    &format!(
317                        "{}.{}.{}",
318                        version_parts[0], version_parts[1], version_parts[2]
319                    ),
320                    &format!("{}.{}.0", version_parts[0], version_parts[1] + 1),
321                    VersionParser::Tilde,
322                    num_parts,
323                    false,
324                ),
325            },
326            VersionParser::Caret => {
327                if num_parts >= 2 {
328                    Self::create_bounds(
329                        &format!(
330                            "{}.{}.{}",
331                            version_parts[0], version_parts[1], version_parts[2]
332                        ),
333                        &format!("{}.{}.0", version_parts[0], version_parts[1] + 1),
334                        VersionParser::Caret,
335                        num_parts,
336                        false,
337                    )
338                } else {
339                    Err(VersionError::CaretSyntaxError)
340                }
341            }
342            VersionParser::Exact => match num_parts {
343                1 => Self::create_bounds(
344                    &format!("{}.0.0", version_parts[0]),
345                    &format!("{}.0.0", version_parts[0] + 1),
346                    VersionParser::Exact,
347                    num_parts,
348                    false,
349                ),
350                2 => Self::create_bounds(
351                    &format!("{}.{}.0", version_parts[0], version_parts[1]),
352                    &format!("{}.{}.0", version_parts[0], version_parts[1] + 1),
353                    VersionParser::Exact,
354                    num_parts,
355                    false,
356                ),
357                3 => Self::create_bounds(
358                    &format!(
359                        "{}.{}.{}",
360                        version_parts[0], version_parts[1], version_parts[2]
361                    ),
362                    &format!(
363                        "{}.{}.{}",
364                        version_parts[0],
365                        version_parts[1],
366                        version_parts[2] + 1
367                    ),
368                    VersionParser::Exact,
369                    num_parts,
370                    false,
371                ),
372                _ => Err(VersionError::ExactSyntaxError),
373            },
374        }
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use semver::Version;
382
383    #[test]
384    fn test_version_validator_validate_version() {
385        assert!(VersionValidator::validate_version("1.2.3").is_ok());
386        assert!(VersionValidator::validate_version("invalid.version").is_err());
387
388        // validate default
389        assert_eq!(
390            VersionValidator::clean_version("1.2.3-alpha").unwrap(),
391            Version::parse("1.2.3-alpha").unwrap()
392        );
393
394        // validate with build
395        assert_eq!(
396            VersionValidator::clean_version("1.2.3+001").unwrap(),
397            Version::parse("1.2.3+001").unwrap()
398        );
399
400        // validate with pre and build
401        assert_eq!(
402            VersionValidator::clean_version("1.2.3-alpha+001").unwrap(),
403            Version::parse("1.2.3-alpha+001").unwrap()
404        );
405
406        // validate missing parts
407        assert_eq!(
408            VersionValidator::clean_version("1.2").unwrap(),
409            Version::parse("1.2.0").unwrap()
410        );
411
412        assert_eq!(
413            VersionValidator::clean_version("1").unwrap(),
414            Version::parse("1.0.0").unwrap()
415        );
416    }
417
418    #[test]
419    fn test_version_validator_bump_version() {
420        let args = VersionArgs {
421            version: "1.2.3".to_string(),
422            version_type: VersionType::Major,
423            pre: None,
424            build: None,
425        };
426        assert_eq!(
427            VersionValidator::bump_version(&args).unwrap(),
428            Version::parse("2.0.0").unwrap()
429        );
430
431        let args = VersionArgs {
432            version: "1.2.3".to_string(),
433            version_type: VersionType::Minor,
434            pre: None,
435            build: None,
436        };
437        assert_eq!(
438            VersionValidator::bump_version(&args).unwrap(),
439            Version::parse("1.3.0").unwrap()
440        );
441
442        let args = VersionArgs {
443            version: "1.2.3".to_string(),
444            version_type: VersionType::Patch,
445            pre: None,
446            build: None,
447        };
448        assert_eq!(
449            VersionValidator::bump_version(&args).unwrap(),
450            Version::parse("1.2.4").unwrap()
451        );
452
453        let args = VersionArgs {
454            version: "1.2.3".to_string(),
455            version_type: VersionType::Pre,
456            pre: Some("alpha".to_string()),
457            build: None,
458        };
459        assert_eq!(
460            VersionValidator::bump_version(&args).unwrap(),
461            Version::parse("1.2.3-alpha").unwrap()
462        );
463
464        let args = VersionArgs {
465            version: "1.2.3".to_string(),
466            version_type: VersionType::Build,
467            pre: None,
468            build: Some("001".to_string()),
469        };
470        assert_eq!(
471            VersionValidator::bump_version(&args).unwrap(),
472            Version::parse("1.2.3+001").unwrap()
473        );
474
475        let args = VersionArgs {
476            version: "1.2.3".to_string(),
477            version_type: VersionType::PreBuild,
478            pre: Some("alpha".to_string()),
479            build: Some("001".to_string()),
480        };
481        assert_eq!(
482            VersionValidator::bump_version(&args).unwrap(),
483            Version::parse("1.2.3-alpha+001").unwrap()
484        );
485    }
486
487    #[test]
488    fn test_version_validator_sort_versions() {
489        let versions = vec![
490            "1.2.1".to_string(),
491            "1.3.0".to_string(),
492            "1.2.2".to_string(),
493            "1.2.3-alpha+001".to_string(),
494            "1.2.3+001".to_string(),
495            "1.2.3+0b1".to_string(),
496            "1.2.3".to_string(),
497        ];
498        let sorted_versions = VersionValidator::sort_string_versions(versions).unwrap();
499        assert_eq!(
500            sorted_versions,
501            vec![
502                "1.2.1",
503                "1.2.2",
504                "1.2.3-alpha+001",
505                "1.2.3",
506                "1.2.3+001",
507                "1.2.3+0b1",
508                "1.3.0"
509            ]
510        );
511    }
512
513    #[test]
514    fn test_version_parser_new() {
515        assert_eq!(VersionParser::new("*").unwrap(), VersionParser::Star);
516        assert_eq!(VersionParser::new("^1.2.3").unwrap(), VersionParser::Caret);
517        assert_eq!(VersionParser::new("~1.2.3").unwrap(), VersionParser::Tilde);
518        assert_eq!(VersionParser::new("1.2.3").unwrap(), VersionParser::Exact);
519    }
520
521    #[test]
522    fn test_version_parser_remove_version_prefix() {
523        assert_eq!(VersionParser::Star.remove_version_prefix("*"), "");
524        assert_eq!(VersionParser::Star.remove_version_prefix("*1"), "1");
525        assert_eq!(VersionParser::Star.remove_version_prefix("1.*"), "1.0");
526        assert_eq!(VersionParser::Star.remove_version_prefix("1.2.*"), "1.2.0");
527        assert_eq!(
528            VersionParser::Caret.remove_version_prefix("^1.2.3"),
529            "1.2.3"
530        );
531        assert_eq!(
532            VersionParser::Tilde.remove_version_prefix("~1.2.3"),
533            "1.2.3"
534        );
535        assert_eq!(VersionParser::Exact.remove_version_prefix("1.2.3"), "1.2.3");
536    }
537
538    #[test]
539    fn test_version_parser_get_version_to_search() {
540        let bounds = VersionParser::get_version_to_search("*").unwrap();
541        assert_eq!(bounds.lower_bound, Version::parse("0.0.0").unwrap());
542        assert!(bounds.no_upper_bound);
543
544        let bounds = VersionParser::get_version_to_search("1.*").unwrap();
545        assert_eq!(bounds.lower_bound, Version::parse("1.0.0").unwrap());
546        assert_eq!(bounds.upper_bound, Version::parse("1.1.0").unwrap());
547
548        let bounds = VersionParser::get_version_to_search("^1.2.3").unwrap();
549        assert_eq!(bounds.lower_bound, Version::parse("1.2.3").unwrap());
550        assert_eq!(bounds.upper_bound, Version::parse("1.3.0").unwrap());
551
552        let bounds = VersionParser::get_version_to_search("~1.2.3").unwrap();
553        assert_eq!(bounds.lower_bound, Version::parse("1.2.3").unwrap());
554        assert_eq!(bounds.upper_bound, Version::parse("1.3.0").unwrap());
555    }
556}