tower_reqwest/
set_header.rs

1//! Middleware for setting headers on requests.
2//!
3//! This module borrows heavily from the `set-header` module in the `tower-http` crate.
4//! The main difference is that this module is designed to work with the `reqwest` client,
5//! while the [`set-header`] module is designed to work with the abstract `http` service.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value
10//!
11//! ```
12#![doc = include_str!("../examples/set_header.rs")]
13//! ```
14//!
15//! [`set-header`]: https://docs.rs/tower-http/latest/tower_http/set_header/index.html
16
17use std::{
18    fmt,
19    task::{Context, Poll},
20};
21
22use http::{HeaderName, HeaderValue};
23use tower_layer::Layer;
24use tower_service::Service;
25
26/// Trait for producing header values.
27///
28/// This trait is implemented for closures with the correct type signature. Typically users will
29/// not have to implement this trait for their own types.
30///
31/// It is also implemented directly for [`HeaderValue`]. When a fixed header value should be added
32/// to all responses, it can be supplied directly to the middleware.
33pub trait MakeHeaderValue<T> {
34    /// Try to create a header value from the request or response.
35    fn make_header_value(&mut self, message: &T) -> Option<HeaderValue>;
36}
37
38impl<F, T> MakeHeaderValue<T> for F
39where
40    F: FnMut(&T) -> Option<HeaderValue>,
41{
42    fn make_header_value(&mut self, message: &T) -> Option<HeaderValue> {
43        self(message)
44    }
45}
46
47impl<T> MakeHeaderValue<T> for HeaderValue {
48    fn make_header_value(&mut self, _message: &T) -> Option<HeaderValue> {
49        Some(self.clone())
50    }
51}
52
53impl<T> MakeHeaderValue<T> for Option<HeaderValue> {
54    fn make_header_value(&mut self, _message: &T) -> Option<HeaderValue> {
55        self.clone()
56    }
57}
58
59#[derive(Debug, Clone, Copy)]
60enum InsertHeaderMode {
61    Override,
62    Append,
63    IfNotPresent,
64}
65
66impl InsertHeaderMode {
67    fn apply<M>(self, header_name: &HeaderName, target: &mut reqwest::Request, make: &mut M)
68    where
69        M: MakeHeaderValue<reqwest::Request>,
70    {
71        match self {
72            InsertHeaderMode::Override => {
73                if let Some(value) = make.make_header_value(target) {
74                    target.headers_mut().insert(header_name.clone(), value);
75                }
76            }
77            InsertHeaderMode::IfNotPresent => {
78                if !target.headers().contains_key(header_name)
79                    && let Some(value) = make.make_header_value(target)
80                {
81                    target.headers_mut().insert(header_name.clone(), value);
82                }
83            }
84            InsertHeaderMode::Append => {
85                if let Some(value) = make.make_header_value(target) {
86                    target.headers_mut().append(header_name.clone(), value);
87                }
88            }
89        }
90    }
91}
92
93/// Layer that applies [`SetRequestHeader`] which adds a request header.
94///
95/// # Example
96///
97/// Setting a header from a fixed value
98///
99/// ```
100#[doc = include_str!("../examples/set_header.rs")]
101/// ```
102pub struct SetRequestHeaderLayer<M> {
103    header_name: HeaderName,
104    make: M,
105    mode: InsertHeaderMode,
106}
107
108impl<M> fmt::Debug for SetRequestHeaderLayer<M> {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        f.debug_struct("SetRequestHeaderLayer")
111            .field("header_name", &self.header_name)
112            .field("mode", &self.mode)
113            .field("make", &std::any::type_name::<M>())
114            .finish()
115    }
116}
117
118impl<M> SetRequestHeaderLayer<M> {
119    /// Create a new [`SetRequestHeaderLayer`].
120    ///
121    /// If a previous value exists for the same header, it is removed and replaced with the new
122    /// header value.
123    pub fn overriding(header_name: HeaderName, make: M) -> Self {
124        Self::new(header_name, make, InsertHeaderMode::Override)
125    }
126
127    /// Create a new [`SetRequestHeaderLayer`].
128    ///
129    /// The new header is always added, preserving any existing values. If previous values exist,
130    /// the header will have multiple values.
131    pub fn appending(header_name: HeaderName, make: M) -> Self {
132        Self::new(header_name, make, InsertHeaderMode::Append)
133    }
134
135    /// Create a new [`SetRequestHeaderLayer`].
136    ///
137    /// If a previous value exists for the header, the new value is not inserted.
138    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
139        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
140    }
141
142    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
143        Self {
144            header_name,
145            make,
146            mode,
147        }
148    }
149}
150
151impl<S, M> Layer<S> for SetRequestHeaderLayer<M>
152where
153    M: Clone,
154{
155    type Service = SetRequestHeader<S, M>;
156
157    fn layer(&self, inner: S) -> Self::Service {
158        SetRequestHeader {
159            inner,
160            header_name: self.header_name.clone(),
161            make: self.make.clone(),
162            mode: self.mode,
163        }
164    }
165}
166
167impl<M> Clone for SetRequestHeaderLayer<M>
168where
169    M: Clone,
170{
171    fn clone(&self) -> Self {
172        Self {
173            make: self.make.clone(),
174            header_name: self.header_name.clone(),
175            mode: self.mode,
176        }
177    }
178}
179
180/// Middleware that sets a header on the request.
181#[derive(Clone)]
182pub struct SetRequestHeader<S, M> {
183    inner: S,
184    header_name: HeaderName,
185    make: M,
186    mode: InsertHeaderMode,
187}
188
189impl<S, M> SetRequestHeader<S, M> {
190    /// Create a new [`SetRequestHeader`].
191    ///
192    /// If a previous value exists for the same header, it is removed and replaced with the new
193    /// header value.
194    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
195        Self::new(inner, header_name, make, InsertHeaderMode::Override)
196    }
197
198    /// Create a new [`SetRequestHeader`].
199    ///
200    /// The new header is always added, preserving any existing values. If previous values exist,
201    /// the header will have multiple values.
202    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
203        Self::new(inner, header_name, make, InsertHeaderMode::Append)
204    }
205
206    /// Create a new [`SetRequestHeader`].
207    ///
208    /// If a previous value exists for the header, the new value is not inserted.
209    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
210        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
211    }
212
213    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
214        Self {
215            inner,
216            header_name,
217            make,
218            mode,
219        }
220    }
221}
222
223impl<S, M> fmt::Debug for SetRequestHeader<S, M>
224where
225    S: fmt::Debug,
226{
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        f.debug_struct("SetRequestHeader")
229            .field("inner", &self.inner)
230            .field("header_name", &self.header_name)
231            .field("mode", &self.mode)
232            .field("make", &std::any::type_name::<M>())
233            .finish()
234    }
235}
236
237impl<S, M> Service<reqwest::Request> for SetRequestHeader<S, M>
238where
239    S: Service<reqwest::Request, Response = reqwest::Response>,
240    M: MakeHeaderValue<reqwest::Request>,
241{
242    type Response = S::Response;
243    type Error = S::Error;
244    type Future = S::Future;
245
246    #[inline]
247    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248        self.inner.poll_ready(cx)
249    }
250
251    fn call(&mut self, mut req: reqwest::Request) -> Self::Future {
252        self.mode.apply(&self.header_name, &mut req, &mut self.make);
253        self.inner.call(req)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259
260    use http::{HeaderName, HeaderValue};
261    use tower_layer::Layer;
262    use tower_service::Service;
263    use wiremock::{
264        Mock, MockServer, ResponseTemplate,
265        matchers::{method, path},
266    };
267
268    use crate::set_header::SetRequestHeaderLayer;
269
270    #[tokio::test]
271    async fn test_set_headers() -> anyhow::Result<()> {
272        let mock_server = MockServer::start().await;
273        let mock_uri = mock_server.uri();
274
275        let header_name = HeaderName::from_static("x-test-header");
276        let header_value = HeaderValue::from_static("test-value");
277
278        Mock::given(method("GET"))
279            .and(path("/test"))
280            .and(wiremock::matchers::header(&header_name, &header_value))
281            .respond_with(ResponseTemplate::new(200))
282            .mount(&mock_server)
283            .await;
284
285        let uri = format!("{mock_uri}/test");
286        let request = reqwest::Request::new(reqwest::Method::GET, uri.parse()?);
287
288        let client = reqwest::Client::new();
289        // Check that the header is not set by default.
290        let response = client.execute(request.try_clone().unwrap()).await?;
291        assert_eq!(response.status(), 404);
292        // Check that the header will be set by the layer.
293        let response = SetRequestHeaderLayer::overriding(header_name, header_value)
294            .layer(client)
295            .call(request)
296            .await?;
297        assert_eq!(response.status(), 200);
298
299        Ok(())
300    }
301}