signway_server/
gateway_callbacks.rs

1use std::fmt::{Display, Formatter};
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use http_body_util::Full;
6use hyper::http::{request, response};
7use hyper::Response;
8use url::Url;
9
10pub enum CallbackResult {
11    EarlyResponse(Response<Full<Bytes>>),
12    Empty,
13}
14
15#[async_trait]
16pub trait OnRequest: Sync + Send {
17    async fn call(&self, id: &str, req: &request::Parts) -> CallbackResult;
18}
19
20#[async_trait]
21pub trait OnSuccess: Sync + Send {
22    async fn call(&self, id: &str, res: &response::Parts) -> CallbackResult;
23}
24
25#[derive(Debug, Clone)]
26pub enum BytesTransferredKind {
27    In,
28    Out,
29}
30
31impl Display for BytesTransferredKind {
32    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
33        write!(
34            f,
35            "{}",
36            match self {
37                BytesTransferredKind::In => "IN",
38                BytesTransferredKind::Out => "OUT",
39            }
40        )
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct BytesTransferredInfo {
46    pub id: String,
47    pub proxy_url: Url,
48    pub bytes: usize,
49    pub kind: BytesTransferredKind,
50}
51
52#[cfg(test)]
53mod tests {
54    use std::collections::HashMap;
55    use std::str::FromStr;
56    use std::sync::atomic::AtomicU64;
57    use std::sync::atomic::Ordering::SeqCst;
58
59    use async_trait::async_trait;
60    use hyper::http::{request, response};
61    use hyper::{Request, StatusCode};
62
63    use crate::_test_tools::tests::{InMemorySecretGetter, ReqBuilder};
64    use crate::gateway_callbacks::{CallbackResult, OnRequest, OnSuccess};
65    use crate::sw_body::SwBody;
66    use crate::{HeaderMap, SecretGetterResult, SignwayServer};
67
68    fn server() -> SignwayServer {
69        SignwayServer::from_env(InMemorySecretGetter(HashMap::from([(
70            "foo".to_string(),
71            SecretGetterResult {
72                secret: "bar".to_string(),
73                headers_extension: HeaderMap::new(),
74            },
75        )])))
76    }
77
78    fn req() -> Request<SwBody> {
79        ReqBuilder::default()
80            .query("page", "1")
81            .header("Content-Length", "3")
82            .post()
83            .sign("foo", "bar", "http://localhost:3000")
84            .unwrap()
85            .body("foo")
86            .build()
87            .unwrap()
88    }
89
90    struct SizeCollector<'a>(&'a AtomicU64);
91
92    #[async_trait]
93    impl<'a> OnRequest for SizeCollector<'a> {
94        async fn call(&self, _id: &str, req: &request::Parts) -> CallbackResult {
95            let size: &str = req.headers.get("content-length").unwrap().to_str().unwrap();
96            self.0.fetch_add(u64::from_str(size).unwrap(), SeqCst);
97            CallbackResult::Empty
98        }
99    }
100
101    #[async_trait]
102    impl<'a> OnSuccess for SizeCollector<'a> {
103        async fn call(&self, _id: &str, res: &response::Parts) -> CallbackResult {
104            let size: &str = res.headers.get("content-length").unwrap().to_str().unwrap();
105            self.0.fetch_add(u64::from_str(size).unwrap(), SeqCst);
106            CallbackResult::Empty
107        }
108    }
109
110    #[tokio::test]
111    async fn test_on_request() {
112        static COUNTER: AtomicU64 = AtomicU64::new(0);
113        let size_collector = SizeCollector(&COUNTER);
114
115        let response = server()
116            .on_request(size_collector)
117            .handler(req())
118            .await
119            .unwrap();
120
121        assert_eq!(response.status(), StatusCode::OK);
122        assert_eq!(COUNTER.load(SeqCst), 3);
123    }
124
125    #[tokio::test]
126    async fn test_on_success() {
127        static COUNTER: AtomicU64 = AtomicU64::new(0);
128        let size_collector = SizeCollector(&COUNTER);
129
130        let response = server()
131            .on_success(size_collector)
132            .handler(req())
133            .await
134            .unwrap();
135
136        assert_eq!(response.status(), StatusCode::OK);
137        assert_eq!(COUNTER.load(SeqCst), 396);
138    }
139}