turul_http_mcp_server/middleware/
stack.rs1use super::{DispatcherResult, McpMiddleware, MiddlewareError, RequestContext, SessionInjection};
4use std::sync::Arc;
5use turul_mcp_session_storage::SessionView;
6
7#[derive(Default, Clone)]
50pub struct MiddlewareStack {
51 middleware: Vec<Arc<dyn McpMiddleware>>,
52}
53
54impl MiddlewareStack {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn push(&mut self, middleware: Arc<dyn McpMiddleware>) {
71 self.middleware.push(middleware);
72 }
73
74 pub fn len(&self) -> usize {
76 self.middleware.len()
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.middleware.is_empty()
82 }
83
84 pub async fn execute_before(
104 &self,
105 ctx: &mut RequestContext<'_>,
106 session: Option<&dyn SessionView>,
107 ) -> Result<SessionInjection, MiddlewareError> {
108 let mut combined_injection = SessionInjection::new();
109
110 for middleware in &self.middleware {
111 let mut injection = SessionInjection::new();
112 middleware
113 .before_dispatch(ctx, session, &mut injection)
114 .await?;
115
116 for (key, value) in injection.state() {
118 combined_injection.set_state(key.clone(), value.clone());
119 }
120 for (key, value) in injection.metadata() {
121 combined_injection.set_metadata(key.clone(), value.clone());
122 }
123 }
124
125 Ok(combined_injection)
126 }
127
128 pub async fn execute_after(
146 &self,
147 ctx: &RequestContext<'_>,
148 result: &mut DispatcherResult,
149 ) -> Result<(), MiddlewareError> {
150 for middleware in self.middleware.iter().rev() {
152 middleware.after_dispatch(ctx, result).await?;
153 }
154
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use async_trait::async_trait;
163 use serde_json::json;
164
165 struct CountingMiddleware {
166 id: String,
167 counter: Arc<std::sync::Mutex<Vec<String>>>,
168 }
169
170 #[async_trait]
171 impl McpMiddleware for CountingMiddleware {
172 async fn before_dispatch(
173 &self,
174 _ctx: &mut RequestContext<'_>,
175 _session: Option<&dyn SessionView>,
176 injection: &mut SessionInjection,
177 ) -> Result<(), MiddlewareError> {
178 self.counter
179 .lock()
180 .unwrap()
181 .push(format!("before_{}", self.id));
182 injection.set_state(&self.id, json!(true));
183 Ok(())
184 }
185
186 async fn after_dispatch(
187 &self,
188 _ctx: &RequestContext<'_>,
189 _result: &mut DispatcherResult,
190 ) -> Result<(), MiddlewareError> {
191 self.counter
192 .lock()
193 .unwrap()
194 .push(format!("after_{}", self.id));
195 Ok(())
196 }
197 }
198
199 struct ErrorMiddleware {
200 error_on_before: bool,
201 }
202
203 #[async_trait]
204 impl McpMiddleware for ErrorMiddleware {
205 async fn before_dispatch(
206 &self,
207 _ctx: &mut RequestContext<'_>,
208 _session: Option<&dyn SessionView>,
209 _injection: &mut SessionInjection,
210 ) -> Result<(), MiddlewareError> {
211 if self.error_on_before {
212 Err(MiddlewareError::unauthorized("Test error"))
213 } else {
214 Ok(())
215 }
216 }
217 }
218
219 #[tokio::test]
220 async fn test_middleware_execution_order() {
221 let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
222 let mut stack = MiddlewareStack::new();
223
224 stack.push(Arc::new(CountingMiddleware {
225 id: "first".to_string(),
226 counter: counter.clone(),
227 }));
228 stack.push(Arc::new(CountingMiddleware {
229 id: "second".to_string(),
230 counter: counter.clone(),
231 }));
232
233 let mut ctx = RequestContext::new("test/method", None);
234
235 let injection = stack.execute_before(&mut ctx, None).await.unwrap();
237 assert_eq!(injection.state().len(), 2);
238 assert!(injection.state().contains_key("first"));
239 assert!(injection.state().contains_key("second"));
240
241 let mut result = DispatcherResult::Success(json!({"ok": true}));
243 stack.execute_after(&ctx, &mut result).await.unwrap();
244
245 let log = counter.lock().unwrap();
247 assert_eq!(log[0], "before_first");
248 assert_eq!(log[1], "before_second");
249 assert_eq!(log[2], "after_second"); assert_eq!(log[3], "after_first");
251 }
252
253 #[tokio::test]
254 async fn test_middleware_error_stops_chain() {
255 let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
256 let mut stack = MiddlewareStack::new();
257
258 stack.push(Arc::new(CountingMiddleware {
259 id: "first".to_string(),
260 counter: counter.clone(),
261 }));
262 stack.push(Arc::new(ErrorMiddleware {
263 error_on_before: true,
264 }));
265 stack.push(Arc::new(CountingMiddleware {
266 id: "third".to_string(),
267 counter: counter.clone(),
268 }));
269
270 let mut ctx = RequestContext::new("test/method", None);
271
272 let result = stack.execute_before(&mut ctx, None).await;
274 assert!(result.is_err());
275 assert_eq!(
276 result.unwrap_err(),
277 MiddlewareError::unauthorized("Test error")
278 );
279
280 let log = counter.lock().unwrap();
282 assert_eq!(log.len(), 1);
283 assert_eq!(log[0], "before_first");
284 }
285
286 #[tokio::test]
287 async fn test_empty_stack() {
288 let stack = MiddlewareStack::new();
289 assert!(stack.is_empty());
290 assert_eq!(stack.len(), 0);
291
292 let mut ctx = RequestContext::new("test/method", None);
293
294 let injection = stack.execute_before(&mut ctx, None).await.unwrap();
295 assert!(injection.is_empty());
296
297 let mut result = DispatcherResult::Success(json!({"ok": true}));
298 stack.execute_after(&ctx, &mut result).await.unwrap();
299 }
300}