signway_server/
gateway_callbacks.rs1use std::fmt::{Display, Formatter};
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use http_body_util::Full;
6use hyper::http::{request, response};
7use hyper::Response;
8use url::Url;
9
10pub enum CallbackResult {
11 EarlyResponse(Response<Full<Bytes>>),
12 Empty,
13}
14
15#[async_trait]
16pub trait OnRequest: Sync + Send {
17 async fn call(&self, id: &str, req: &request::Parts) -> CallbackResult;
18}
19
20#[async_trait]
21pub trait OnSuccess: Sync + Send {
22 async fn call(&self, id: &str, res: &response::Parts) -> CallbackResult;
23}
24
25#[derive(Debug, Clone)]
26pub enum BytesTransferredKind {
27 In,
28 Out,
29}
30
31impl Display for BytesTransferredKind {
32 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
33 write!(
34 f,
35 "{}",
36 match self {
37 BytesTransferredKind::In => "IN",
38 BytesTransferredKind::Out => "OUT",
39 }
40 )
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct BytesTransferredInfo {
46 pub id: String,
47 pub proxy_url: Url,
48 pub bytes: usize,
49 pub kind: BytesTransferredKind,
50}
51
52#[cfg(test)]
53mod tests {
54 use std::collections::HashMap;
55 use std::str::FromStr;
56 use std::sync::atomic::AtomicU64;
57 use std::sync::atomic::Ordering::SeqCst;
58
59 use async_trait::async_trait;
60 use hyper::http::{request, response};
61 use hyper::{Request, StatusCode};
62
63 use crate::_test_tools::tests::{InMemorySecretGetter, ReqBuilder};
64 use crate::gateway_callbacks::{CallbackResult, OnRequest, OnSuccess};
65 use crate::sw_body::SwBody;
66 use crate::{HeaderMap, SecretGetterResult, SignwayServer};
67
68 fn server() -> SignwayServer {
69 SignwayServer::from_env(InMemorySecretGetter(HashMap::from([(
70 "foo".to_string(),
71 SecretGetterResult {
72 secret: "bar".to_string(),
73 headers_extension: HeaderMap::new(),
74 },
75 )])))
76 }
77
78 fn req() -> Request<SwBody> {
79 ReqBuilder::default()
80 .query("page", "1")
81 .header("Content-Length", "3")
82 .post()
83 .sign("foo", "bar", "http://localhost:3000")
84 .unwrap()
85 .body("foo")
86 .build()
87 .unwrap()
88 }
89
90 struct SizeCollector<'a>(&'a AtomicU64);
91
92 #[async_trait]
93 impl<'a> OnRequest for SizeCollector<'a> {
94 async fn call(&self, _id: &str, req: &request::Parts) -> CallbackResult {
95 let size: &str = req.headers.get("content-length").unwrap().to_str().unwrap();
96 self.0.fetch_add(u64::from_str(size).unwrap(), SeqCst);
97 CallbackResult::Empty
98 }
99 }
100
101 #[async_trait]
102 impl<'a> OnSuccess for SizeCollector<'a> {
103 async fn call(&self, _id: &str, res: &response::Parts) -> CallbackResult {
104 let size: &str = res.headers.get("content-length").unwrap().to_str().unwrap();
105 self.0.fetch_add(u64::from_str(size).unwrap(), SeqCst);
106 CallbackResult::Empty
107 }
108 }
109
110 #[tokio::test]
111 async fn test_on_request() {
112 static COUNTER: AtomicU64 = AtomicU64::new(0);
113 let size_collector = SizeCollector(&COUNTER);
114
115 let response = server()
116 .on_request(size_collector)
117 .handler(req())
118 .await
119 .unwrap();
120
121 assert_eq!(response.status(), StatusCode::OK);
122 assert_eq!(COUNTER.load(SeqCst), 3);
123 }
124
125 #[tokio::test]
126 async fn test_on_success() {
127 static COUNTER: AtomicU64 = AtomicU64::new(0);
128 let size_collector = SizeCollector(&COUNTER);
129
130 let response = server()
131 .on_success(size_collector)
132 .handler(req())
133 .await
134 .unwrap();
135
136 assert_eq!(response.status(), StatusCode::OK);
137 assert_eq!(COUNTER.load(SeqCst), 396);
138 }
139}