1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
use std::convert::Infallible;

use hyper::{body::Incoming, StatusCode};
use motore::{layer::Layer, service::Service};

use crate::{extract::FromContext, response::IntoResponse, HttpContext, Response};

#[derive(Debug, Default, Clone, Copy)]
pub struct Extension<T>(pub T);

impl<S, T> Layer<S> for Extension<T>
where
    S: Service<HttpContext, Incoming, Response = Response> + Send + Sync + 'static,
    T: Sync,
{
    type Service = ExtensionService<S, T>;

    fn layer(self, inner: S) -> Self::Service {
        ExtensionService { inner, ext: self.0 }
    }
}

#[derive(Debug, Default, Clone, Copy)]
pub struct ExtensionService<I, T> {
    inner: I,
    ext: T,
}

impl<S, T> Service<HttpContext, Incoming> for ExtensionService<S, T>
where
    S: Service<HttpContext, Incoming, Response = Response, Error = Infallible>
        + Send
        + Sync
        + 'static,
    T: Clone + Send + Sync + 'static,
{
    type Response = S::Response;
    type Error = S::Error;

    async fn call<'s, 'cx>(
        &'s self,
        cx: &'cx mut HttpContext,
        req: Incoming,
    ) -> Result<Self::Response, Self::Error> {
        cx.extensions.insert(self.ext.clone());
        self.inner.call(cx, req).await
    }
}

impl<T, S> FromContext<S> for Extension<T>
where
    T: Clone + Send + Sync + 'static,
    S: Sync,
{
    type Rejection = ExtensionRejection;

    async fn from_context(cx: &mut HttpContext, _state: &S) -> Result<Self, Self::Rejection> {
        cx.extensions
            .get::<T>()
            .map(T::clone)
            .map(Extension)
            .ok_or(ExtensionRejection::NotExist)
    }
}

pub enum ExtensionRejection {
    NotExist,
}

impl IntoResponse for ExtensionRejection {
    fn into_response(self) -> Response {
        StatusCode::INTERNAL_SERVER_ERROR.into_response()
    }
}