1#![doc = include_str!("../examples/set_header.rs")]
13use std::{
18 fmt,
19 task::{Context, Poll},
20};
21
22use http::{HeaderName, HeaderValue};
23use tower_layer::Layer;
24use tower_service::Service;
25
26pub trait MakeHeaderValue<T> {
34 fn make_header_value(&mut self, message: &T) -> Option<HeaderValue>;
36}
37
38impl<F, T> MakeHeaderValue<T> for F
39where
40 F: FnMut(&T) -> Option<HeaderValue>,
41{
42 fn make_header_value(&mut self, message: &T) -> Option<HeaderValue> {
43 self(message)
44 }
45}
46
47impl<T> MakeHeaderValue<T> for HeaderValue {
48 fn make_header_value(&mut self, _message: &T) -> Option<HeaderValue> {
49 Some(self.clone())
50 }
51}
52
53impl<T> MakeHeaderValue<T> for Option<HeaderValue> {
54 fn make_header_value(&mut self, _message: &T) -> Option<HeaderValue> {
55 self.clone()
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
60enum InsertHeaderMode {
61 Override,
62 Append,
63 IfNotPresent,
64}
65
66impl InsertHeaderMode {
67 fn apply<M>(self, header_name: &HeaderName, target: &mut reqwest::Request, make: &mut M)
68 where
69 M: MakeHeaderValue<reqwest::Request>,
70 {
71 match self {
72 InsertHeaderMode::Override => {
73 if let Some(value) = make.make_header_value(target) {
74 target.headers_mut().insert(header_name.clone(), value);
75 }
76 }
77 InsertHeaderMode::IfNotPresent => {
78 if !target.headers().contains_key(header_name)
79 && let Some(value) = make.make_header_value(target)
80 {
81 target.headers_mut().insert(header_name.clone(), value);
82 }
83 }
84 InsertHeaderMode::Append => {
85 if let Some(value) = make.make_header_value(target) {
86 target.headers_mut().append(header_name.clone(), value);
87 }
88 }
89 }
90 }
91}
92
93#[doc = include_str!("../examples/set_header.rs")]
101pub struct SetRequestHeaderLayer<M> {
103 header_name: HeaderName,
104 make: M,
105 mode: InsertHeaderMode,
106}
107
108impl<M> fmt::Debug for SetRequestHeaderLayer<M> {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 f.debug_struct("SetRequestHeaderLayer")
111 .field("header_name", &self.header_name)
112 .field("mode", &self.mode)
113 .field("make", &std::any::type_name::<M>())
114 .finish()
115 }
116}
117
118impl<M> SetRequestHeaderLayer<M> {
119 pub fn overriding(header_name: HeaderName, make: M) -> Self {
124 Self::new(header_name, make, InsertHeaderMode::Override)
125 }
126
127 pub fn appending(header_name: HeaderName, make: M) -> Self {
132 Self::new(header_name, make, InsertHeaderMode::Append)
133 }
134
135 pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
139 Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
140 }
141
142 fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
143 Self {
144 header_name,
145 make,
146 mode,
147 }
148 }
149}
150
151impl<S, M> Layer<S> for SetRequestHeaderLayer<M>
152where
153 M: Clone,
154{
155 type Service = SetRequestHeader<S, M>;
156
157 fn layer(&self, inner: S) -> Self::Service {
158 SetRequestHeader {
159 inner,
160 header_name: self.header_name.clone(),
161 make: self.make.clone(),
162 mode: self.mode,
163 }
164 }
165}
166
167impl<M> Clone for SetRequestHeaderLayer<M>
168where
169 M: Clone,
170{
171 fn clone(&self) -> Self {
172 Self {
173 make: self.make.clone(),
174 header_name: self.header_name.clone(),
175 mode: self.mode,
176 }
177 }
178}
179
180#[derive(Clone)]
182pub struct SetRequestHeader<S, M> {
183 inner: S,
184 header_name: HeaderName,
185 make: M,
186 mode: InsertHeaderMode,
187}
188
189impl<S, M> SetRequestHeader<S, M> {
190 pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
195 Self::new(inner, header_name, make, InsertHeaderMode::Override)
196 }
197
198 pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
203 Self::new(inner, header_name, make, InsertHeaderMode::Append)
204 }
205
206 pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
210 Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
211 }
212
213 fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
214 Self {
215 inner,
216 header_name,
217 make,
218 mode,
219 }
220 }
221}
222
223impl<S, M> fmt::Debug for SetRequestHeader<S, M>
224where
225 S: fmt::Debug,
226{
227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228 f.debug_struct("SetRequestHeader")
229 .field("inner", &self.inner)
230 .field("header_name", &self.header_name)
231 .field("mode", &self.mode)
232 .field("make", &std::any::type_name::<M>())
233 .finish()
234 }
235}
236
237impl<S, M> Service<reqwest::Request> for SetRequestHeader<S, M>
238where
239 S: Service<reqwest::Request, Response = reqwest::Response>,
240 M: MakeHeaderValue<reqwest::Request>,
241{
242 type Response = S::Response;
243 type Error = S::Error;
244 type Future = S::Future;
245
246 #[inline]
247 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248 self.inner.poll_ready(cx)
249 }
250
251 fn call(&mut self, mut req: reqwest::Request) -> Self::Future {
252 self.mode.apply(&self.header_name, &mut req, &mut self.make);
253 self.inner.call(req)
254 }
255}
256
257#[cfg(test)]
258mod tests {
259
260 use http::{HeaderName, HeaderValue};
261 use tower_layer::Layer;
262 use tower_service::Service;
263 use wiremock::{
264 Mock, MockServer, ResponseTemplate,
265 matchers::{method, path},
266 };
267
268 use crate::set_header::SetRequestHeaderLayer;
269
270 #[tokio::test]
271 async fn test_set_headers() -> anyhow::Result<()> {
272 let mock_server = MockServer::start().await;
273 let mock_uri = mock_server.uri();
274
275 let header_name = HeaderName::from_static("x-test-header");
276 let header_value = HeaderValue::from_static("test-value");
277
278 Mock::given(method("GET"))
279 .and(path("/test"))
280 .and(wiremock::matchers::header(&header_name, &header_value))
281 .respond_with(ResponseTemplate::new(200))
282 .mount(&mock_server)
283 .await;
284
285 let uri = format!("{mock_uri}/test");
286 let request = reqwest::Request::new(reqwest::Method::GET, uri.parse()?);
287
288 let client = reqwest::Client::new();
289 let response = client.execute(request.try_clone().unwrap()).await?;
291 assert_eq!(response.status(), 404);
292 let response = SetRequestHeaderLayer::overriding(header_name, header_value)
294 .layer(client)
295 .call(request)
296 .await?;
297 assert_eq!(response.status(), 200);
298
299 Ok(())
300 }
301}