uv_configuration/
trusted_host.rs

1use serde::{Deserialize, Deserializer};
2#[cfg(feature = "schemars")]
3use std::borrow::Cow;
4use std::str::FromStr;
5use url::Url;
6
7/// A host specification (wildcard, or host, with optional scheme and/or port) for which
8/// certificates are not verified when making HTTPS requests.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum TrustedHost {
11    Wildcard,
12    Host {
13        scheme: Option<String>,
14        host: String,
15        port: Option<u16>,
16    },
17}
18
19impl TrustedHost {
20    /// Returns `true` if the [`Url`] matches this trusted host.
21    pub fn matches(&self, url: &Url) -> bool {
22        match self {
23            Self::Wildcard => true,
24            Self::Host { scheme, host, port } => {
25                if scheme.as_ref().is_some_and(|scheme| scheme != url.scheme()) {
26                    return false;
27                }
28
29                if port.is_some_and(|port| url.port() != Some(port)) {
30                    return false;
31                }
32
33                if Some(host.as_str()) != url.host_str() {
34                    return false;
35                }
36
37                true
38            }
39        }
40    }
41}
42
43impl<'de> Deserialize<'de> for TrustedHost {
44    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
45    where
46        D: Deserializer<'de>,
47    {
48        #[derive(Deserialize)]
49        struct Inner {
50            scheme: Option<String>,
51            host: String,
52            port: Option<u16>,
53        }
54
55        serde_untagged::UntaggedEnumVisitor::new()
56            .string(|string| Self::from_str(string).map_err(serde::de::Error::custom))
57            .map(|map| {
58                map.deserialize::<Inner>().map(|inner| Self::Host {
59                    scheme: inner.scheme,
60                    host: inner.host,
61                    port: inner.port,
62                })
63            })
64            .deserialize(deserializer)
65    }
66}
67
68impl serde::Serialize for TrustedHost {
69    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70    where
71        S: serde::ser::Serializer,
72    {
73        let s = self.to_string();
74        serializer.serialize_str(&s)
75    }
76}
77
78#[derive(Debug, thiserror::Error)]
79pub enum TrustedHostError {
80    #[error("missing host for `--trusted-host`: `{0}`")]
81    MissingHost(String),
82    #[error("invalid port for `--trusted-host`: `{0}`")]
83    InvalidPort(String),
84}
85
86impl FromStr for TrustedHost {
87    type Err = TrustedHostError;
88
89    fn from_str(s: &str) -> Result<Self, Self::Err> {
90        if s == "*" {
91            return Ok(Self::Wildcard);
92        }
93
94        // Detect scheme.
95        let (scheme, s) = if let Some(s) = s.strip_prefix("https://") {
96            (Some("https".to_string()), s)
97        } else if let Some(s) = s.strip_prefix("http://") {
98            (Some("http".to_string()), s)
99        } else {
100            (None, s)
101        };
102
103        let mut parts = s.splitn(2, ':');
104
105        // Detect host.
106        let host = parts
107            .next()
108            .and_then(|host| host.split('/').next())
109            .map(ToString::to_string)
110            .ok_or_else(|| TrustedHostError::MissingHost(s.to_string()))?;
111
112        // Detect port.
113        let port = parts
114            .next()
115            .map(str::parse)
116            .transpose()
117            .map_err(|_| TrustedHostError::InvalidPort(s.to_string()))?;
118
119        Ok(Self::Host { scheme, host, port })
120    }
121}
122
123impl std::fmt::Display for TrustedHost {
124    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
125        match self {
126            Self::Wildcard => {
127                write!(f, "*")?;
128            }
129            Self::Host { scheme, host, port } => {
130                if let Some(scheme) = &scheme {
131                    write!(f, "{scheme}://{host}")?;
132                } else {
133                    write!(f, "{host}")?;
134                }
135
136                if let Some(port) = port {
137                    write!(f, ":{port}")?;
138                }
139            }
140        }
141
142        Ok(())
143    }
144}
145
146#[cfg(feature = "schemars")]
147impl schemars::JsonSchema for TrustedHost {
148    fn schema_name() -> Cow<'static, str> {
149        Cow::Borrowed("TrustedHost")
150    }
151
152    fn json_schema(_generator: &mut schemars::generate::SchemaGenerator) -> schemars::Schema {
153        schemars::json_schema!({
154            "type": "string",
155            "description": "A host or host-port pair."
156        })
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    #[test]
163    fn parse() {
164        assert_eq!(
165            "*".parse::<super::TrustedHost>().unwrap(),
166            super::TrustedHost::Wildcard
167        );
168
169        assert_eq!(
170            "example.com".parse::<super::TrustedHost>().unwrap(),
171            super::TrustedHost::Host {
172                scheme: None,
173                host: "example.com".to_string(),
174                port: None
175            }
176        );
177
178        assert_eq!(
179            "example.com:8080".parse::<super::TrustedHost>().unwrap(),
180            super::TrustedHost::Host {
181                scheme: None,
182                host: "example.com".to_string(),
183                port: Some(8080)
184            }
185        );
186
187        assert_eq!(
188            "https://example.com".parse::<super::TrustedHost>().unwrap(),
189            super::TrustedHost::Host {
190                scheme: Some("https".to_string()),
191                host: "example.com".to_string(),
192                port: None
193            }
194        );
195
196        assert_eq!(
197            "https://example.com/hello/world"
198                .parse::<super::TrustedHost>()
199                .unwrap(),
200            super::TrustedHost::Host {
201                scheme: Some("https".to_string()),
202                host: "example.com".to_string(),
203                port: None
204            }
205        );
206    }
207}