volo_http/utils/
extension.rs

1//! [`Extension`] support for inserting or extracting anything for contexts
2
3use motore::{layer::Layer, service::Service};
4use volo::context::Context;
5
6/// Inserting anything into contexts as a [`Layer`] or extracting anything as an extractor
7///
8/// # Examples
9///
10/// ```ignore
11/// use volo_http::{
12///     server::route::{get, Router},
13///     utils::Extension,
14/// };
15///
16/// #[derive(Clone)]
17/// struct State {
18///     foo: String,
19/// }
20///
21/// // A handler for extracting the `State` from `Extension`
22/// async fn show_state(Extension(state): Extension<State>) -> String {
23///     state.foo
24/// }
25///
26/// let router: Router = Router::new()
27///     .route("/", get(show_state))
28///     // Use `Extension` as a `Layer`
29///     .layer(Extension(State {
30///         foo: String::from("bar"),
31///     }));
32/// ```
33#[derive(Debug, Default, Clone, Copy)]
34pub struct Extension<T>(pub T);
35
36impl<S, T> Layer<S> for Extension<T>
37where
38    S: Send + Sync + 'static,
39    T: Sync,
40{
41    type Service = ExtensionService<S, T>;
42
43    fn layer(self, inner: S) -> Self::Service {
44        ExtensionService { inner, ext: self.0 }
45    }
46}
47
48/// A [`Service`] generated by [`Extension`] as a [`Layer`] for inserting something into Contexts.
49#[derive(Debug, Default, Clone, Copy)]
50pub struct ExtensionService<I, T> {
51    inner: I,
52    ext: T,
53}
54
55impl<S, Cx, Req, Resp, E, T> Service<Cx, Req> for ExtensionService<S, T>
56where
57    S: Service<Cx, Req, Response = Resp, Error = E> + Send + Sync + 'static,
58    Req: Send,
59    Cx: Context + Send,
60    T: Clone + Send + Sync + 'static,
61{
62    type Response = S::Response;
63    type Error = S::Error;
64
65    async fn call(&self, cx: &mut Cx, req: Req) -> Result<Self::Response, Self::Error> {
66        cx.extensions_mut().insert(self.ext.clone());
67        self.inner.call(cx, req).await
68    }
69}
70
71#[cfg(feature = "server")]
72mod server {
73    use http::{StatusCode, request::Parts};
74    use volo::context::Context;
75
76    use super::Extension;
77    use crate::{
78        context::ServerContext,
79        response::Response,
80        server::{IntoResponse, extract::FromContext},
81    };
82
83    impl<T> FromContext for Extension<T>
84    where
85        T: Clone + Send + Sync + 'static,
86    {
87        type Rejection = ExtensionRejection;
88
89        async fn from_context(
90            cx: &mut ServerContext,
91            _parts: &mut Parts,
92        ) -> Result<Self, Self::Rejection> {
93            cx.extensions()
94                .get::<T>()
95                .cloned()
96                .map(Extension)
97                .ok_or(ExtensionRejection::NotExist)
98        }
99    }
100
101    pub enum ExtensionRejection {
102        NotExist,
103    }
104
105    impl IntoResponse for ExtensionRejection {
106        fn into_response(self) -> Response {
107            StatusCode::INTERNAL_SERVER_ERROR.into_response()
108        }
109    }
110}