1use std::collections::HashMap;
13use std::sync::Arc;
14use std::pin::Pin;
15
16use hyper::{Request, Response, Method, StatusCode, HeaderMap};
17use http_body_util::Full;
18use bytes::Bytes;
19use hyper::header::{CONTENT_TYPE, ACCEPT};
20use tracing::{warn, info};
21use futures::Stream;
22use http_body::Body;
23use serde_json::Value;
24
25use crate::ServerConfig;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum McpProtocolVersion {
30 V2024_11_05,
32 V2025_03_26,
34 V2025_06_18,
36}
37
38impl McpProtocolVersion {
39 pub fn from_str(s: &str) -> Option<Self> {
41 match s {
42 "2024-11-05" => Some(Self::V2024_11_05),
43 "2025-03-26" => Some(Self::V2025_03_26),
44 "2025-06-18" => Some(Self::V2025_06_18),
45 _ => None,
46 }
47 }
48
49 pub fn as_str(&self) -> &'static str {
51 match self {
52 Self::V2024_11_05 => "2024-11-05",
53 Self::V2025_03_26 => "2025-03-26",
54 Self::V2025_06_18 => "2025-06-18",
55 }
56 }
57
58 pub fn supports_streamable_http(&self) -> bool {
60 matches!(self, Self::V2025_03_26 | Self::V2025_06_18)
61 }
62
63 pub fn supports_meta_fields(&self) -> bool {
65 matches!(self, Self::V2025_06_18)
66 }
67
68 pub fn supports_cursors(&self) -> bool {
70 matches!(self, Self::V2025_06_18)
71 }
72
73 pub fn supports_progress_tokens(&self) -> bool {
75 matches!(self, Self::V2025_06_18)
76 }
77
78 pub fn supports_elicitation(&self) -> bool {
80 matches!(self, Self::V2025_06_18)
81 }
82
83 pub fn supported_features(&self) -> Vec<&'static str> {
85 let mut features = vec![];
86 if self.supports_streamable_http() {
87 features.push("streamable-http");
88 }
89 if self.supports_meta_fields() {
90 features.push("_meta-fields");
91 }
92 if self.supports_cursors() {
93 features.push("cursor-pagination");
94 }
95 if self.supports_progress_tokens() {
96 features.push("progress-tokens");
97 }
98 if self.supports_elicitation() {
99 features.push("elicitation");
100 }
101 features
102 }
103}
104
105impl Default for McpProtocolVersion {
106 fn default() -> Self {
107 Self::V2025_06_18
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct StreamableHttpContext {
114 pub protocol_version: McpProtocolVersion,
116 pub session_id: Option<String>,
118 pub wants_streaming: bool,
120 pub accepts_json: bool,
122 pub headers: HashMap<String, String>,
124}
125
126impl StreamableHttpContext {
127 pub fn from_request<T>(req: &Request<T>) -> Self {
129 let headers = req.headers();
130
131 let protocol_version = headers
133 .get("MCP-Protocol-Version")
134 .and_then(|h| h.to_str().ok())
135 .and_then(McpProtocolVersion::from_str)
136 .unwrap_or_default();
137
138 let session_id = headers
140 .get("Mcp-Session-Id")
141 .and_then(|h| h.to_str().ok())
142 .map(|s| s.to_string());
143
144 let accept_header = headers
146 .get(ACCEPT)
147 .and_then(|h| h.to_str().ok())
148 .unwrap_or_default()
149 .to_ascii_lowercase();
150
151 let wants_streaming = accept_header.contains("text/event-stream");
152 let accepts_json = accept_header.contains("application/json") || accept_header.contains("*/*");
153
154 let mut header_map = HashMap::new();
156 for (name, value) in headers.iter() {
157 if let Ok(value_str) = value.to_str() {
158 header_map.insert(name.to_string(), value_str.to_string());
159 }
160 }
161
162 Self {
163 protocol_version,
164 session_id,
165 wants_streaming,
166 accepts_json,
167 headers: header_map,
168 }
169 }
170
171 pub fn is_streamable_compatible(&self) -> bool {
173 self.protocol_version.supports_streamable_http() &&
174 self.wants_streaming &&
175 self.session_id.is_some()
176 }
177
178 pub fn validate(&self) -> std::result::Result<(), String> {
180 if !self.accepts_json {
181 return Err("Accept header must include application/json".to_string());
182 }
183
184 if self.wants_streaming && !self.protocol_version.supports_streamable_http() {
185 return Err(format!(
186 "Protocol version {} does not support streamable HTTP",
187 self.protocol_version.as_str()
188 ));
189 }
190
191 if self.wants_streaming && self.session_id.is_none() {
192 return Err("Mcp-Session-Id header required for streaming requests".to_string());
193 }
194
195 Ok(())
196 }
197
198 pub fn response_headers(&self) -> HeaderMap {
200 let mut headers = HeaderMap::new();
201
202 headers.insert(
204 "MCP-Protocol-Version",
205 self.protocol_version.as_str().parse().unwrap()
206 );
207
208 if let Some(session_id) = &self.session_id {
210 headers.insert(
211 "Mcp-Session-Id",
212 session_id.parse().unwrap()
213 );
214 }
215
216 let features = self.protocol_version.supported_features();
218 if !features.is_empty() {
219 headers.insert(
220 "MCP-Capabilities",
221 features.join(",").parse().unwrap()
222 );
223 }
224
225 headers
226 }
227}
228
229pub enum StreamableResponse {
231 Json(Value),
233 Stream(Pin<Box<dyn Stream<Item = std::result::Result<Value, String>> + Send>>),
235 Error { status: StatusCode, message: String },
237}
238
239impl std::fmt::Debug for StreamableResponse {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 match self {
242 Self::Json(value) => f.debug_tuple("Json").field(value).finish(),
243 Self::Stream(_) => f.debug_tuple("Stream").field(&"<stream>").finish(),
244 Self::Error { status, message } => f.debug_struct("Error")
245 .field("status", status)
246 .field("message", message)
247 .finish(),
248 }
249 }
250}
251
252impl StreamableResponse {
253 pub fn into_response(self, context: &StreamableHttpContext) -> Response<Full<Bytes>> {
255 let mut response_headers = context.response_headers();
256
257 match self {
258 StreamableResponse::Json(json) => {
259 response_headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
260
261 let body = serde_json::to_string(&json)
262 .unwrap_or_else(|_| r#"{"error": "Failed to serialize response"}"#.to_string());
263
264 Response::builder()
265 .status(StatusCode::OK)
266 .body(Full::new(Bytes::from(body)))
267 .unwrap()
268 }
269
270 StreamableResponse::Stream(_stream) => {
271 response_headers.insert(CONTENT_TYPE, "text/event-stream".parse().unwrap());
273 response_headers.insert("Cache-Control", "no-cache, no-transform".parse().unwrap());
274 response_headers.insert("Connection", "keep-alive".parse().unwrap());
275
276 Response::builder()
280 .status(StatusCode::ACCEPTED)
281 .body(Full::new(Bytes::from("Streaming response accepted")))
282 .unwrap()
283 }
284
285 StreamableResponse::Error { status, message } => {
286 response_headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
287
288 let error_json = serde_json::json!({
289 "error": {
290 "code": status.as_u16(),
291 "message": message
292 }
293 });
294
295 let body = serde_json::to_string(&error_json)
296 .unwrap_or_else(|_| r#"{"error": {"code": 500, "message": "Internal server error"}}"#.to_string());
297
298 Response::builder()
299 .status(status)
300 .body(Full::new(Bytes::from(body)))
301 .unwrap()
302 }
303 }
304 }
305}
306
307pub struct StreamableHttpHandler {
309 #[allow(dead_code)] config: Arc<ServerConfig>,
311}
312
313impl StreamableHttpHandler {
314 pub fn new(config: Arc<ServerConfig>) -> Self {
315 Self { config }
316 }
317
318 pub async fn handle_request<T>(&self, req: Request<T>) -> Response<Full<Bytes>>
320 where
321 T: Body + Send + 'static,
322 T::Data: Send,
323 T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
324 {
325 let context = StreamableHttpContext::from_request(&req);
327
328 info!(
329 "Streamable HTTP request: method={}, protocol={}, session={:?}, streaming={}",
330 req.method(),
331 context.protocol_version.as_str(),
332 context.session_id,
333 context.wants_streaming
334 );
335
336 if let Err(error) = context.validate() {
338 warn!("Invalid streamable HTTP request: {}", error);
339 return StreamableResponse::Error {
340 status: StatusCode::BAD_REQUEST,
341 message: error,
342 }.into_response(&context);
343 }
344
345 match (req.method(), context.is_streamable_compatible()) {
347 (&Method::GET, true) => self.handle_streaming_get(req, context).await,
348 (&Method::POST, true) => self.handle_streaming_post(req, context).await,
349 (&Method::POST, false) => self.handle_json_post(req, context).await,
350 (&Method::DELETE, _) => self.handle_session_delete(req, context).await,
351 _ => {
352 StreamableResponse::Error {
353 status: StatusCode::METHOD_NOT_ALLOWED,
354 message: "Method not allowed for this endpoint".to_string(),
355 }.into_response(&context)
356 }
357 }
358 }
359
360 async fn handle_streaming_get<T>(&self, _req: Request<T>, context: StreamableHttpContext) -> Response<Full<Bytes>>
362 where
363 T: Body + Send + 'static,
364 {
365 info!("Opening streaming connection for session: {:?}", context.session_id);
366
367 StreamableResponse::Json(serde_json::json!({
375 "status": "streaming_connection_opened",
376 "session_id": context.session_id,
377 "protocol_version": context.protocol_version.as_str(),
378 "note": "Streaming implementation pending"
379 })).into_response(&context)
380 }
381
382 async fn handle_streaming_post<T>(&self, _req: Request<T>, context: StreamableHttpContext) -> Response<Full<Bytes>>
384 where
385 T: Body + Send + 'static,
386 {
387 info!("Handling streaming POST for session: {:?}", context.session_id);
388
389 StreamableResponse::Json(serde_json::json!({
397 "status": "streaming_post_accepted",
398 "session_id": context.session_id,
399 "protocol_version": context.protocol_version.as_str(),
400 "note": "Streaming POST implementation pending"
401 })).into_response(&context)
402 }
403
404 async fn handle_json_post<T>(&self, _req: Request<T>, context: StreamableHttpContext) -> Response<Full<Bytes>>
406 where
407 T: Body + Send + 'static,
408 {
409 info!("Handling JSON POST (non-streaming)");
410
411 StreamableResponse::Json(serde_json::json!({
419 "status": "json_post_handled",
420 "protocol_version": context.protocol_version.as_str(),
421 "streaming": false,
422 "note": "JSON POST implementation pending"
423 })).into_response(&context)
424 }
425
426 async fn handle_session_delete<T>(&self, _req: Request<T>, context: StreamableHttpContext) -> Response<Full<Bytes>>
428 where
429 T: Body + Send + 'static,
430 {
431 if let Some(session_id) = &context.session_id {
432 info!("Deleting session: {}", session_id);
433
434 StreamableResponse::Json(serde_json::json!({
439 "status": "session_deleted",
440 "session_id": session_id,
441 "note": "Session cleanup implementation pending"
442 })).into_response(&context)
443 } else {
444 StreamableResponse::Error {
445 status: StatusCode::BAD_REQUEST,
446 message: "Mcp-Session-Id header required for session deletion".to_string(),
447 }.into_response(&context)
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn test_protocol_version_parsing() {
458 assert_eq!(McpProtocolVersion::from_str("2024-11-05"), Some(McpProtocolVersion::V2024_11_05));
459 assert_eq!(McpProtocolVersion::from_str("2025-03-26"), Some(McpProtocolVersion::V2025_03_26));
460 assert_eq!(McpProtocolVersion::from_str("2025-06-18"), Some(McpProtocolVersion::V2025_06_18));
461 assert_eq!(McpProtocolVersion::from_str("invalid"), None);
462 }
463
464 #[test]
465 fn test_version_capabilities() {
466 let v1 = McpProtocolVersion::V2024_11_05;
467 assert!(!v1.supports_streamable_http());
468 assert!(!v1.supports_meta_fields());
469
470 let v2 = McpProtocolVersion::V2025_03_26;
471 assert!(v2.supports_streamable_http());
472 assert!(!v2.supports_meta_fields());
473
474 let v3 = McpProtocolVersion::V2025_06_18;
475 assert!(v3.supports_streamable_http());
476 assert!(v3.supports_meta_fields());
477 assert!(v3.supports_cursors());
478 assert!(v3.supports_progress_tokens());
479 assert!(v3.supports_elicitation());
480 }
481
482 #[test]
483 fn test_context_validation() {
484 let mut context = StreamableHttpContext {
485 protocol_version: McpProtocolVersion::V2025_06_18,
486 session_id: Some("test-session".to_string()),
487 wants_streaming: true,
488 accepts_json: true,
489 headers: HashMap::new(),
490 };
491
492 assert!(context.validate().is_ok());
493
494 context.accepts_json = false;
496 assert!(context.validate().is_err());
497
498 context.accepts_json = true;
499 context.protocol_version = McpProtocolVersion::V2024_11_05;
500 context.wants_streaming = true;
501 assert!(context.validate().is_err());
502
503 context.protocol_version = McpProtocolVersion::V2025_06_18;
504 context.session_id = None;
505 assert!(context.validate().is_err());
506 }
507}