volo_http/client/layer/
header.rs

1//! [`Layer`]s for inserting header to requests.
2//!
3//! - [`Header`] inserts any [`HeaderName`] and [`HeaderValue`]
4//! - [`Host`] inserts the given `Host` or a `Host` generated by the target hostname or target
5//!   address with its scheme and port.
6//! - [`UserAgent`] inserts the given `User-Agent` or a `User-Agent` generated by the current
7//!   package information.
8
9use std::{error::Error, future::Future};
10
11use http::header::{self, HeaderName, HeaderValue};
12use motore::{layer::Layer, service::Service};
13
14use crate::{
15    client::{Target, target::RemoteHost, utils::is_default_port},
16    context::ClientContext,
17    error::client::{Result, builder_error},
18    request::Request,
19};
20
21/// [`Layer`] for inserting a header to requests.
22#[derive(Clone, Debug)]
23pub struct Header {
24    key: HeaderName,
25    val: HeaderValue,
26}
27
28impl Header {
29    /// Create a new [`Header`] layer for inserting a header to requests.
30    ///
31    /// This function takes [`HeaderName`] and [`HeaderValue`], users should create it by
32    /// themselves.
33    ///
34    /// For using string types directly, see [`Header::try_new`].
35    pub fn new(key: HeaderName, val: HeaderValue) -> Self {
36        Self { key, val }
37    }
38
39    /// Create a new [`Header`] layer for inserting a header to requests.
40    ///
41    /// This function takes any types that can be converted into [`HeaderName`] or [`HeaderValue`].
42    /// If the values are invalid [`HeaderName`] or [`HeaderValue`], an [`ClientError`] with
43    /// [`ErrorKind::Builder`] will be returned.
44    ///
45    /// [`ClientError`]: crate::error::client::ClientError
46    /// [`ErrorKind::Builder`]: crate::error::client::ErrorKind::Builder
47    pub fn try_new<K, V>(key: K, val: V) -> Result<Self>
48    where
49        K: TryInto<HeaderName>,
50        K::Error: Error + Send + Sync + 'static,
51        V: TryInto<HeaderValue>,
52        V::Error: Error + Send + Sync + 'static,
53    {
54        let key = key.try_into().map_err(builder_error)?;
55        let val = val.try_into().map_err(builder_error)?;
56
57        Ok(Self::new(key, val))
58    }
59}
60
61impl<S> Layer<S> for Header {
62    type Service = HeaderService<S>;
63
64    fn layer(self, inner: S) -> Self::Service {
65        HeaderService {
66            inner,
67            key: self.key,
68            val: self.val,
69        }
70    }
71}
72
73/// [`Service`] generated by [`Header`].
74///
75/// See [`Header`], [`Header::new`] and [`Header::try_new`] for more details.
76pub struct HeaderService<S> {
77    inner: S,
78    key: HeaderName,
79    val: HeaderValue,
80}
81
82impl<Cx, B, S> Service<Cx, Request<B>> for HeaderService<S>
83where
84    S: Service<Cx, Request<B>>,
85{
86    type Response = S::Response;
87    type Error = S::Error;
88
89    fn call(
90        &self,
91        cx: &mut Cx,
92        mut req: Request<B>,
93    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
94        req.headers_mut().insert(self.key.clone(), self.val.clone());
95        self.inner.call(cx, req)
96    }
97}
98
99/// [`Layer`] for inserting `Host` into the request header.
100#[derive(Clone, Debug, Default)]
101pub enum Host {
102    /// Do not insert `Host` into the request headers.
103    None,
104    /// If there is no `Host` in request headers, the layer will generate it through target
105    /// address.
106    #[default]
107    Auto,
108    /// Forcely use the given value as `Host` in request headers, it will override the previous
109    /// one.
110    Force(HeaderValue),
111    /// If there is no `Host` in request headers, the `Host` will be set to the given value.
112    Fallback(HeaderValue),
113}
114
115impl<S> Layer<S> for Host {
116    type Service = HostService<S>;
117
118    fn layer(self, inner: S) -> Self::Service {
119        HostService {
120            inner,
121            config: self,
122        }
123    }
124}
125
126/// [`Service`] generated by [`Host`].
127///
128/// See [`Host`] for more details.
129pub struct HostService<S> {
130    inner: S,
131    config: Host,
132}
133
134#[cfg(target_family = "unix")]
135const UDS_HOST: HeaderValue = HeaderValue::from_static("unix-domain-socket");
136
137pub(super) fn gen_host(target: &Target) -> Option<HeaderValue> {
138    let rt = match target {
139        Target::None => return None,
140        Target::Remote(rt) => rt,
141        #[cfg(target_family = "unix")]
142        Target::Local(_) => return Some(UDS_HOST.clone()),
143    };
144    let default_port = is_default_port(&rt.scheme, rt.port);
145    match &rt.host {
146        RemoteHost::Ip(ip) => {
147            let s = if default_port {
148                if ip.is_ipv4() {
149                    format!("{ip}")
150                } else {
151                    format!("[{ip}]")
152                }
153            } else {
154                let port = rt.port;
155                if ip.is_ipv4() {
156                    format!("{ip}:{port}")
157                } else {
158                    format!("[{ip}]:{port}")
159                }
160            };
161            HeaderValue::from_str(&s).ok()
162        }
163        RemoteHost::Name(name) => {
164            let port = rt.port;
165            if default_port {
166                HeaderValue::from_str(name).ok()
167            } else {
168                HeaderValue::from_str(&format!("{name}:{port}")).ok()
169            }
170        }
171    }
172}
173
174impl<B, S> Service<ClientContext, Request<B>> for HostService<S>
175where
176    S: Service<ClientContext, Request<B>>,
177{
178    type Response = S::Response;
179    type Error = S::Error;
180
181    fn call(
182        &self,
183        cx: &mut ClientContext,
184        mut req: Request<B>,
185    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
186        match &self.config {
187            Host::None => {}
188            Host::Auto => {
189                if !req.headers().contains_key(header::HOST) {
190                    if let Some(val) = gen_host(cx.target()) {
191                        req.headers_mut().insert(header::HOST, val);
192                    }
193                }
194            }
195            Host::Force(val) => {
196                req.headers_mut().insert(header::HOST, val.clone());
197            }
198            Host::Fallback(val) => {
199                if !req.headers().contains_key(header::HOST) {
200                    req.headers_mut().insert(header::HOST, val.clone());
201                }
202            }
203        }
204
205        self.inner.call(cx, req)
206    }
207}
208
209const PKG_NAME_WITH_VER: &str = concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION"));
210
211/// [`Layer`] for inserting `User-Agent` into the request header.
212///
213/// See [`UserAgent::new`] for more details.
214pub struct UserAgent {
215    val: HeaderValue,
216}
217
218impl UserAgent {
219    /// Create a new [`UserAgent`] layer that inserts `User-Agent` into the request header.
220    ///
221    /// Note that the layer only inserts it if there is no `User-Agent`
222    pub fn new(val: HeaderValue) -> Self {
223        Self { val }
224    }
225
226    /// Create a new [`UserAgent`] layer with the package name and package version as its default
227    /// value.
228    ///
229    /// Note that the layer only inserts it if there is no `User-Agent`
230    pub fn auto() -> Self {
231        Self {
232            val: HeaderValue::from_static(PKG_NAME_WITH_VER),
233        }
234    }
235}
236
237impl<S> Layer<S> for UserAgent {
238    type Service = UserAgentService<S>;
239
240    fn layer(self, inner: S) -> Self::Service {
241        UserAgentService {
242            inner,
243            val: self.val,
244        }
245    }
246}
247
248/// [`Service`] generated by [`UserAgent`].
249///
250/// See [`UserAgent`] and [`UserAgent::new`] for more details.
251pub struct UserAgentService<S> {
252    inner: S,
253    val: HeaderValue,
254}
255
256impl<Cx, B, S> Service<Cx, Request<B>> for UserAgentService<S>
257where
258    S: Service<Cx, Request<B>>,
259{
260    type Response = S::Response;
261    type Error = S::Error;
262
263    fn call(
264        &self,
265        cx: &mut Cx,
266        mut req: Request<B>,
267    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
268        if !req.headers().contains_key(header::USER_AGENT) {
269            req.headers_mut()
270                .insert(header::USER_AGENT, self.val.clone());
271        }
272        self.inner.call(cx, req)
273    }
274}
275
276#[cfg(test)]
277mod layer_header_tests {
278    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
279
280    use faststr::FastStr;
281    use http::uri::Scheme;
282
283    use crate::client::{
284        Target,
285        layer::header::gen_host,
286        target::{RemoteHost, RemoteTarget},
287    };
288
289    const IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
290    const IPV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
291
292    const fn host_target(scheme: Scheme, host: &'static str, port: u16) -> Target {
293        Target::Remote(RemoteTarget {
294            scheme,
295            host: RemoteHost::Name(FastStr::from_static_str(host)),
296            port,
297        })
298    }
299
300    const fn ip_target(scheme: Scheme, ip: IpAddr, port: u16) -> Target {
301        Target::Remote(RemoteTarget {
302            scheme,
303            host: RemoteHost::Ip(ip),
304            port,
305        })
306    }
307
308    #[test]
309    fn gen_host_test() {
310        // no host, no addr
311        assert_eq!(gen_host(&Target::None), None);
312
313        // host with default port
314        assert_eq!(
315            gen_host(&host_target(Scheme::HTTP, "github.com", 80)).unwrap(),
316            "github.com",
317        );
318        // host with non-default port
319        assert_eq!(
320            gen_host(&host_target(Scheme::HTTP, "github.com", 8000)).unwrap(),
321            "github.com:8000",
322        );
323        assert_eq!(
324            gen_host(&host_target(Scheme::HTTP, "github.com", 443)).unwrap(),
325            "github.com:443",
326        );
327
328        // ipv4 addr with default port
329        assert_eq!(
330            gen_host(&ip_target(Scheme::HTTP, IPV4, 80)).unwrap(),
331            "127.0.0.1",
332        );
333        // ipv4 addr with non-default port
334        assert_eq!(
335            gen_host(&ip_target(Scheme::HTTP, IPV4, 8000)).unwrap(),
336            "127.0.0.1:8000",
337        );
338        assert_eq!(
339            gen_host(&ip_target(Scheme::HTTP, IPV4, 443)).unwrap(),
340            "127.0.0.1:443",
341        );
342
343        // ipv6 addr with default port
344        assert_eq!(
345            gen_host(&ip_target(Scheme::HTTP, IPV6, 80)).unwrap(),
346            "[::1]",
347        );
348        // ipv6 addr with non-default port
349        assert_eq!(
350            gen_host(&ip_target(Scheme::HTTP, IPV6, 8000)).unwrap(),
351            "[::1]:8000",
352        );
353        assert_eq!(
354            gen_host(&ip_target(Scheme::HTTP, IPV6, 443)).unwrap(),
355            "[::1]:443",
356        );
357    }
358
359    #[cfg(feature = "__tls")]
360    #[test]
361    fn gen_host_with_tls_test() {
362        // host with default port
363        assert_eq!(
364            gen_host(&host_target(Scheme::HTTPS, "github.com", 443)).unwrap(),
365            "github.com",
366        );
367        // host with non-default port
368        assert_eq!(
369            gen_host(&host_target(Scheme::HTTPS, "github.com", 4430)).unwrap(),
370            "github.com:4430"
371        );
372        assert_eq!(
373            gen_host(&host_target(Scheme::HTTPS, "github.com", 80)).unwrap(),
374            "github.com:80"
375        );
376
377        // ipv4 addr with default port
378        assert_eq!(
379            gen_host(&ip_target(Scheme::HTTPS, IPV4, 443)).unwrap(),
380            "127.0.0.1"
381        );
382        // ipv4 addr with non-default port
383        assert_eq!(
384            gen_host(&ip_target(Scheme::HTTPS, IPV4, 4430)).unwrap(),
385            "127.0.0.1:4430"
386        );
387        assert_eq!(
388            gen_host(&ip_target(Scheme::HTTPS, IPV4, 80)).unwrap(),
389            "127.0.0.1:80"
390        );
391
392        // ipv6 addr with default port
393        assert_eq!(
394            gen_host(&ip_target(Scheme::HTTPS, IPV6, 443)).unwrap(),
395            "[::1]"
396        );
397        // ipv6 addr with non-default port
398        assert_eq!(
399            gen_host(&ip_target(Scheme::HTTPS, IPV6, 4430)).unwrap(),
400            "[::1]:4430"
401        );
402        assert_eq!(
403            gen_host(&ip_target(Scheme::HTTPS, IPV6, 80)).unwrap(),
404            "[::1]:80"
405        );
406    }
407}