Skip to main content

sui_http/middleware/callback/
service.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CallbackLayer;
5use super::MakeCallbackHandler;
6use super::RequestBody;
7use super::ResponseBody;
8use super::ResponseFuture;
9use http::Request;
10use http::Response;
11use std::task::Context;
12use std::task::Poll;
13use tower::Service;
14
15/// Middleware that adds callbacks to a [`Service`].
16///
17/// See the [module docs](crate::middleware::callback) for an example.
18///
19/// [`Service`]: tower::Service
20#[derive(Debug, Clone, Copy)]
21pub struct Callback<S, M> {
22    pub(crate) inner: S,
23    pub(crate) make_callback_handler: M,
24}
25
26impl<S, M> Callback<S, M> {
27    /// Create a new [`Callback`].
28    pub fn new(inner: S, make_callback_handler: M) -> Self {
29        Self {
30            inner,
31            make_callback_handler,
32        }
33    }
34
35    /// Returns a new [`Layer`] that wraps services with a [`CallbackLayer`] middleware.
36    ///
37    /// [`Layer`]: tower::layer::Layer
38    pub fn layer(make_handler: M) -> CallbackLayer<M>
39    where
40        M: MakeCallbackHandler,
41    {
42        CallbackLayer::new(make_handler)
43    }
44
45    /// Gets a reference to the underlying service.
46    pub fn inner(&self) -> &S {
47        &self.inner
48    }
49
50    /// Gets a mutable reference to the underlying service.
51    pub fn inner_mut(&mut self) -> &mut S {
52        &mut self.inner
53    }
54
55    /// Consumes `self`, returning the underlying service.
56    pub fn into_inner(self) -> S {
57        self.inner
58    }
59}
60
61impl<S, M, ReqBody, ResponseBodyT> Service<Request<ReqBody>> for Callback<S, M>
62where
63    S: Service<
64            Request<RequestBody<ReqBody, M::RequestHandler>>,
65            Response = Response<ResponseBodyT>,
66            Error: std::fmt::Display + 'static,
67        >,
68    M: MakeCallbackHandler,
69    ReqBody: http_body::Body<Error: std::fmt::Display + 'static>,
70    ResponseBodyT: http_body::Body<Error: std::fmt::Display + 'static>,
71{
72    type Response = Response<ResponseBody<ResponseBodyT, M::ResponseHandler>>;
73    type Error = S::Error;
74    type Future = ResponseFuture<S::Future, M::ResponseHandler>;
75
76    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
77        self.inner.poll_ready(cx)
78    }
79
80    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
81        let (head, body) = request.into_parts();
82        let (req_handler, resp_handler) = self.make_callback_handler.make_handler(&head);
83        let wrapped_body = RequestBody {
84            inner: body,
85            handler: req_handler,
86            ended: false,
87        };
88        let request = Request::from_parts(head, wrapped_body);
89
90        ResponseFuture {
91            inner: self.inner.call(request),
92            handler: Some(resp_handler),
93        }
94    }
95}