1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
//! An adapter that makes a tower [`Service`] into a [`Handler`].

use tower::{Service, ServiceExt};
use viz_core::{Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result};

mod service;
pub use service::HandlerService;

mod middleware;
pub use middleware::Middleware;

mod layer;
pub use layer::Layered;

/// Converts a tower [`Service`] into a [`Handler`].
#[derive(Debug, Clone)]
pub struct ServiceHandler<S>(S);

impl<S> ServiceHandler<S> {
    /// Creates a new [`ServiceHandler`].
    pub const fn new(s: S) -> Self {
        Self(s)
    }
}

#[viz_core::async_trait]
impl<O, S> Handler<Request> for ServiceHandler<S>
where
    O: HttpBody + Send + 'static,
    O::Data: Into<Bytes>,
    O::Error: Into<BoxError>,
    S: Service<Request, Response = Response<O>> + Send + Sync + Clone + 'static,
    S::Future: Send,
    S::Error: Into<BoxError>,
{
    type Output = Result<Response>;

    async fn call(&self, req: Request) -> Self::Output {
        self.0
            .clone()
            .oneshot(req)
            .await
            .map_err(Error::boxed)
            .map(|resp| resp.map(Body::wrap))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::{
        sync::{
            atomic::{AtomicU64, Ordering},
            Arc,
        },
        time::Duration,
    };
    use tower::util::{MapErrLayer, MapRequestLayer, MapResponseLayer};
    use tower::{service_fn, ServiceBuilder};
    use tower_http::{
        limit::RequestBodyLimitLayer,
        request_id::{MakeRequestId, RequestId, SetRequestIdLayer},
        timeout::TimeoutLayer,
    };
    use viz_core::{
        Body, BoxHandler, Handler, HandlerExt, IntoResponse, Request, RequestExt, Response,
    };

    #[derive(Clone, Default, Debug)]
    struct MyMakeRequestId {
        counter: Arc<AtomicU64>,
    }

    impl MakeRequestId for MyMakeRequestId {
        fn make_request_id<B>(&mut self, _: &Request<B>) -> Option<RequestId> {
            let request_id = self
                .counter
                .fetch_add(1, Ordering::SeqCst)
                .to_string()
                .parse()
                .unwrap();

            Some(RequestId::new(request_id))
        }
    }

    async fn hello(mut req: Request) -> Result<Response> {
        let bytes = req.bytes().await?;
        Ok(bytes.into_response())
    }

    #[tokio::test]
    async fn tower_service_into_handler() {
        let hello_svc = service_fn(hello);

        let svc = ServiceBuilder::new()
            .layer(RequestBodyLimitLayer::new(1))
            .layer(MapErrLayer::new(Error::from))
            .layer(SetRequestIdLayer::x_request_id(MyMakeRequestId::default()))
            .layer(MapResponseLayer::new(IntoResponse::into_response))
            .layer(MapRequestLayer::new(|req: Request<_>| req.map(Body::wrap)))
            .layer(TimeoutLayer::new(Duration::from_secs(10)))
            .service(hello_svc);

        let r0 = Request::new(Body::Full("12".into()));
        let h0 = ServiceHandler::new(svc);
        assert!(h0.call(r0).await.is_err());

        let r1 = Request::new(Body::Full("1".into()));
        let b0: BoxHandler = h0.boxed();
        assert!(b0.call(r1).await.is_ok());
    }
}