1use std::{error::Error, future::Future};
10
11use http::header::{self, HeaderName, HeaderValue};
12use motore::{layer::Layer, service::Service};
13
14use crate::{
15 client::{Target, target::RemoteHost, utils::is_default_port},
16 context::ClientContext,
17 error::client::{Result, builder_error},
18 request::Request,
19};
20
21#[derive(Clone, Debug)]
23pub struct Header {
24 key: HeaderName,
25 val: HeaderValue,
26}
27
28impl Header {
29 pub fn new(key: HeaderName, val: HeaderValue) -> Self {
36 Self { key, val }
37 }
38
39 pub fn try_new<K, V>(key: K, val: V) -> Result<Self>
48 where
49 K: TryInto<HeaderName>,
50 K::Error: Error + Send + Sync + 'static,
51 V: TryInto<HeaderValue>,
52 V::Error: Error + Send + Sync + 'static,
53 {
54 let key = key.try_into().map_err(builder_error)?;
55 let val = val.try_into().map_err(builder_error)?;
56
57 Ok(Self::new(key, val))
58 }
59}
60
61impl<S> Layer<S> for Header {
62 type Service = HeaderService<S>;
63
64 fn layer(self, inner: S) -> Self::Service {
65 HeaderService {
66 inner,
67 key: self.key,
68 val: self.val,
69 }
70 }
71}
72
73pub struct HeaderService<S> {
77 inner: S,
78 key: HeaderName,
79 val: HeaderValue,
80}
81
82impl<Cx, B, S> Service<Cx, Request<B>> for HeaderService<S>
83where
84 S: Service<Cx, Request<B>>,
85{
86 type Response = S::Response;
87 type Error = S::Error;
88
89 fn call(
90 &self,
91 cx: &mut Cx,
92 mut req: Request<B>,
93 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
94 req.headers_mut().insert(self.key.clone(), self.val.clone());
95 self.inner.call(cx, req)
96 }
97}
98
99#[derive(Clone, Debug, Default)]
101pub enum Host {
102 None,
104 #[default]
107 Auto,
108 Force(HeaderValue),
111 Fallback(HeaderValue),
113}
114
115impl<S> Layer<S> for Host {
116 type Service = HostService<S>;
117
118 fn layer(self, inner: S) -> Self::Service {
119 HostService {
120 inner,
121 config: self,
122 }
123 }
124}
125
126pub struct HostService<S> {
130 inner: S,
131 config: Host,
132}
133
134#[cfg(target_family = "unix")]
135const UDS_HOST: HeaderValue = HeaderValue::from_static("unix-domain-socket");
136
137pub(super) fn gen_host(target: &Target) -> Option<HeaderValue> {
138 let rt = match target {
139 Target::None => return None,
140 Target::Remote(rt) => rt,
141 #[cfg(target_family = "unix")]
142 Target::Local(_) => return Some(UDS_HOST.clone()),
143 };
144 let default_port = is_default_port(&rt.scheme, rt.port);
145 match &rt.host {
146 RemoteHost::Ip(ip) => {
147 let s = if default_port {
148 if ip.is_ipv4() {
149 format!("{ip}")
150 } else {
151 format!("[{ip}]")
152 }
153 } else {
154 let port = rt.port;
155 if ip.is_ipv4() {
156 format!("{ip}:{port}")
157 } else {
158 format!("[{ip}]:{port}")
159 }
160 };
161 HeaderValue::from_str(&s).ok()
162 }
163 RemoteHost::Name(name) => {
164 let port = rt.port;
165 if default_port {
166 HeaderValue::from_str(name).ok()
167 } else {
168 HeaderValue::from_str(&format!("{name}:{port}")).ok()
169 }
170 }
171 }
172}
173
174impl<B, S> Service<ClientContext, Request<B>> for HostService<S>
175where
176 S: Service<ClientContext, Request<B>>,
177{
178 type Response = S::Response;
179 type Error = S::Error;
180
181 fn call(
182 &self,
183 cx: &mut ClientContext,
184 mut req: Request<B>,
185 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
186 match &self.config {
187 Host::None => {}
188 Host::Auto => {
189 if !req.headers().contains_key(header::HOST) {
190 if let Some(val) = gen_host(cx.target()) {
191 req.headers_mut().insert(header::HOST, val);
192 }
193 }
194 }
195 Host::Force(val) => {
196 req.headers_mut().insert(header::HOST, val.clone());
197 }
198 Host::Fallback(val) => {
199 if !req.headers().contains_key(header::HOST) {
200 req.headers_mut().insert(header::HOST, val.clone());
201 }
202 }
203 }
204
205 self.inner.call(cx, req)
206 }
207}
208
209const PKG_NAME_WITH_VER: &str = concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION"));
210
211pub struct UserAgent {
215 val: HeaderValue,
216}
217
218impl UserAgent {
219 pub fn new(val: HeaderValue) -> Self {
223 Self { val }
224 }
225
226 pub fn auto() -> Self {
231 Self {
232 val: HeaderValue::from_static(PKG_NAME_WITH_VER),
233 }
234 }
235}
236
237impl<S> Layer<S> for UserAgent {
238 type Service = UserAgentService<S>;
239
240 fn layer(self, inner: S) -> Self::Service {
241 UserAgentService {
242 inner,
243 val: self.val,
244 }
245 }
246}
247
248pub struct UserAgentService<S> {
252 inner: S,
253 val: HeaderValue,
254}
255
256impl<Cx, B, S> Service<Cx, Request<B>> for UserAgentService<S>
257where
258 S: Service<Cx, Request<B>>,
259{
260 type Response = S::Response;
261 type Error = S::Error;
262
263 fn call(
264 &self,
265 cx: &mut Cx,
266 mut req: Request<B>,
267 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
268 if !req.headers().contains_key(header::USER_AGENT) {
269 req.headers_mut()
270 .insert(header::USER_AGENT, self.val.clone());
271 }
272 self.inner.call(cx, req)
273 }
274}
275
276#[cfg(test)]
277mod layer_header_tests {
278 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
279
280 use faststr::FastStr;
281 use http::uri::Scheme;
282
283 use crate::client::{
284 Target,
285 layer::header::gen_host,
286 target::{RemoteHost, RemoteTarget},
287 };
288
289 const IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
290 const IPV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
291
292 const fn host_target(scheme: Scheme, host: &'static str, port: u16) -> Target {
293 Target::Remote(RemoteTarget {
294 scheme,
295 host: RemoteHost::Name(FastStr::from_static_str(host)),
296 port,
297 })
298 }
299
300 const fn ip_target(scheme: Scheme, ip: IpAddr, port: u16) -> Target {
301 Target::Remote(RemoteTarget {
302 scheme,
303 host: RemoteHost::Ip(ip),
304 port,
305 })
306 }
307
308 #[test]
309 fn gen_host_test() {
310 assert_eq!(gen_host(&Target::None), None);
312
313 assert_eq!(
315 gen_host(&host_target(Scheme::HTTP, "github.com", 80)).unwrap(),
316 "github.com",
317 );
318 assert_eq!(
320 gen_host(&host_target(Scheme::HTTP, "github.com", 8000)).unwrap(),
321 "github.com:8000",
322 );
323 assert_eq!(
324 gen_host(&host_target(Scheme::HTTP, "github.com", 443)).unwrap(),
325 "github.com:443",
326 );
327
328 assert_eq!(
330 gen_host(&ip_target(Scheme::HTTP, IPV4, 80)).unwrap(),
331 "127.0.0.1",
332 );
333 assert_eq!(
335 gen_host(&ip_target(Scheme::HTTP, IPV4, 8000)).unwrap(),
336 "127.0.0.1:8000",
337 );
338 assert_eq!(
339 gen_host(&ip_target(Scheme::HTTP, IPV4, 443)).unwrap(),
340 "127.0.0.1:443",
341 );
342
343 assert_eq!(
345 gen_host(&ip_target(Scheme::HTTP, IPV6, 80)).unwrap(),
346 "[::1]",
347 );
348 assert_eq!(
350 gen_host(&ip_target(Scheme::HTTP, IPV6, 8000)).unwrap(),
351 "[::1]:8000",
352 );
353 assert_eq!(
354 gen_host(&ip_target(Scheme::HTTP, IPV6, 443)).unwrap(),
355 "[::1]:443",
356 );
357 }
358
359 #[cfg(feature = "__tls")]
360 #[test]
361 fn gen_host_with_tls_test() {
362 assert_eq!(
364 gen_host(&host_target(Scheme::HTTPS, "github.com", 443)).unwrap(),
365 "github.com",
366 );
367 assert_eq!(
369 gen_host(&host_target(Scheme::HTTPS, "github.com", 4430)).unwrap(),
370 "github.com:4430"
371 );
372 assert_eq!(
373 gen_host(&host_target(Scheme::HTTPS, "github.com", 80)).unwrap(),
374 "github.com:80"
375 );
376
377 assert_eq!(
379 gen_host(&ip_target(Scheme::HTTPS, IPV4, 443)).unwrap(),
380 "127.0.0.1"
381 );
382 assert_eq!(
384 gen_host(&ip_target(Scheme::HTTPS, IPV4, 4430)).unwrap(),
385 "127.0.0.1:4430"
386 );
387 assert_eq!(
388 gen_host(&ip_target(Scheme::HTTPS, IPV4, 80)).unwrap(),
389 "127.0.0.1:80"
390 );
391
392 assert_eq!(
394 gen_host(&ip_target(Scheme::HTTPS, IPV6, 443)).unwrap(),
395 "[::1]"
396 );
397 assert_eq!(
399 gen_host(&ip_target(Scheme::HTTPS, IPV6, 4430)).unwrap(),
400 "[::1]:4430"
401 );
402 assert_eq!(
403 gen_host(&ip_target(Scheme::HTTPS, IPV6, 80)).unwrap(),
404 "[::1]:80"
405 );
406 }
407}