tower_async_http/add_extension.rs
1//! Middleware that clones a value into each request's [extensions].
2//!
3//! [extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
4//!
5//! # Example
6//!
7//! ```
8//! use tower_async_http::add_extension::AddExtensionLayer;
9//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn};
10//! use http::{Request, Response};
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::{sync::Arc, convert::Infallible};
14//!
15//! # struct DatabaseConnectionPool;
16//! # impl DatabaseConnectionPool {
17//! # fn new() -> DatabaseConnectionPool { DatabaseConnectionPool }
18//! # }
19//! #
20//! // Shared state across all request handlers --- in this case, a pool of database connections.
21//! struct State {
22//! pool: DatabaseConnectionPool,
23//! }
24//!
25//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
26//! // Grab the state from the request extensions.
27//! let state = req.extensions().get::<Arc<State>>().unwrap();
28//!
29//! Ok(Response::new(Full::<Bytes>::default()))
30//! }
31//!
32//! # #[tokio::main]
33//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
34//! // Construct the shared state.
35//! let state = State {
36//! pool: DatabaseConnectionPool::new(),
37//! };
38//!
39//! let mut service = ServiceBuilder::new()
40//! // Share an `Arc<State>` with all requests.
41//! .layer(AddExtensionLayer::new(Arc::new(state)))
42//! .service_fn(handle);
43//!
44//! // Call the service.
45//! let response = service
46
47//! .call(Request::new(Full::<Bytes>::default()))
48//! .await?;
49//! # Ok(())
50//! # }
51//! ```
52
53use http::{Request, Response};
54use tower_async_layer::Layer;
55use tower_async_service::Service;
56
57/// [`Layer`] for adding some shareable value to [request extensions].
58///
59/// See the [module docs](crate::add_extension) for more details.
60///
61/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
62#[derive(Clone, Copy, Debug)]
63pub struct AddExtensionLayer<T> {
64 value: T,
65}
66
67impl<T> AddExtensionLayer<T> {
68 /// Create a new [`AddExtensionLayer`].
69 pub fn new(value: T) -> Self {
70 AddExtensionLayer { value }
71 }
72}
73
74impl<S, T> Layer<S> for AddExtensionLayer<T>
75where
76 T: Clone,
77{
78 type Service = AddExtension<S, T>;
79
80 fn layer(&self, inner: S) -> Self::Service {
81 AddExtension {
82 inner,
83 value: self.value.clone(),
84 }
85 }
86}
87
88/// Middleware for adding some shareable value to [request extensions].
89///
90/// See the [module docs](crate::add_extension) for more details.
91///
92/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
93#[derive(Clone, Copy, Debug)]
94pub struct AddExtension<S, T> {
95 inner: S,
96 value: T,
97}
98
99impl<S, T> AddExtension<S, T> {
100 /// Create a new [`AddExtension`].
101 pub fn new(inner: S, value: T) -> Self {
102 Self { inner, value }
103 }
104
105 define_inner_service_accessors!();
106
107 /// Returns a new [`Layer`] that wraps services with a `AddExtension` middleware.
108 ///
109 /// [`Layer`]: tower_async_layer::Layer
110 pub fn layer(value: T) -> AddExtensionLayer<T> {
111 AddExtensionLayer::new(value)
112 }
113}
114
115impl<ResBody, ReqBody, S, T> Service<Request<ReqBody>> for AddExtension<S, T>
116where
117 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
118 T: Clone + Send + Sync + 'static,
119{
120 type Response = S::Response;
121 type Error = S::Error;
122
123 async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
124 req.extensions_mut().insert(self.value.clone());
125 self.inner.call(req).await
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 #[allow(unused_imports)]
132 use super::*;
133
134 use crate::test_helpers::Body;
135
136 use http::Response;
137 use std::{convert::Infallible, sync::Arc};
138 use tower_async::{service_fn, ServiceBuilder, ServiceExt};
139
140 struct State(i32);
141
142 #[tokio::test]
143 async fn basic() {
144 let state = Arc::new(State(1));
145
146 let svc = ServiceBuilder::new()
147 .layer(AddExtensionLayer::new(state))
148 .service(service_fn(|req: Request<Body>| async move {
149 let state = req.extensions().get::<Arc<State>>().unwrap();
150 Ok::<_, Infallible>(Response::new(state.0))
151 }));
152
153 let res = svc
154 .oneshot(Request::new(Body::empty()))
155 .await
156 .unwrap()
157 .into_body();
158
159 assert_eq!(1, res);
160 }
161}