turul_http_mcp_server/middleware/
stack.rs1use super::{McpMiddleware, RequestContext, SessionInjection, DispatcherResult, MiddlewareError};
4use turul_mcp_session_storage::SessionView;
5use std::sync::Arc;
6
7#[derive(Default, Clone)]
50pub struct MiddlewareStack {
51 middleware: Vec<Arc<dyn McpMiddleware>>,
52}
53
54impl MiddlewareStack {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn push(&mut self, middleware: Arc<dyn McpMiddleware>) {
71 self.middleware.push(middleware);
72 }
73
74 pub fn len(&self) -> usize {
76 self.middleware.len()
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.middleware.is_empty()
82 }
83
84 pub async fn execute_before(
104 &self,
105 ctx: &mut RequestContext<'_>,
106 session: Option<&dyn SessionView>,
107 ) -> Result<SessionInjection, MiddlewareError> {
108 let mut combined_injection = SessionInjection::new();
109
110 for middleware in &self.middleware {
111 let mut injection = SessionInjection::new();
112 middleware.before_dispatch(ctx, session, &mut injection).await?;
113
114 for (key, value) in injection.state() {
116 combined_injection.set_state(key.clone(), value.clone());
117 }
118 for (key, value) in injection.metadata() {
119 combined_injection.set_metadata(key.clone(), value.clone());
120 }
121 }
122
123 Ok(combined_injection)
124 }
125
126 pub async fn execute_after(
144 &self,
145 ctx: &RequestContext<'_>,
146 result: &mut DispatcherResult,
147 ) -> Result<(), MiddlewareError> {
148 for middleware in self.middleware.iter().rev() {
150 middleware.after_dispatch(ctx, result).await?;
151 }
152
153 Ok(())
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use async_trait::async_trait;
161 use serde_json::json;
162
163 struct CountingMiddleware {
164 id: String,
165 counter: Arc<std::sync::Mutex<Vec<String>>>,
166 }
167
168 #[async_trait]
169 impl McpMiddleware for CountingMiddleware {
170 async fn before_dispatch(
171 &self,
172 _ctx: &mut RequestContext<'_>,
173 _session: Option<&dyn SessionView>,
174 injection: &mut SessionInjection,
175 ) -> Result<(), MiddlewareError> {
176 self.counter.lock().unwrap().push(format!("before_{}", self.id));
177 injection.set_state(&self.id, json!(true));
178 Ok(())
179 }
180
181 async fn after_dispatch(
182 &self,
183 _ctx: &RequestContext<'_>,
184 _result: &mut DispatcherResult,
185 ) -> Result<(), MiddlewareError> {
186 self.counter.lock().unwrap().push(format!("after_{}", self.id));
187 Ok(())
188 }
189 }
190
191 struct ErrorMiddleware {
192 error_on_before: bool,
193 }
194
195 #[async_trait]
196 impl McpMiddleware for ErrorMiddleware {
197 async fn before_dispatch(
198 &self,
199 _ctx: &mut RequestContext<'_>,
200 _session: Option<&dyn SessionView>,
201 _injection: &mut SessionInjection,
202 ) -> Result<(), MiddlewareError> {
203 if self.error_on_before {
204 Err(MiddlewareError::unauthorized("Test error"))
205 } else {
206 Ok(())
207 }
208 }
209 }
210
211 #[tokio::test]
212 async fn test_middleware_execution_order() {
213 let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
214 let mut stack = MiddlewareStack::new();
215
216 stack.push(Arc::new(CountingMiddleware {
217 id: "first".to_string(),
218 counter: counter.clone(),
219 }));
220 stack.push(Arc::new(CountingMiddleware {
221 id: "second".to_string(),
222 counter: counter.clone(),
223 }));
224
225 let mut ctx = RequestContext::new("test/method", None);
226
227 let injection = stack.execute_before(&mut ctx, None).await.unwrap();
229 assert_eq!(injection.state().len(), 2);
230 assert!(injection.state().contains_key("first"));
231 assert!(injection.state().contains_key("second"));
232
233 let mut result = DispatcherResult::Success(json!({"ok": true}));
235 stack.execute_after(&ctx, &mut result).await.unwrap();
236
237 let log = counter.lock().unwrap();
239 assert_eq!(log[0], "before_first");
240 assert_eq!(log[1], "before_second");
241 assert_eq!(log[2], "after_second"); assert_eq!(log[3], "after_first");
243 }
244
245 #[tokio::test]
246 async fn test_middleware_error_stops_chain() {
247 let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
248 let mut stack = MiddlewareStack::new();
249
250 stack.push(Arc::new(CountingMiddleware {
251 id: "first".to_string(),
252 counter: counter.clone(),
253 }));
254 stack.push(Arc::new(ErrorMiddleware {
255 error_on_before: true,
256 }));
257 stack.push(Arc::new(CountingMiddleware {
258 id: "third".to_string(),
259 counter: counter.clone(),
260 }));
261
262 let mut ctx = RequestContext::new("test/method", None);
263
264 let result = stack.execute_before(&mut ctx, None).await;
266 assert!(result.is_err());
267 assert_eq!(result.unwrap_err(), MiddlewareError::unauthorized("Test error"));
268
269 let log = counter.lock().unwrap();
271 assert_eq!(log.len(), 1);
272 assert_eq!(log[0], "before_first");
273 }
274
275 #[tokio::test]
276 async fn test_empty_stack() {
277 let stack = MiddlewareStack::new();
278 assert!(stack.is_empty());
279 assert_eq!(stack.len(), 0);
280
281 let mut ctx = RequestContext::new("test/method", None);
282
283 let injection = stack.execute_before(&mut ctx, None).await.unwrap();
284 assert!(injection.is_empty());
285
286 let mut result = DispatcherResult::Success(json!({"ok": true}));
287 stack.execute_after(&ctx, &mut result).await.unwrap();
288 }
289}