turul_http_mcp_server/
handler.rs

1//! HTTP request handler for MCP protocol
2
3use std::sync::Arc;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use hyper::{Request, Response, Method, StatusCode};
8use http_body_util::Full;
9use bytes::Bytes;
10use hyper::header::{CONTENT_TYPE, ACCEPT};
11use http_body_util::BodyExt;
12use tracing::{debug, warn, error};
13use futures::Stream;
14use http_body::Body;
15
16use crate::{Result, ServerConfig, sse::SseManager};
17use turul_mcp_json_rpc_server::{JsonRpcDispatcher, dispatch::parse_json_rpc_message};
18
19/// SSE stream body that implements hyper's Body trait
20pub struct SseStreamBody {
21    stream: Pin<Box<dyn Stream<Item = std::result::Result<String, tokio::sync::broadcast::error::RecvError>> + Send>>,
22}
23
24impl SseStreamBody {
25    pub fn new<S>(stream: S) -> Self
26    where
27        S: Stream<Item = std::result::Result<String, tokio::sync::broadcast::error::RecvError>> + Send + 'static,
28    {
29        Self {
30            stream: Box::pin(stream),
31        }
32    }
33}
34
35impl Body for SseStreamBody {
36    type Data = Bytes;
37    type Error = Box<dyn std::error::Error + Send + Sync>;
38
39    fn poll_frame(
40        mut self: Pin<&mut Self>,
41        cx: &mut Context<'_>,
42    ) -> Poll<Option<std::result::Result<http_body::Frame<Self::Data>, Self::Error>>> {
43        match self.stream.as_mut().poll_next(cx) {
44            Poll::Ready(Some(Ok(data))) => {
45                let bytes = Bytes::from(data);
46                Poll::Ready(Some(Ok(http_body::Frame::data(bytes))))
47            }
48            Poll::Ready(Some(Err(e))) => {
49                Poll::Ready(Some(Err(Box::new(e))))
50            }
51            Poll::Ready(None) => Poll::Ready(None),
52            Poll::Pending => Poll::Pending,
53        }
54    }
55}
56
57/// HTTP handler for MCP requests
58pub struct McpHttpHandler {
59    pub(crate) config: ServerConfig,
60    pub(crate) dispatcher: Arc<JsonRpcDispatcher>,
61    pub(crate) sse_manager: Arc<SseManager>,
62}
63
64impl McpHttpHandler {
65    /// Create a new handler
66    pub fn new(config: ServerConfig, dispatcher: Arc<JsonRpcDispatcher>) -> Self {
67        Self {
68            config,
69            dispatcher,
70            sse_manager: Arc::new(SseManager::new()),
71        }
72    }
73
74    /// Create a new handler with existing SSE manager
75    pub fn with_sse_manager(
76        config: ServerConfig,
77        dispatcher: Arc<JsonRpcDispatcher>,
78        sse_manager: Arc<SseManager>
79    ) -> Self {
80        Self { config, dispatcher, sse_manager }
81    }
82
83    /// Handle MCP HTTP requests
84    pub async fn handle_mcp_request(
85        &self,
86        req: Request<hyper::body::Incoming>,
87    ) -> Result<Response<Full<Bytes>>> {
88        match req.method() {
89            &Method::POST => self.handle_json_rpc_request(req).await,
90            &Method::GET => {
91                if self.config.enable_get_sse {
92                    self.handle_sse_request(req).await
93                } else {
94                    self.method_not_allowed().await
95                }
96            }
97            &Method::OPTIONS => self.handle_preflight().await,
98            _ => self.method_not_allowed().await,
99        }
100    }
101
102    /// Handle JSON-RPC requests over HTTP POST
103    async fn handle_json_rpc_request(
104        &self,
105        req: Request<hyper::body::Incoming>,
106    ) -> Result<Response<Full<Bytes>>> {
107        // Check content type
108        let content_type = req.headers()
109            .get(CONTENT_TYPE)
110            .and_then(|ct| ct.to_str().ok())
111            .unwrap_or("");
112
113        if !content_type.starts_with("application/json") {
114            warn!("Invalid content type: {}", content_type);
115            return Ok(Response::builder()
116                .status(StatusCode::BAD_REQUEST)
117                .body(Full::new(Bytes::from("Content-Type must be application/json")))
118                .unwrap());
119        }
120
121        // Read request body
122        let body = req.into_body();
123        let body_bytes = match body.collect().await {
124            Ok(collected) => collected.to_bytes(),
125            Err(err) => {
126                error!("Failed to read request body: {}", err);
127                return Ok(Response::builder()
128                    .status(StatusCode::BAD_REQUEST)
129                    .body(Full::new(Bytes::from("Failed to read request body")))
130                    .unwrap());
131            }
132        };
133
134        // Check body size
135        if body_bytes.len() > self.config.max_body_size {
136            warn!("Request body too large: {} bytes", body_bytes.len());
137            return Ok(Response::builder()
138                .status(StatusCode::PAYLOAD_TOO_LARGE)
139                .body(Full::new(Bytes::from("Request body too large")))
140                .unwrap());
141        }
142
143        // Parse as UTF-8
144        let body_str = match std::str::from_utf8(&body_bytes) {
145            Ok(s) => s,
146            Err(err) => {
147                error!("Invalid UTF-8 in request body: {}", err);
148                return Ok(Response::builder()
149                    .status(StatusCode::BAD_REQUEST)
150                    .body(Full::new(Bytes::from("Request body must be valid UTF-8")))
151                    .unwrap());
152            }
153        };
154
155        debug!("Received JSON-RPC request: {}", body_str);
156
157        // Parse JSON-RPC message
158        let message = match parse_json_rpc_message(body_str) {
159            Ok(msg) => msg,
160            Err(rpc_err) => {
161                error!("JSON-RPC parse error: {}", rpc_err);
162                let error_response = serde_json::to_string(&rpc_err)?;
163                return Ok(Response::builder()
164                    .status(StatusCode::BAD_REQUEST)
165                    .header(CONTENT_TYPE, "application/json")
166                    .body(Full::new(Bytes::from(error_response)))
167                    .unwrap());
168            }
169        };
170
171        // Handle the message
172        match message {
173            turul_mcp_json_rpc_server::dispatch::JsonRpcMessage::Request(request) => {
174                debug!("Processing JSON-RPC request: method={}", request.method);
175                let response = self.dispatcher.handle_request(request).await;
176                let response_json = serde_json::to_string(&response)?;
177
178                debug!("Sending JSON-RPC response");
179                Ok(Response::builder()
180                    .status(StatusCode::OK)
181                    .header(CONTENT_TYPE, "application/json")
182                    .body(Full::new(Bytes::from(response_json)))
183                    .unwrap())
184            }
185            turul_mcp_json_rpc_server::dispatch::JsonRpcMessage::Notification(notification) => {
186                debug!("Processing JSON-RPC notification: method={}", notification.method);
187                if let Err(err) = self.dispatcher.handle_notification(notification).await {
188                    error!("Notification handling error: {}", err);
189                }
190
191                // Notifications don't return responses
192                Ok(Response::builder()
193                    .status(StatusCode::NO_CONTENT)
194                    .body(Full::new(Bytes::new()))
195                    .unwrap())
196            }
197        }
198    }
199
200    /// Handle Server-Sent Events requests
201    async fn handle_sse_request(
202        &self,
203        req: Request<hyper::body::Incoming>,
204    ) -> Result<Response<Full<Bytes>>> {
205        // Check if client accepts SSE
206        let headers = req.headers();
207        let accept = headers
208            .get(ACCEPT)
209            .and_then(|accept| accept.to_str().ok())
210            .unwrap_or("");
211
212        if !accept.contains("text/event-stream") {
213            return Ok(Response::builder()
214                .status(StatusCode::NOT_ACCEPTABLE)
215                .body(Full::new(Bytes::from("SSE not accepted")))
216                .unwrap());
217        }
218
219        // Extract connection ID from query parameters or generate one
220        let connection_id = match req.uri().query() {
221            Some(q) => {
222                q.split('&')
223                    .find(|param| param.starts_with("connection_id="))
224                    .and_then(|param| param.split('=').nth(1))
225                    .map(|s| s.to_string())
226                    .unwrap_or_else(|| uuid::Uuid::now_v7().to_string())
227            }
228            None => uuid::Uuid::now_v7().to_string(),
229        };
230
231        debug!("Creating SSE connection: {}", connection_id);
232
233        // For now, return an initial connection message with instructions
234        // TODO: Implement actual streaming with SseStreamBody when we can change the signature
235        let initial_response = format!(
236            "event: connection\n\
237             data: {{\"type\":\"connected\",\"connection_id\":\"{}\",\"message\":\"SSE connection established\"}}\n\n\
238             event: info\n\
239             data: {{\"type\":\"info\",\"message\":\"This is a basic SSE endpoint. Full streaming will be available in a future update.\"}}\n\n",
240            connection_id
241        );
242
243        // Store the connection for later use
244        let _connection = self.sse_manager.create_connection(connection_id.clone()).await;
245
246        // Start a background task to send periodic keep-alives
247        let sse_manager = Arc::clone(&self.sse_manager);
248        let conn_id = connection_id.clone();
249        tokio::spawn(async move {
250            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
251            for _ in 0..10 { // Send 10 keep-alives then stop for this basic implementation
252                interval.tick().await;
253                sse_manager.send_keep_alive().await;
254            }
255            sse_manager.remove_connection(&conn_id).await;
256        });
257
258        Ok(Response::builder()
259            .status(StatusCode::OK)
260            .header(CONTENT_TYPE, "text/event-stream")
261            .header("Cache-Control", "no-cache")
262            .header("Connection", "keep-alive")
263            .header("Access-Control-Allow-Origin", "*")
264            .header("Access-Control-Allow-Headers", "Cache-Control")
265            .body(Full::new(Bytes::from(initial_response)))
266            .unwrap())
267    }
268
269    /// Handle OPTIONS preflight requests
270    async fn handle_preflight(&self) -> Result<Response<Full<Bytes>>> {
271        Ok(Response::builder()
272            .status(StatusCode::OK)
273            .header("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
274            .header("Access-Control-Allow-Headers", "Content-Type, Accept")
275            .header("Access-Control-Max-Age", "86400")
276            .body(Full::new(Bytes::new()))
277            .unwrap())
278    }
279
280    /// Return method not allowed response
281    async fn method_not_allowed(&self) -> Result<Response<Full<Bytes>>> {
282        Ok(Response::builder()
283            .status(StatusCode::METHOD_NOT_ALLOWED)
284            .header("Allow", "POST, OPTIONS")
285            .body(Full::new(Bytes::from("Method not allowed")))
286            .unwrap())
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use turul_mcp_json_rpc_server::JsonRpcDispatcher;
294
295    fn create_test_handler() -> McpHttpHandler {
296        let config = ServerConfig::default();
297        let dispatcher = Arc::new(JsonRpcDispatcher::new());
298        McpHttpHandler::new(config, dispatcher)
299    }
300
301    #[tokio::test]
302    async fn test_options_request() {
303        let _handler = create_test_handler();
304        // For testing, we'll need to create a proper request body
305        // For now, let's create a simple test that doesn't use actual HTTP requests
306        assert!(true); // Placeholder test
307    }
308
309    #[tokio::test]
310    async fn test_method_not_allowed() {
311        let _handler = create_test_handler();
312        // For testing, we'll need to create a proper request body
313        // For now, let's create a simple test that doesn't use actual HTTP requests
314        assert!(true); // Placeholder test
315    }
316}