1use std::fmt;
2use std::str::FromStr;
3
4use anyhow::{anyhow, bail, Context};
5use semver::{Version, VersionReq};
6use serde::de::{Deserialize, Deserializer, Error, Visitor};
7use serde::ser::{Serialize, Serializer};
8
9use crate::package_id::PackageId;
10use crate::package_name::PackageName;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
20pub struct PackageReq {
21 name: PackageName,
22 version_req: VersionReq,
23}
24
25impl PackageReq {
26 pub fn new(name: PackageName, version_req: VersionReq) -> Self {
27 PackageReq { name, version_req }
28 }
29
30 pub fn name(&self) -> &PackageName {
31 &self.name
32 }
33
34 pub fn version_req(&self) -> &VersionReq {
35 &self.version_req
36 }
37
38 pub fn matches_id(&self, package_id: &PackageId) -> bool {
39 self.matches(package_id.name(), package_id.version())
40 }
41
42 pub fn matches(&self, name: &PackageName, version: &Version) -> bool {
43 self.name() == name && self.version_req.matches(version)
44 }
45}
46
47impl fmt::Display for PackageReq {
48 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
49 write!(formatter, "{}@{}", self.name, self.version_req)
50 }
51}
52
53impl FromStr for PackageReq {
54 type Err = anyhow::Error;
55
56 fn from_str(value: &str) -> anyhow::Result<Self> {
57 const BAD_FORMAT_MSG: &str = "a package requirement is of the form SCOPE/NAME@VERSION_REQ";
58
59 let mut first_half = value.splitn(2, '/');
60 let scope = first_half.next().ok_or_else(|| anyhow!(BAD_FORMAT_MSG))?;
61 let name_and_version = first_half.next().ok_or_else(|| anyhow!(BAD_FORMAT_MSG))?;
62
63 let mut second_half = name_and_version.splitn(2, '@');
64 let name = second_half.next().ok_or_else(|| anyhow!(BAD_FORMAT_MSG))?;
65
66 let version_req_source = second_half.next().ok_or_else(|| anyhow!(BAD_FORMAT_MSG))?;
67
68 if version_req_source.len() == 0 || version_req_source.chars().all(char::is_whitespace) {
74 bail!(BAD_FORMAT_MSG);
75 }
76
77 let version_req = version_req_source
78 .parse()
79 .context("could not parse version requirement")?;
80
81 let package_name = PackageName::new(scope, name).context(BAD_FORMAT_MSG)?;
82 Ok(PackageReq::new(package_name, version_req))
83 }
84}
85
86impl Serialize for PackageReq {
87 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88 let combined_name = format!(
89 "{}/{}@{}",
90 self.name().scope(),
91 self.name().name(),
92 self.version_req()
93 );
94 serializer.serialize_str(&combined_name)
95 }
96}
97
98impl<'de> Deserialize<'de> for PackageReq {
99 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
100 deserializer.deserialize_str(PackageReqVisitor)
101 }
102}
103
104struct PackageReqVisitor;
105
106impl<'de> Visitor<'de> for PackageReqVisitor {
107 type Value = PackageReq;
108
109 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
110 write!(
111 formatter,
112 "a package requirement of the form SCOPE/NAME@VERSION_REQ"
113 )
114 }
115
116 fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
117 value.parse().map_err(|err| E::custom(err))
118 }
119}
120
121#[cfg(test)]
122mod test {
123 use super::*;
124
125 #[test]
126 fn new() {
127 let req = PackageReq::new(
128 PackageName::new("foo", "bar").unwrap(),
129 VersionReq::parse("1.2.3").unwrap(),
130 );
131 assert_eq!(req.name().scope(), "foo");
132 assert_eq!(req.name().name(), "bar");
133 assert_eq!(req.version_req(), &VersionReq::parse("1.2.3").unwrap());
134 }
135
136 #[test]
137 fn display() {
138 let req = PackageReq::new(
139 PackageName::new("hello", "world").unwrap(),
140 VersionReq::parse("0.2.3").unwrap(),
141 );
142
143 assert_eq!(req.to_string(), "hello/world@>=0.2.3, <0.3.0");
147 }
148
149 #[test]
150 fn parse() {
151 let default_compat: PackageReq = "hello/world@1.2.3".parse().unwrap();
154 assert_eq!(default_compat.name().scope(), "hello");
155 assert_eq!(default_compat.name().name(), "world");
156 assert_eq!(
157 default_compat.version_req(),
158 &VersionReq::parse("^1.2.3").unwrap()
159 );
160
161 let with_ops: PackageReq = "hello/world@>=0.2.0, <0.2.7".parse().unwrap();
164 assert_eq!(with_ops.name().scope(), "hello");
165 assert_eq!(with_ops.name().name(), "world");
166 assert_eq!(
167 with_ops.version_req(),
168 &VersionReq::parse(">=0.2.0, <0.2.7").unwrap()
169 );
170 }
171
172 #[test]
173 fn parse_invalid() {
174 let no_version: Result<PackageReq, _> = "hello/world".parse();
176 no_version.unwrap_err();
177 let no_version_at: Result<PackageReq, _> = "hello/world@".parse();
178 no_version_at.unwrap_err();
179 }
180
181 #[test]
182 fn serialization() {
183 let name = PackageName::new("lpghatguy", "asink").unwrap();
184 let package_req = PackageReq::new(name, VersionReq::parse("2.3.1").unwrap());
185
186 let serialized = serde_json::to_string(&package_req).unwrap();
187 assert_eq!(serialized, "\"lpghatguy/asink@>=2.3.1, <3.0.0\"");
188
189 let deserialized: PackageReq = serde_json::from_str(&serialized).unwrap();
190 assert_eq!(deserialized, package_req);
191 }
192}