winget_types/shared/
package_identifier.rs

1use core::{fmt, str::FromStr};
2
3use compact_str::CompactString;
4use thiserror::Error;
5
6use super::DISALLOWED_CHARACTERS;
7
8#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[cfg_attr(feature = "serde", serde(try_from = "CompactString"))]
11#[repr(transparent)]
12pub struct PackageIdentifier(CompactString);
13
14#[derive(Error, Debug, Eq, PartialEq)]
15pub enum PackageIdentifierError {
16    #[error("Package identifier cannot be empty")]
17    Empty,
18    #[error("A part of a package identifier cannot be empty")]
19    EmptyPart,
20    #[error(
21        "Package identifier cannot be more than {} characters long",
22        PackageIdentifier::MAX_CHAR_LENGTH
23    )]
24    TooLong,
25    #[error("Package identifier contains invalid character {_0:?}")]
26    InvalidCharacter(char),
27    #[error(
28        "The length of a part in a package identifier cannot be more than {} characters long",
29        PackageIdentifier::MAX_PART_CHAR_LENGTH
30    )]
31    PartTooLong,
32    #[error(
33        "The number of parts in the package identifier must be between {} and {}",
34        PackageIdentifier::MIN_PARTS,
35        PackageIdentifier::MAX_PARTS
36    )]
37    InvalidPartCount,
38}
39
40impl PackageIdentifier {
41    pub const MAX_CHAR_LENGTH: usize = 128;
42    pub const MIN_PARTS: usize = 2;
43    pub const MAX_PARTS: usize = 8;
44    pub const MAX_PART_CHAR_LENGTH: usize = 32;
45
46    /// Creates a new `PackageIdentifier` from any type that implements `AsRef<str>` and
47    /// `Into<CompactString>`.
48    ///
49    /// # Errors
50    ///
51    /// Will return `Err` if the package identifier:
52    /// 1. Is empty
53    /// 2. Has an empty part
54    /// 3. Is more than 128 characters long
55    /// 4. Has a part more than 32 characters long
56    /// 5. Contains a disallowed character (control, whitespace, or one of [`DISALLOWED_CHARACTERS`])
57    pub fn new<T: AsRef<str> + Into<CompactString>>(
58        identifier: T,
59    ) -> Result<Self, PackageIdentifierError> {
60        let identifier_str = identifier.as_ref();
61
62        if identifier_str.is_empty() {
63            return Err(PackageIdentifierError::Empty);
64        }
65
66        let (char_count, parts_count) = identifier_str.split('.').try_fold(
67            (0, 0),
68            |(total_char_count, part_count), part| {
69                if part.is_empty() {
70                    return Err(PackageIdentifierError::EmptyPart);
71                }
72
73                let part_char_count = part.chars().try_fold(0, |char_count, char| {
74                    if DISALLOWED_CHARACTERS.contains(&char)
75                        || char.is_control()
76                        || char.is_whitespace()
77                    {
78                        return Err(PackageIdentifierError::InvalidCharacter(char));
79                    }
80
81                    Ok(char_count + 1)
82                })?;
83
84                if part_char_count > Self::MAX_PART_CHAR_LENGTH {
85                    return Err(PackageIdentifierError::PartTooLong);
86                }
87
88                Ok((
89                    total_char_count + part_char_count + '.'.len_utf8(),
90                    part_count + 1,
91                ))
92            },
93        )?;
94
95        if char_count > Self::MAX_CHAR_LENGTH {
96            return Err(PackageIdentifierError::TooLong);
97        }
98
99        if !(Self::MIN_PARTS..=Self::MAX_PARTS).contains(&parts_count) {
100            return Err(PackageIdentifierError::InvalidPartCount);
101        }
102
103        Ok(Self(identifier.into()))
104    }
105
106    /// Creates a new `PackageIdentifier` from any type that implements `Into<CompactString>`
107    /// without checking its validity.
108    ///
109    /// # Safety
110    ///
111    /// The package identifier must not:
112    /// 1. Be empty
113    /// 2. Have an empty part
114    /// 3. Be more than 128 characters long
115    /// 4. Have a part more than 32 characters long
116    /// 5. Contain a disallowed character (control, whitespace, or one of [`DISALLOWED_CHARACTERS`])
117    #[must_use]
118    #[inline]
119    pub unsafe fn new_unchecked<T: Into<CompactString>>(identifier: T) -> Self {
120        Self(identifier.into())
121    }
122
123    /// Extracts a string slice containing the entire `PackageIdentifier`.
124    #[must_use]
125    #[inline]
126    pub fn as_str(&self) -> &str {
127        self.0.as_str()
128    }
129}
130
131impl fmt::Display for PackageIdentifier {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        self.0.fmt(f)
134    }
135}
136
137impl FromStr for PackageIdentifier {
138    type Err = PackageIdentifierError;
139
140    fn from_str(s: &str) -> Result<Self, PackageIdentifierError> {
141        Self::new(s)
142    }
143}
144
145impl TryFrom<CompactString> for PackageIdentifier {
146    type Error = PackageIdentifierError;
147
148    #[inline]
149    fn try_from(value: CompactString) -> Result<Self, Self::Error> {
150        Self::new(value)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use alloc::{format, string::String};
157    use core::iter::repeat_n;
158
159    #[cfg(feature = "serde")]
160    use indoc::indoc;
161    use rstest::rstest;
162
163    use crate::shared::{
164        DISALLOWED_CHARACTERS,
165        package_identifier::{PackageIdentifier, PackageIdentifierError},
166    };
167
168    #[rstest]
169    #[case("Package.Identifier")]
170    #[case("Microsoft.PowerShell")]
171    #[case("Google.Chrome.Canary")]
172    #[case("EclipseAdoptium.Temurin.21.JDK")]
173    #[case("A.Long.Package.Identifier.With.Exactly.Eight.Parts")]
174    fn valid_package_identifier(#[case] package_identifier: &str) {
175        assert!(package_identifier.parse::<PackageIdentifier>().is_ok());
176    }
177
178    #[test]
179    fn too_long_package_identifier() {
180        let num_delimiters = PackageIdentifier::MAX_PARTS - 1;
181        let part_length = (PackageIdentifier::MAX_CHAR_LENGTH - num_delimiters)
182            .div_ceil(PackageIdentifier::MAX_PARTS);
183
184        let part = "a".repeat(part_length);
185
186        let identifier =
187            itertools::intersperse(repeat_n(&*part, PackageIdentifier::MAX_PARTS), ".")
188                .collect::<String>();
189
190        assert_eq!(
191            identifier.parse::<PackageIdentifier>(),
192            Err(PackageIdentifierError::TooLong)
193        );
194    }
195
196    #[test]
197    fn too_many_parts_package_identifier() {
198        assert_eq!(
199            itertools::intersperse(repeat_n('a', PackageIdentifier::MAX_PARTS + 1), '.')
200                .collect::<String>()
201                .parse::<PackageIdentifier>(),
202            Err(PackageIdentifierError::InvalidPartCount)
203        );
204
205        assert_eq!(
206            "Really.Long.Package.Identifier.Spanning.More.Than.Eight.Parts"
207                .parse::<PackageIdentifier>(),
208            Err(PackageIdentifierError::InvalidPartCount)
209        );
210    }
211
212    #[test]
213    fn package_identifier_parts_too_long() {
214        let part = "a".repeat(PackageIdentifier::MAX_PART_CHAR_LENGTH + 1);
215
216        let identifier =
217            itertools::intersperse(repeat_n(&*part, PackageIdentifier::MIN_PARTS), ".")
218                .collect::<String>();
219
220        assert_eq!(
221            identifier.parse::<PackageIdentifier>(),
222            Err(PackageIdentifierError::PartTooLong)
223        );
224    }
225
226    #[test]
227    fn too_few_parts_package_identifier() {
228        assert_eq!(
229            "a".repeat(PackageIdentifier::MIN_PARTS - 1)
230                .parse::<PackageIdentifier>(),
231            Err(PackageIdentifierError::InvalidPartCount)
232        );
233
234        assert_eq!(
235            "OnePart".parse::<PackageIdentifier>(),
236            Err(PackageIdentifierError::InvalidPartCount)
237        );
238    }
239
240    #[test]
241    fn whitespace_in_package_identifier() {
242        assert_eq!(
243            "Publisher.Pack age".parse::<PackageIdentifier>(),
244            Err(PackageIdentifierError::InvalidCharacter(' '))
245        );
246    }
247
248    #[test]
249    fn control_chars_in_package_identifier() {
250        for char in '\u{0}'..='\u{1F}' {
251            assert_eq!(
252                format!("Publisher.Pack{char}age").parse::<PackageIdentifier>(),
253                Err(PackageIdentifierError::InvalidCharacter(char))
254            );
255        }
256    }
257
258    #[test]
259    fn package_identifier_disallowed_characters() {
260        for char in DISALLOWED_CHARACTERS {
261            let identifier = format!("Publisher.Pack{char}age");
262
263            assert_eq!(
264                identifier.parse::<PackageIdentifier>(),
265                Err(PackageIdentifierError::InvalidCharacter(char))
266            );
267        }
268    }
269
270    #[test]
271    fn package_identifier_part_empty() {
272        assert!("a.b".parse::<PackageIdentifier>().is_ok());
273        assert_eq!(
274            "a.b.".parse::<PackageIdentifier>(),
275            Err(PackageIdentifierError::EmptyPart)
276        );
277        assert_eq!(
278            "a..b".parse::<PackageIdentifier>(),
279            Err(PackageIdentifierError::EmptyPart)
280        );
281    }
282
283    #[cfg(feature = "serde")]
284    #[derive(serde::Serialize, serde::Deserialize)]
285    #[serde(rename_all = "PascalCase")]
286    struct Manifest {
287        package_identifier: PackageIdentifier,
288    }
289
290    #[cfg(feature = "serde")]
291    #[test]
292    fn serialize_package_identifier() {
293        assert_eq!(
294            serde_yaml::to_string(&Manifest {
295                package_identifier: "Microsoft.PowerShell".parse().unwrap()
296            })
297            .unwrap(),
298            indoc! {"
299                PackageIdentifier: Microsoft.PowerShell
300            "}
301        );
302    }
303
304    #[cfg(feature = "serde")]
305    #[test]
306    fn deserialize_package_identifier() {
307        assert_eq!(
308            serde_yaml::from_str::<Manifest>(indoc! {"
309                PackageIdentifier: Microsoft.PowerShell
310            "})
311            .unwrap()
312            .package_identifier,
313            "Microsoft.PowerShell".parse::<PackageIdentifier>().unwrap()
314        );
315    }
316}