spacegate_kernel/helper_layers/
function.rs

1pub mod handler;
2
3use futures_util::future::BoxFuture;
4use futures_util::Future;
5use hyper::{service::Service, Request, Response};
6use std::{convert::Infallible, sync::Arc};
7use tower_layer::Layer;
8
9use crate::{ArcHyperService, SgBody};
10
11/// see [`FnLayer`]
12pub trait FnLayerMethod: Send + 'static {
13    fn call(&self, req: Request<SgBody>, inner: Inner) -> impl Future<Output = Response<SgBody>> + Send;
14}
15
16impl<T> FnLayerMethod for Arc<T>
17where
18    T: FnLayerMethod + std::marker::Sync,
19{
20    #[inline]
21    async fn call(&self, req: Request<SgBody>, inner: Inner) -> Response<SgBody> {
22        self.as_ref().call(req, inner).await
23    }
24}
25#[derive(Debug)]
26pub struct Handler<H, T, Fut> {
27    handler: H,
28    marker: std::marker::PhantomData<fn(T) -> Fut>,
29}
30
31impl<H, T, Fut> Clone for Handler<H, T, Fut>
32where
33    H: Clone,
34{
35    fn clone(&self) -> Self {
36        Self {
37            handler: self.handler.clone(),
38            marker: std::marker::PhantomData,
39        }
40    }
41}
42
43impl<H, T, Fut> Handler<H, T, Fut> {
44    pub const fn new(handler: H) -> Self {
45        Self {
46            handler,
47            marker: std::marker::PhantomData,
48        }
49    }
50}
51
52impl<H, T, Fut> FnLayerMethod for Handler<H, T, Fut>
53where
54    T: 'static,
55    H: handler::HandlerFn<T, Fut> + Send + Clone + 'static,
56    Fut: Future<Output = Response<SgBody>> + Send + 'static,
57{
58    #[inline]
59    fn call(&self, req: Request<SgBody>, inner: Inner) -> impl Future<Output = Response<SgBody>> + Send {
60        (self.handler).apply(req, inner)
61    }
62}
63
64/// see [`FnLayer`]
65#[derive(Debug)]
66pub struct Closure<F, Fut>
67where
68    F: Fn(Request<SgBody>, Inner) -> Fut + Send + Sync + Clone + 'static,
69    Fut: Future<Output = Response<SgBody>> + Send + 'static,
70{
71    pub f: F,
72}
73
74impl<F, Fut> Closure<F, Fut>
75where
76    F: Fn(Request<SgBody>, Inner) -> Fut + Send + Sync + Clone + 'static,
77    Fut: Future<Output = Response<SgBody>> + Send + 'static,
78{
79    pub const fn new(f: F) -> Self {
80        Self { f }
81    }
82}
83
84impl<F, Fut> From<F> for Closure<F, Fut>
85where
86    F: Fn(Request<SgBody>, Inner) -> Fut + Send + Sync + Clone + 'static,
87    Fut: Future<Output = Response<SgBody>> + Send + 'static,
88{
89    fn from(value: F) -> Self {
90        Closure { f: value }
91    }
92}
93
94impl<F, Fut> Clone for Closure<F, Fut>
95where
96    F: Fn(Request<SgBody>, Inner) -> Fut + Send + Sync + Clone + 'static,
97    Fut: Future<Output = Response<SgBody>> + Send + 'static,
98{
99    fn clone(&self) -> Self {
100        Self { f: self.f.clone() }
101    }
102}
103
104impl<F, Fut> FnLayerMethod for Closure<F, Fut>
105where
106    F: Fn(Request<SgBody>, Inner) -> Fut + Send + Sync + Clone + 'static,
107    Fut: Future<Output = Response<SgBody>> + Send + 'static,
108{
109    #[inline]
110    async fn call(&self, req: Request<SgBody>, inner: Inner) -> Response<SgBody> {
111        (self.f)(req, inner).await
112    }
113}
114
115/// A functional layer
116///
117/// This is an example of how to create a layer that adds a header to the response:
118/// ```
119/// # use spacegate_kernel::helper_layers::function::FnLayer;
120/// # use hyper::http::header::HeaderValue;
121/// let layer = FnLayer::new_closure(move |req, inner| {
122///    async move {
123///        let mut resp = inner.call(req).await;
124///        resp.headers_mut().insert("server", HeaderValue::from_static("potato"));
125///        resp
126///    }
127/// });
128/// ```
129///
130/// Or you can use a struct that implements `FnLayerMethod`:
131/// ```
132/// # use spacegate_kernel::{helper_layers::function::{FnLayer, FnLayerMethod, Inner}, SgRequest, SgResponse};
133/// # use hyper::http::header::HeaderValue;
134/// struct MyPlugin;
135/// impl FnLayerMethod for MyPlugin {
136///    async fn call(&self, req: SgRequest, inner: Inner) -> SgResponse {
137///       let mut resp = inner.call(req).await;
138///       resp.headers_mut().insert("server", HeaderValue::from_static("potato"));
139///       resp
140///    }
141/// }
142/// let layer = FnLayer::new(MyPlugin);
143/// ```
144#[derive(Debug, Clone)]
145pub struct FnLayer<M> {
146    method: M,
147}
148
149impl<M> FnLayer<M> {
150    pub const fn new(method: M) -> Self {
151        Self { method }
152    }
153}
154
155impl<F, Fut> FnLayer<Closure<F, Fut>>
156where
157    F: Fn(Request<SgBody>, Inner) -> Fut + Send + Sync + Clone + 'static,
158    Fut: Future<Output = Response<SgBody>> + Send + 'static,
159{
160    pub const fn new_closure(f: F) -> Self {
161        Self::new(Closure::new(f))
162    }
163}
164
165impl<H, F, Fut> FnLayer<Handler<H, F, Fut>>
166where
167    Handler<H, F, Fut>: FnLayerMethod,
168{
169    pub const fn new_handler(h: H) -> Self {
170        Self::new(Handler::new(h))
171    }
172}
173
174impl<M, S> Layer<S> for FnLayer<M>
175where
176    M: FnLayerMethod + Clone,
177    S: Service<Request<SgBody>, Error = Infallible, Response = Response<SgBody>> + Send + Sync + Clone + 'static,
178    <S as Service<Request<SgBody>>>::Future: Future<Output = Result<Response<SgBody>, Infallible>> + 'static + Send,
179{
180    type Service = FnService<M>;
181
182    fn layer(&self, inner: S) -> Self::Service {
183        FnService {
184            m: self.method.clone(),
185            inner: ArcHyperService::new(inner),
186        }
187    }
188}
189
190/// The corresponded server for [`FnLayer`]
191#[derive(Debug, Clone)]
192pub struct FnService<M> {
193    m: M,
194    inner: ArcHyperService,
195}
196
197impl<M> Service<Request<SgBody>> for FnService<M>
198where
199    M: FnLayerMethod + Clone,
200{
201    type Response = Response<SgBody>;
202
203    type Error = Infallible;
204
205    type Future = BoxFuture<'static, Result<Response<SgBody>, Infallible>>;
206
207    #[inline]
208    fn call(&self, req: Request<SgBody>) -> Self::Future {
209        let next = Inner { inner: self.inner.clone() };
210        let method = self.m.clone();
211        Box::pin(async move { Ok(method.call(req, next).await) })
212    }
213}
214
215/// A shared hyper service wrapper
216#[derive(Debug, Clone)]
217pub struct Inner {
218    inner: ArcHyperService,
219}
220
221impl Inner {
222    #[inline]
223    pub fn new(inner: ArcHyperService) -> Self {
224        Inner { inner }
225    }
226
227    /// Call the inner service and get the response
228    #[inline]
229    pub async fn call(self, req: Request<SgBody>) -> Response<SgBody> {
230        // just infallible
231        unsafe { self.inner.call(req).await.unwrap_unchecked() }
232    }
233
234    #[inline]
235    /// Unwrap the inner service
236    pub fn into_inner(self) -> ArcHyperService {
237        self.inner
238    }
239}
240
241#[cfg(test)]
242mod test {
243    use std::{collections::HashMap, sync::Arc};
244
245    use hyper::{header::HeaderValue, Method, StatusCode, Uri};
246    #[derive(Debug, Default, Clone)]
247    pub struct MyPlugin {
248        status_message: HashMap<StatusCode, String>,
249    }
250
251    impl FnLayerMethod for MyPlugin {
252        async fn call(&self, req: Request<SgBody>, inner: Inner) -> Response<SgBody> {
253            let host = req.headers().get("host");
254            if let Some(Ok(host)) = host.map(HeaderValue::to_str) {
255                println!("{host}");
256            }
257            let resp = inner.call(req).await;
258            if let Some(message) = self.status_message.get(&resp.status()) {
259                println!("{message}");
260            }
261            resp
262        }
263    }
264    use crate::{BoxLayer, Extract};
265
266    use super::*;
267    #[test]
268    fn test_fn_layer() {
269        let status_message = Arc::new(<HashMap<StatusCode, String>>::default());
270        let boxed_layer = BoxLayer::new(FnLayer::new(MyPlugin::default()));
271        let boxed_layer2 = BoxLayer::new(FnLayer::new_closure(move |req, inner| {
272            let host = req.headers().get("host");
273            if let Some(Ok(host)) = host.map(HeaderValue::to_str) {
274                println!("{host}");
275            }
276            let status_message = status_message.clone();
277            async move {
278                let resp = inner.call(req).await;
279                if let Some(message) = status_message.get(&resp.status()) {
280                    println!("{message}");
281                }
282                resp
283            }
284        }));
285        #[derive(Debug, Clone)]
286        struct Server {
287            #[allow(dead_code)]
288            pub name: String,
289        }
290        impl Extract for Option<Server> {
291            fn extract(req: &Request<SgBody>) -> Self {
292                let host = req.headers().get("server");
293                if let Some(Ok(host)) = host.map(HeaderValue::to_str) {
294                    Some(Server { name: host.to_string() })
295                } else {
296                    None
297                }
298            }
299        }
300
301        async fn custom_handler(req: Request<SgBody>, inner: Inner, uri: Uri, method: Method, server: Option<Server>) -> Response<SgBody> {
302            tokio::spawn(async move { println!("{method} // {uri} // {server:?}") });
303            inner.call(req).await
304        }
305        async fn empty_custom_handler(req: Request<SgBody>, inner: Inner) -> Response<SgBody> {
306            inner.call(req).await
307        }
308        let boxed_layer3 = BoxLayer::new(FnLayer::new_handler(custom_handler));
309        let boxed_layer_empty = BoxLayer::new(FnLayer::new_handler(empty_custom_handler));
310        drop(boxed_layer);
311        drop(boxed_layer2);
312        drop(boxed_layer3);
313        drop(boxed_layer_empty);
314    }
315}