winget_types/shared/
version.rs

1use std::{
2    cmp::{Ordering, Reverse},
3    convert::Infallible,
4    hash::{Hash, Hasher},
5    str::FromStr,
6};
7
8use compact_str::CompactString;
9use derive_more::Display;
10use itertools::{EitherOrBoth, Itertools};
11use serde_with::{DeserializeFromStr, SerializeDisplay};
12use smallvec::SmallVec;
13
14use crate::traits::Closest;
15
16#[derive(Clone, Debug, Default, Display, Eq, SerializeDisplay, DeserializeFromStr)]
17#[display("{raw}")]
18pub struct Version {
19    raw: CompactString,
20    parts: SmallVec<[VersionPart; 4]>, // Most versions have 4 parts or fewer
21}
22
23impl Version {
24    pub const SEPARATOR: char = '.';
25
26    pub fn new(input: &str) -> Self {
27        let mut trimmed = input.trim();
28
29        // If there is a digit before the delimiter, or no delimiters, trim off all leading
30        // non-digit characters
31        if let Some(digit_pos) = trimmed.find(|char: char| char.is_ascii_digit()) {
32            if trimmed
33                .find('.')
34                .is_none_or(|delimiter_pos| digit_pos < delimiter_pos)
35            {
36                trimmed = &trimmed[digit_pos..];
37            }
38        }
39
40        let mut parts = trimmed
41            .split(Self::SEPARATOR)
42            .map(VersionPart::new)
43            .collect::<SmallVec<[_; 4]>>();
44
45        if parts.is_empty() {
46            parts.push(VersionPart::new(trimmed));
47        }
48
49        let droppable_parts = parts
50            .iter()
51            .rev()
52            .take_while(|part| part.is_droppable())
53            .count();
54
55        parts.truncate(parts.len() - droppable_parts);
56
57        Self {
58            raw: CompactString::from(input),
59            parts,
60        }
61    }
62
63    #[must_use]
64    pub fn as_str(&self) -> &str {
65        &self.raw
66    }
67
68    #[must_use]
69    pub fn is_latest(&self) -> bool {
70        const LATEST: &str = "latest";
71
72        self.raw.eq_ignore_ascii_case(LATEST)
73    }
74}
75
76impl FromStr for Version {
77    type Err = Infallible;
78
79    fn from_str(s: &str) -> Result<Self, Self::Err> {
80        Ok(Self::new(s))
81    }
82}
83
84impl PartialEq for Version {
85    fn eq(&self, other: &Self) -> bool {
86        self.parts.eq(&other.parts)
87    }
88}
89
90impl Hash for Version {
91    fn hash<H: Hasher>(&self, state: &mut H) {
92        self.parts.hash(state);
93    }
94}
95
96impl PartialOrd for Version {
97    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
98        Some(self.cmp(other))
99    }
100}
101
102impl Ord for Version {
103    fn cmp(&self, other: &Self) -> Ordering {
104        self.parts
105            .iter()
106            .zip_longest(&other.parts)
107            .map(|pair| match pair {
108                EitherOrBoth::Both(a, b) => a.cmp(b),
109                EitherOrBoth::Left(a) => a.cmp(&VersionPart::DEFAULT),
110                EitherOrBoth::Right(b) => VersionPart::DEFAULT.cmp(b),
111            })
112            .find(|&ordering| ordering != Ordering::Equal)
113            .unwrap_or(Ordering::Equal)
114    }
115}
116
117impl Closest for Version {
118    fn distance_key(&self, other: &Self) -> impl Ord {
119        other
120            .parts
121            .iter()
122            .zip_longest(&self.parts)
123            .enumerate()
124            .find_map(|(index, pair)| {
125                let (candidate_part, target_part) = match pair {
126                    EitherOrBoth::Both(a, b) => (a, b),
127                    EitherOrBoth::Left(a) => (a, &VersionPart::DEFAULT),
128                    EitherOrBoth::Right(b) => (&VersionPart::DEFAULT, b),
129                };
130
131                (candidate_part != target_part).then(|| {
132                    (
133                        Reverse(index), // Prefer versions that diverge later
134                        candidate_part.number.abs_diff(target_part.number), // Prefer smaller numerical differences
135                        Reverse(candidate_part.cmp(target_part)), // Prefer higher versions
136                        Reverse(candidate_part.supplement.as_deref()), // Prefer higher supplements lexicographically
137                    )
138                })
139            })
140            .unwrap_or((
141                Reverse(usize::MAX),
142                0,
143                Reverse(Ordering::Equal),
144                Reverse(None),
145            ))
146    }
147}
148
149#[derive(Clone, Debug, PartialEq, Eq, Hash)]
150struct VersionPart {
151    number: u64,
152    supplement: Option<CompactString>,
153}
154
155impl VersionPart {
156    const DEFAULT: Self = Self {
157        number: 0,
158        supplement: None,
159    };
160
161    pub fn new(input: &str) -> Self {
162        let input = input.trim();
163
164        let split_index = input
165            .find(|char: char| !char.is_ascii_digit())
166            .unwrap_or(input.len());
167
168        let (number_str, supplement) = input.split_at(split_index);
169
170        Self {
171            number: number_str.parse().unwrap_or_default(),
172            supplement: Option::from(supplement)
173                .filter(|supplement| !supplement.is_empty())
174                .map(CompactString::from),
175        }
176    }
177
178    // WinGet ignores trailing parts that are 0 and have no supplemental value
179    pub const fn is_droppable(&self) -> bool {
180        self.number == 0 && self.supplement.is_none()
181    }
182}
183
184impl PartialOrd for VersionPart {
185    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
186        Some(self.cmp(other))
187    }
188}
189
190impl Ord for VersionPart {
191    fn cmp(&self, other: &Self) -> Ordering {
192        self.number.cmp(&other.number).then_with(|| {
193            match (self.supplement.as_deref(), other.supplement.as_deref()) {
194                (None, None) => Ordering::Equal,
195                (None, Some(_)) => Ordering::Greater,
196                (Some(_), None) => Ordering::Less,
197                (Some(a), Some(b)) => a.cmp(b),
198            }
199        })
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use std::{
206        cmp::Ordering,
207        hash::{DefaultHasher, Hash, Hasher},
208    };
209
210    use rstest::rstest;
211
212    use crate::{shared::Version, traits::Closest};
213
214    #[rstest]
215    #[case("1.0", "1.0.0")]
216    #[case("1.2.00.3", "1.2.0.3")]
217    #[case("1.2.003.4", "1.2.3.4")]
218    #[case("01.02.03.04", "1.2.3.4")]
219    #[case("1.2.03-beta", "1.2.3-beta")]
220    #[case("1.0", "1.0 ")]
221    #[case("1.0", "1. 0")]
222    #[case("1.0", "1.0.")]
223    #[case("1.0", "Version 1.0")]
224    #[case("foo1", "bar1")]
225    fn version_equality(#[case] left: &str, #[case] right: &str) {
226        let left = Version::new(left);
227        let right = Version::new(right);
228        assert_eq!(left, right);
229        assert_eq!(left.cmp(&right), Ordering::Equal);
230    }
231
232    #[rstest]
233    #[case("1", "2")]
234    #[case("1.2-rc", "1.2")]
235    #[case("1.0-rc", "1.0")]
236    #[case("1.0.0-rc", "1")]
237    #[case("22.0.0-rc.1", "22.0.0")]
238    #[case("22.0.0-rc.1", "22.0.0.1")]
239    #[case("22.0.0-rc.1", "22.0.0.1-rc")]
240    #[case("22.0.0-rc.1", "22.0.0-rc.1.1")]
241    #[case("22.0.0-rc.1.1", "22.0.0-rc.1.2")]
242    #[case("22.0.0-rc.1.2", "22.0.0-rc.2")]
243    #[case("v0.0.1", "0.0.2")]
244    #[case("v0.0.1", "v0.0.2")]
245    #[case("1.a2", "1.b1")]
246    #[case("alpha", "beta")]
247    fn version_comparison(#[case] left: &str, #[case] right: &str) {
248        let left = Version::new(left);
249        let right = Version::new(right);
250        assert!(left < right);
251        assert!(right > left);
252    }
253
254    #[rstest]
255    #[case("1", "2")]
256    #[case("1-rc", "1")]
257    #[case("1-a2", "1-b1")]
258    #[case("alpha", "beta")]
259    fn version_part_comparison(#[case] left: &str, #[case] right: &str) {
260        let left = Version::new(left);
261        let right = Version::new(right);
262        assert!(left < right);
263        assert!(right > left);
264    }
265
266    #[test]
267    fn version_hash() {
268        // If two keys are equal, their hashes must also be equal
269        // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
270
271        let version1 = Version::new("1.2.3");
272        let version2 = Version::new("1.2.3.0");
273        assert_eq!(version1, version2);
274
275        let mut version1_hasher = DefaultHasher::default();
276        version1.hash(&mut version1_hasher);
277
278        let mut version2_hasher = DefaultHasher::default();
279        version2.hash(&mut version2_hasher);
280
281        assert_eq!(version1_hasher.finish(), version2_hasher.finish());
282    }
283
284    #[test]
285    fn only_supplement() {
286        const ALPHA: &str = "alpha";
287
288        let version = Version::new(ALPHA);
289        assert_eq!(version.parts.len(), 1);
290        assert_eq!(version.parts[0].number, 0);
291        assert_eq!(version.parts[0].supplement.as_deref(), Some(ALPHA));
292    }
293
294    #[rstest]
295    #[case("1.2.3", &["1.0.0", "0.9.0", "1.5.6.3", "1.3.2"], "1.3.2")]
296    #[case("10.20.30", &["10.20.29", "10.20.31", "10.20.40"], "10.20.31")]
297    #[case("5.5.5", &["5.5.50", "5.5.0", "5.5.10"], "5.5.10")]
298    #[case("3.0.0", &["3.0.0-beta", "3.0.0-alpha.1", "3.0.0-rc.1"], "3.0.0-rc.1")]
299    #[case("2.1.0-beta", &["2.1.0-alpha", "2.1.0-beta.2", "2.1.0"], "2.1.0-beta.2")]
300    #[case("1.5.0", &["1.0.0", "2.0.0"], "1.0.0")]
301    #[case("3.3.3", &["1.1.1", "5.5.5"], "5.5.5")]
302    #[case("3.3.3", &["5.5.5", "1.1.1"], "5.5.5")]
303    #[case("2.2.2", &["2.2.2", "2.2.2", "2.2.3"], "2.2.2")]
304    #[case("0.0.2", &["0.0.1", "0.0.3", "0.2.0"], "0.0.3")]
305    #[case("999.999.999", &["999.999.998", "1000.0.0"], "999.999.998")]
306    fn closest_version(#[case] version: &str, #[case] versions: &[&str], #[case] expected: &str) {
307        let versions = versions
308            .iter()
309            .copied()
310            .map(Version::new)
311            .collect::<Vec<_>>();
312        assert_eq!(
313            Version::new(version).closest(&versions),
314            Some(&Version::new(expected))
315        );
316    }
317}