turul_http_mcp_server/
handler.rs1use std::sync::Arc;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use hyper::{Request, Response, Method, StatusCode};
8use http_body_util::Full;
9use bytes::Bytes;
10use hyper::header::{CONTENT_TYPE, ACCEPT};
11use http_body_util::BodyExt;
12use tracing::{debug, warn, error};
13use futures::Stream;
14use http_body::Body;
15
16use crate::{Result, ServerConfig, sse::SseManager};
17use turul_mcp_json_rpc_server::{JsonRpcDispatcher, dispatch::parse_json_rpc_message};
18
19pub struct SseStreamBody {
21 stream: Pin<Box<dyn Stream<Item = std::result::Result<String, tokio::sync::broadcast::error::RecvError>> + Send>>,
22}
23
24impl SseStreamBody {
25 pub fn new<S>(stream: S) -> Self
26 where
27 S: Stream<Item = std::result::Result<String, tokio::sync::broadcast::error::RecvError>> + Send + 'static,
28 {
29 Self {
30 stream: Box::pin(stream),
31 }
32 }
33}
34
35impl Body for SseStreamBody {
36 type Data = Bytes;
37 type Error = Box<dyn std::error::Error + Send + Sync>;
38
39 fn poll_frame(
40 mut self: Pin<&mut Self>,
41 cx: &mut Context<'_>,
42 ) -> Poll<Option<std::result::Result<http_body::Frame<Self::Data>, Self::Error>>> {
43 match self.stream.as_mut().poll_next(cx) {
44 Poll::Ready(Some(Ok(data))) => {
45 let bytes = Bytes::from(data);
46 Poll::Ready(Some(Ok(http_body::Frame::data(bytes))))
47 }
48 Poll::Ready(Some(Err(e))) => {
49 Poll::Ready(Some(Err(Box::new(e))))
50 }
51 Poll::Ready(None) => Poll::Ready(None),
52 Poll::Pending => Poll::Pending,
53 }
54 }
55}
56
57pub struct McpHttpHandler {
59 pub(crate) config: ServerConfig,
60 pub(crate) dispatcher: Arc<JsonRpcDispatcher>,
61 pub(crate) sse_manager: Arc<SseManager>,
62}
63
64impl McpHttpHandler {
65 pub fn new(config: ServerConfig, dispatcher: Arc<JsonRpcDispatcher>) -> Self {
67 Self {
68 config,
69 dispatcher,
70 sse_manager: Arc::new(SseManager::new()),
71 }
72 }
73
74 pub fn with_sse_manager(
76 config: ServerConfig,
77 dispatcher: Arc<JsonRpcDispatcher>,
78 sse_manager: Arc<SseManager>
79 ) -> Self {
80 Self { config, dispatcher, sse_manager }
81 }
82
83 pub async fn handle_mcp_request(
85 &self,
86 req: Request<hyper::body::Incoming>,
87 ) -> Result<Response<Full<Bytes>>> {
88 match req.method() {
89 &Method::POST => self.handle_json_rpc_request(req).await,
90 &Method::GET => {
91 if self.config.enable_get_sse {
92 self.handle_sse_request(req).await
93 } else {
94 self.method_not_allowed().await
95 }
96 }
97 &Method::OPTIONS => self.handle_preflight().await,
98 _ => self.method_not_allowed().await,
99 }
100 }
101
102 async fn handle_json_rpc_request(
104 &self,
105 req: Request<hyper::body::Incoming>,
106 ) -> Result<Response<Full<Bytes>>> {
107 let content_type = req.headers()
109 .get(CONTENT_TYPE)
110 .and_then(|ct| ct.to_str().ok())
111 .unwrap_or("");
112
113 if !content_type.starts_with("application/json") {
114 warn!("Invalid content type: {}", content_type);
115 return Ok(Response::builder()
116 .status(StatusCode::BAD_REQUEST)
117 .body(Full::new(Bytes::from("Content-Type must be application/json")))
118 .unwrap());
119 }
120
121 let body = req.into_body();
123 let body_bytes = match body.collect().await {
124 Ok(collected) => collected.to_bytes(),
125 Err(err) => {
126 error!("Failed to read request body: {}", err);
127 return Ok(Response::builder()
128 .status(StatusCode::BAD_REQUEST)
129 .body(Full::new(Bytes::from("Failed to read request body")))
130 .unwrap());
131 }
132 };
133
134 if body_bytes.len() > self.config.max_body_size {
136 warn!("Request body too large: {} bytes", body_bytes.len());
137 return Ok(Response::builder()
138 .status(StatusCode::PAYLOAD_TOO_LARGE)
139 .body(Full::new(Bytes::from("Request body too large")))
140 .unwrap());
141 }
142
143 let body_str = match std::str::from_utf8(&body_bytes) {
145 Ok(s) => s,
146 Err(err) => {
147 error!("Invalid UTF-8 in request body: {}", err);
148 return Ok(Response::builder()
149 .status(StatusCode::BAD_REQUEST)
150 .body(Full::new(Bytes::from("Request body must be valid UTF-8")))
151 .unwrap());
152 }
153 };
154
155 debug!("Received JSON-RPC request: {}", body_str);
156
157 let message = match parse_json_rpc_message(body_str) {
159 Ok(msg) => msg,
160 Err(rpc_err) => {
161 error!("JSON-RPC parse error: {}", rpc_err);
162 let error_response = serde_json::to_string(&rpc_err)?;
163 return Ok(Response::builder()
164 .status(StatusCode::BAD_REQUEST)
165 .header(CONTENT_TYPE, "application/json")
166 .body(Full::new(Bytes::from(error_response)))
167 .unwrap());
168 }
169 };
170
171 match message {
173 turul_mcp_json_rpc_server::dispatch::JsonRpcMessage::Request(request) => {
174 debug!("Processing JSON-RPC request: method={}", request.method);
175 let response = self.dispatcher.handle_request(request).await;
176 let response_json = serde_json::to_string(&response)?;
177
178 debug!("Sending JSON-RPC response");
179 Ok(Response::builder()
180 .status(StatusCode::OK)
181 .header(CONTENT_TYPE, "application/json")
182 .body(Full::new(Bytes::from(response_json)))
183 .unwrap())
184 }
185 turul_mcp_json_rpc_server::dispatch::JsonRpcMessage::Notification(notification) => {
186 debug!("Processing JSON-RPC notification: method={}", notification.method);
187 if let Err(err) = self.dispatcher.handle_notification(notification).await {
188 error!("Notification handling error: {}", err);
189 }
190
191 Ok(Response::builder()
193 .status(StatusCode::NO_CONTENT)
194 .body(Full::new(Bytes::new()))
195 .unwrap())
196 }
197 }
198 }
199
200 async fn handle_sse_request(
202 &self,
203 req: Request<hyper::body::Incoming>,
204 ) -> Result<Response<Full<Bytes>>> {
205 let headers = req.headers();
207 let accept = headers
208 .get(ACCEPT)
209 .and_then(|accept| accept.to_str().ok())
210 .unwrap_or("");
211
212 if !accept.contains("text/event-stream") {
213 return Ok(Response::builder()
214 .status(StatusCode::NOT_ACCEPTABLE)
215 .body(Full::new(Bytes::from("SSE not accepted")))
216 .unwrap());
217 }
218
219 let connection_id = match req.uri().query() {
221 Some(q) => {
222 q.split('&')
223 .find(|param| param.starts_with("connection_id="))
224 .and_then(|param| param.split('=').nth(1))
225 .map(|s| s.to_string())
226 .unwrap_or_else(|| uuid::Uuid::now_v7().to_string())
227 }
228 None => uuid::Uuid::now_v7().to_string(),
229 };
230
231 debug!("Creating SSE connection: {}", connection_id);
232
233 let initial_response = format!(
236 "event: connection\n\
237 data: {{\"type\":\"connected\",\"connection_id\":\"{}\",\"message\":\"SSE connection established\"}}\n\n\
238 event: info\n\
239 data: {{\"type\":\"info\",\"message\":\"This is a basic SSE endpoint. Full streaming will be available in a future update.\"}}\n\n",
240 connection_id
241 );
242
243 let _connection = self.sse_manager.create_connection(connection_id.clone()).await;
245
246 let sse_manager = Arc::clone(&self.sse_manager);
248 let conn_id = connection_id.clone();
249 tokio::spawn(async move {
250 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
251 for _ in 0..10 { interval.tick().await;
253 sse_manager.send_keep_alive().await;
254 }
255 sse_manager.remove_connection(&conn_id).await;
256 });
257
258 Ok(Response::builder()
259 .status(StatusCode::OK)
260 .header(CONTENT_TYPE, "text/event-stream")
261 .header("Cache-Control", "no-cache")
262 .header("Connection", "keep-alive")
263 .header("Access-Control-Allow-Origin", "*")
264 .header("Access-Control-Allow-Headers", "Cache-Control")
265 .body(Full::new(Bytes::from(initial_response)))
266 .unwrap())
267 }
268
269 async fn handle_preflight(&self) -> Result<Response<Full<Bytes>>> {
271 Ok(Response::builder()
272 .status(StatusCode::OK)
273 .header("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
274 .header("Access-Control-Allow-Headers", "Content-Type, Accept")
275 .header("Access-Control-Max-Age", "86400")
276 .body(Full::new(Bytes::new()))
277 .unwrap())
278 }
279
280 async fn method_not_allowed(&self) -> Result<Response<Full<Bytes>>> {
282 Ok(Response::builder()
283 .status(StatusCode::METHOD_NOT_ALLOWED)
284 .header("Allow", "POST, OPTIONS")
285 .body(Full::new(Bytes::from("Method not allowed")))
286 .unwrap())
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use turul_mcp_json_rpc_server::JsonRpcDispatcher;
294
295 fn create_test_handler() -> McpHttpHandler {
296 let config = ServerConfig::default();
297 let dispatcher = Arc::new(JsonRpcDispatcher::new());
298 McpHttpHandler::new(config, dispatcher)
299 }
300
301 #[tokio::test]
302 async fn test_options_request() {
303 let _handler = create_test_handler();
304 assert!(true); }
308
309 #[tokio::test]
310 async fn test_method_not_allowed() {
311 let _handler = create_test_handler();
312 assert!(true); }
316}