synapse_sdk/
middleware.rs1use anyhow::Result;
6use async_trait::async_trait;
7use bytes::Bytes;
8use std::{any::Any, collections::HashMap, sync::Arc};
9use synapse_proto::{HeaderEntry, RpcRequest};
10
11pub struct RequestState<T = ()> {
16 pub headers: Vec<HeaderEntry>,
18
19 pub sent_at_unix_ms: i64,
21
22 pub interface_id: u32,
24
25 pub method_id: u32,
27
28 pub custom: T,
30
31 extensions: HashMap<String, Box<dyn Any + Send + Sync>>,
33}
34
35impl<T> RequestState<T> {
36 pub fn from_request(request: &RpcRequest, custom: T) -> Self {
38 Self {
39 headers: request.headers.clone(),
40 sent_at_unix_ms: request.sent_at_unix_ms,
41 interface_id: request.interface_id,
42 method_id: request.method_id,
43 custom,
44 extensions: HashMap::new(),
45 }
46 }
47
48 pub fn set_extension<V: Any + Send + Sync>(&mut self, key: impl Into<String>, value: V) {
50 self.extensions.insert(key.into(), Box::new(value));
51 }
52
53 pub fn get_extension<V: Any + Send + Sync>(&self, key: &str) -> Option<&V> {
55 self.extensions.get(key).and_then(|v| v.downcast_ref::<V>())
56 }
57
58 pub fn get_extension_mut<V: Any + Send + Sync>(&mut self, key: &str) -> Option<&mut V> {
60 self.extensions
61 .get_mut(key)
62 .and_then(|v| v.downcast_mut::<V>())
63 }
64
65 pub fn get_header(&self, key: u32) -> Option<&HeaderEntry> {
67 self.headers.iter().find(|h| h.key == key)
68 }
69}
70
71impl RequestState<()> {
72 pub fn new(request: &RpcRequest) -> Self {
74 Self::from_request(request, ())
75 }
76}
77
78#[async_trait]
83pub trait Middleware<T>: Send + Sync {
84 async fn process(&self, state: &mut RequestState<T>, payload: &Bytes) -> Result<()>;
89}
90
91pub struct MiddlewareChain<T> {
95 middlewares: Vec<Arc<dyn Middleware<T>>>,
96}
97
98impl<T> MiddlewareChain<T> {
99 pub fn new() -> Self {
101 Self {
102 middlewares: Vec::new(),
103 }
104 }
105
106 pub fn add(&mut self, middleware: Arc<dyn Middleware<T>>) {
108 self.middlewares.push(middleware);
109 }
110
111 pub async fn process(&self, state: &mut RequestState<T>, payload: &Bytes) -> Result<()> {
113 for middleware in &self.middlewares {
114 middleware.process(state, payload).await?;
115 }
116 Ok(())
117 }
118
119 pub fn is_empty(&self) -> bool {
121 self.middlewares.is_empty()
122 }
123}
124
125impl<T> Default for MiddlewareChain<T> {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use std::sync::atomic::{AtomicUsize, Ordering};
135
136 fn test_request() -> RpcRequest {
137 RpcRequest {
138 interface_id: 1,
139 method_id: 2,
140 headers: vec![HeaderEntry {
141 key: 100,
142 value: Some(synapse_proto::header_entry::Value::StringValue(
143 "test".to_string(),
144 )),
145 }],
146 payload: Bytes::new(),
147 sent_at_unix_ms: 12345,
148 }
149 }
150
151 #[test]
154 fn test_request_state_from_request() {
155 let req = test_request();
156 let state = RequestState::from_request(&req, "custom");
157 assert_eq!(state.interface_id, 1);
158 assert_eq!(state.method_id, 2);
159 assert_eq!(state.sent_at_unix_ms, 12345);
160 assert_eq!(state.custom, "custom");
161 assert_eq!(state.headers.len(), 1);
162 }
163
164 #[test]
165 fn test_request_state_new_unit() {
166 let req = test_request();
167 let state = RequestState::new(&req);
168 assert_eq!(state.custom, ());
169 assert_eq!(state.interface_id, 1);
170 }
171
172 #[test]
173 fn test_extensions_set_and_get() {
174 let req = test_request();
175 let mut state = RequestState::new(&req);
176 state.set_extension("user_id", 42u64);
177 assert_eq!(state.get_extension::<u64>("user_id"), Some(&42));
178 }
179
180 #[test]
181 fn test_extensions_wrong_type() {
182 let req = test_request();
183 let mut state = RequestState::new(&req);
184 state.set_extension("user_id", 42u64);
185 assert!(state.get_extension::<String>("user_id").is_none());
186 }
187
188 #[test]
189 fn test_extensions_missing_key() {
190 let req = test_request();
191 let state = RequestState::new(&req);
192 assert!(state.get_extension::<u64>("nonexistent").is_none());
193 }
194
195 #[test]
196 fn test_extensions_get_mut() {
197 let req = test_request();
198 let mut state = RequestState::new(&req);
199 state.set_extension("counter", 0u32);
200 if let Some(val) = state.get_extension_mut::<u32>("counter") {
201 *val = 5;
202 }
203 assert_eq!(state.get_extension::<u32>("counter"), Some(&5));
204 }
205
206 #[test]
207 fn test_get_header_found() {
208 let req = test_request();
209 let state = RequestState::new(&req);
210 let header = state.get_header(100);
211 assert!(header.is_some());
212 assert_eq!(header.unwrap().key, 100);
213 }
214
215 #[test]
216 fn test_get_header_not_found() {
217 let req = test_request();
218 let state = RequestState::new(&req);
219 assert!(state.get_header(999).is_none());
220 }
221
222 struct CountingMiddleware {
225 counter: Arc<AtomicUsize>,
226 }
227
228 #[async_trait]
229 impl Middleware<()> for CountingMiddleware {
230 async fn process(&self, _state: &mut RequestState<()>, _payload: &Bytes) -> Result<()> {
231 self.counter.fetch_add(1, Ordering::SeqCst);
232 Ok(())
233 }
234 }
235
236 struct FailingMiddleware;
237
238 #[async_trait]
239 impl Middleware<()> for FailingMiddleware {
240 async fn process(&self, _state: &mut RequestState<()>, _payload: &Bytes) -> Result<()> {
241 Err(anyhow::anyhow!("middleware failed"))
242 }
243 }
244
245 #[test]
246 fn test_chain_new_is_empty() {
247 let chain = MiddlewareChain::<()>::new();
248 assert!(chain.is_empty());
249 }
250
251 #[test]
252 fn test_chain_default_is_empty() {
253 let chain = MiddlewareChain::<()>::default();
254 assert!(chain.is_empty());
255 }
256
257 #[test]
258 fn test_chain_not_empty_after_add() {
259 let mut chain = MiddlewareChain::<()>::new();
260 let counter = Arc::new(AtomicUsize::new(0));
261 chain.add(Arc::new(CountingMiddleware {
262 counter: counter.clone(),
263 }));
264 assert!(!chain.is_empty());
265 }
266
267 #[tokio::test]
268 async fn test_chain_executes_middleware() {
269 let mut chain = MiddlewareChain::<()>::new();
270 let counter = Arc::new(AtomicUsize::new(0));
271 chain.add(Arc::new(CountingMiddleware {
272 counter: counter.clone(),
273 }));
274
275 let req = test_request();
276 let mut state = RequestState::new(&req);
277 chain.process(&mut state, &Bytes::new()).await.unwrap();
278 assert_eq!(counter.load(Ordering::SeqCst), 1);
279 }
280
281 #[tokio::test]
282 async fn test_chain_executes_in_order() {
283 let mut chain = MiddlewareChain::<()>::new();
284 let c1 = Arc::new(AtomicUsize::new(0));
285 let c2 = Arc::new(AtomicUsize::new(0));
286 chain.add(Arc::new(CountingMiddleware {
287 counter: c1.clone(),
288 }));
289 chain.add(Arc::new(CountingMiddleware {
290 counter: c2.clone(),
291 }));
292
293 let req = test_request();
294 let mut state = RequestState::new(&req);
295 chain.process(&mut state, &Bytes::new()).await.unwrap();
296 assert_eq!(c1.load(Ordering::SeqCst), 1);
297 assert_eq!(c2.load(Ordering::SeqCst), 1);
298 }
299
300 #[tokio::test]
301 async fn test_chain_short_circuits_on_error() {
302 let mut chain = MiddlewareChain::<()>::new();
303 let counter = Arc::new(AtomicUsize::new(0));
304 chain.add(Arc::new(FailingMiddleware));
305 chain.add(Arc::new(CountingMiddleware {
306 counter: counter.clone(),
307 }));
308
309 let req = test_request();
310 let mut state = RequestState::new(&req);
311 let result = chain.process(&mut state, &Bytes::new()).await;
312 assert!(result.is_err());
313 assert_eq!(counter.load(Ordering::SeqCst), 0);
315 }
316}