1use axum::{
7 Json,
8 extract::State,
9 http::{StatusCode, header::HeaderMap},
10 response::{IntoResponse, Response, Sse, sse::Event},
11 routing::Router,
12};
13use bytes::Bytes;
14use futures::stream::{self, Stream};
15use std::sync::Arc;
16use tokio::sync::broadcast;
17use tower_http::cors::CorsLayer;
18use tracing::{error, info};
19
20use ultrafast_mcp_core::{
21 protocol::{
22 jsonrpc::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse},
23 version::PROTOCOL_VERSION,
24 },
25 utils::{generate_event_id, generate_session_id},
26 validation::{validate_origin, validate_protocol_version, validate_session_id},
27};
28use ultrafast_mcp_monitoring::metrics::RequestTimer;
29use ultrafast_mcp_monitoring::{MetricsCollector, MonitoringSystem};
30
31use crate::{Result, Transport, TransportError};
32use async_trait::async_trait;
33
34#[derive(Debug, Clone)]
36pub struct HttpTransportConfig {
37 pub host: String,
38 pub port: u16,
39 pub cors_enabled: bool,
40 pub protocol_version: String,
41 pub allow_origin: Option<String>,
42 pub monitoring_enabled: bool,
43 pub enable_sse_resumability: bool,
44}
45
46impl Default for HttpTransportConfig {
47 fn default() -> Self {
48 Self {
49 host: "127.0.0.1".to_string(),
50 port: 8080,
51 cors_enabled: true,
52 protocol_version: PROTOCOL_VERSION.to_string(),
53 allow_origin: Some("http://localhost:*".to_string()),
54 monitoring_enabled: true,
55 enable_sse_resumability: true,
56 }
57 }
58}
59
60#[derive(Clone)]
62pub struct HttpTransportState {
63 pub message_sender: broadcast::Sender<(String, JsonRpcMessage)>,
64 pub response_sender: broadcast::Sender<(String, JsonRpcMessage)>,
65 pub config: HttpTransportConfig,
66 pub metrics: Option<Arc<MetricsCollector>>,
67 pub monitoring: Option<Arc<MonitoringSystem>>,
68 pub session_store: Arc<tokio::sync::RwLock<std::collections::HashMap<String, SessionInfo>>>,
69}
70
71#[derive(Debug, Clone)]
73pub struct SessionInfo {
74 pub created_at: std::time::SystemTime,
75 pub last_event_id: Option<String>,
76 pub active_streams: std::collections::HashSet<String>,
77}
78
79impl SessionInfo {
80 pub fn new() -> Self {
81 Self {
82 created_at: std::time::SystemTime::now(),
83 last_event_id: None,
84 active_streams: std::collections::HashSet::new(),
85 }
86 }
87}
88
89impl Default for SessionInfo {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95pub struct HttpTransportServer {
97 state: HttpTransportState,
98 message_receiver: broadcast::Receiver<(String, JsonRpcMessage)>,
99}
100
101impl HttpTransportServer {
102 pub fn new(config: HttpTransportConfig) -> Self {
103 let (message_sender, message_receiver) = broadcast::channel(1000);
104 let (response_sender, _) = broadcast::channel(1000);
105
106 let state = HttpTransportState {
107 message_sender,
108 response_sender,
109 config,
110 metrics: None,
111 monitoring: None,
112 session_store: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
113 };
114
115 Self {
116 state,
117 message_receiver,
118 }
119 }
120
121 pub fn with_metrics(mut self, metrics: Arc<MetricsCollector>) -> Self {
122 self.state.metrics = Some(metrics);
123 self
124 }
125
126 pub fn with_monitoring(mut self, monitoring: Arc<MonitoringSystem>) -> Self {
127 self.state.monitoring = Some(monitoring);
128 self
129 }
130
131 pub fn get_message_receiver(&self) -> broadcast::Receiver<(String, JsonRpcMessage)> {
132 self.state.message_sender.subscribe()
133 }
134
135 pub fn get_message_sender(&self) -> broadcast::Sender<(String, JsonRpcMessage)> {
136 self.state.message_sender.clone()
137 }
138
139 pub fn get_response_sender(&self) -> broadcast::Sender<(String, JsonRpcMessage)> {
140 self.state.response_sender.clone()
141 }
142
143 pub fn get_state(&self) -> HttpTransportState {
144 self.state.clone()
145 }
146
147 pub fn get_metrics(&self) -> Option<Arc<MetricsCollector>> {
148 self.state.metrics.clone()
149 }
150
151 pub fn get_monitoring(&self) -> Option<Arc<MonitoringSystem>> {
152 self.state.monitoring.clone()
153 }
154
155 pub async fn run(self) -> Result<()> {
156 info!(
157 "Starting HTTP transport server on {}:{}",
158 self.state.config.host, self.state.config.port
159 );
160
161 let app = self.create_router();
162 let addr = (self.state.config.host.as_str(), self.state.config.port);
163
164 let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
165 TransportError::InitializationError {
166 message: format!("Failed to bind to address: {e}"),
167 }
168 })?;
169
170 if let Some(monitoring) = &self.state.monitoring {
172 let monitoring_addr =
173 format!("{}:{}", self.state.config.host, self.state.config.port + 1)
174 .parse()
175 .map_err(|e| TransportError::InitializationError {
176 message: format!("Failed to parse monitoring address: {e}"),
177 })?;
178
179 let monitoring_clone = monitoring.clone();
180 tokio::spawn(async move {
181 if let Err(e) = monitoring_clone.start_http_server(monitoring_addr).await {
182 error!("Failed to start monitoring server: {}", e);
183 }
184 });
185 }
186
187 axum::serve(listener, app.into_make_service())
188 .await
189 .map_err(|e| TransportError::InitializationError {
190 message: format!("Server failed: {e}"),
191 })?;
192
193 Ok(())
194 }
195
196 fn create_router(&self) -> Router {
197 let state = Arc::new(self.state.clone());
198 let mut router = Router::new()
199 .route("/mcp", axum::routing::post(handle_mcp_post))
200 .route("/mcp", axum::routing::get(handle_mcp_get))
201 .route("/mcp", axum::routing::delete(handle_mcp_delete));
202
203 if self.state.config.cors_enabled {
204 router = router.layer(CorsLayer::permissive());
205 }
206
207 router.with_state(state)
208 }
209}
210
211#[async_trait]
212impl Transport for HttpTransportServer {
213 async fn send_message(&mut self, message: JsonRpcMessage) -> Result<()> {
214 let _ = self.state.message_sender.send(("*".to_string(), message));
216 Ok(())
217 }
218
219 async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
220 match self.message_receiver.recv().await {
221 Ok((_, message)) => Ok(message),
222 Err(_) => Err(TransportError::ConnectionClosed),
223 }
224 }
225
226 async fn close(&mut self) -> Result<()> {
227 Ok(())
228 }
229}
230
231fn extract_session_id(headers: &HeaderMap) -> Option<String> {
233 headers
234 .get("mcp-session-id")
235 .and_then(|v| v.to_str().ok())
236 .map(|s| s.to_string())
237}
238
239fn extract_protocol_version(headers: &HeaderMap) -> Option<String> {
241 headers
242 .get("mcp-protocol-version")
243 .and_then(|v| v.to_str().ok())
244 .map(|s| s.to_string())
245}
246
247fn extract_last_event_id(headers: &HeaderMap) -> Option<String> {
249 headers
250 .get("last-event-id")
251 .and_then(|v| v.to_str().ok())
252 .map(|s| s.to_string())
253}
254
255fn validate_origin_header(headers: &HeaderMap, config: &HttpTransportConfig) -> bool {
259 let origin = headers.get("origin").and_then(|v| v.to_str().ok());
260
261 validate_origin(origin, config.allow_origin.as_deref(), &config.host)
262}
263
264fn validate_protocol_version_header(version: &str) -> bool {
266 validate_protocol_version(version).is_ok()
267}
268
269fn validate_session_id_header(session_id: &str) -> bool {
271 validate_session_id(session_id).is_ok()
272}
273
274async fn handle_mcp_post(
277 State(state): State<Arc<HttpTransportState>>,
278 headers: HeaderMap,
279 body: Bytes,
280) -> impl IntoResponse {
281 let timer = state
283 .metrics
284 .as_ref()
285 .map(|metrics| RequestTimer::start("mcp_post", metrics.clone()));
286
287 let result = handle_mcp_post_internal(state, headers, body).await;
288
289 if let Some(timer) = timer {
291 let success = result.status() == StatusCode::OK;
292 timer.finish(success).await;
293 }
294
295 result
296}
297
298async fn handle_mcp_post_internal(
299 state: Arc<HttpTransportState>,
300 headers: HeaderMap,
301 body: Bytes,
302) -> Response {
303 if !validate_origin_header(&headers, &state.config) {
305 return (
306 StatusCode::FORBIDDEN,
307 Json(JsonRpcResponse::error(
308 JsonRpcError::new(-32000, "Origin not allowed".to_string()),
309 None,
310 )),
311 )
312 .into_response();
313 }
314
315 if let Some(protocol_version) = extract_protocol_version(&headers) {
317 if !validate_protocol_version_header(&protocol_version) {
318 return (
319 StatusCode::BAD_REQUEST,
320 Json(JsonRpcResponse::error(
321 JsonRpcError::new(
322 -32000,
323 format!("Unsupported protocol version: {protocol_version}"),
324 ),
325 None,
326 )),
327 )
328 .into_response();
329 }
330 }
331
332 let is_initial_connection = body.is_empty() || {
334 if let Ok(message) = serde_json::from_slice::<JsonRpcMessage>(&body) {
335 matches!(message, JsonRpcMessage::Request(req) if req.method == "initialize")
336 } else {
337 false
338 }
339 };
340
341 let session_id = if is_initial_connection {
342 extract_session_id(&headers).unwrap_or_else(generate_session_id)
343 } else {
344 match extract_session_id(&headers) {
345 Some(id) => {
346 if !validate_session_id_header(&id) {
347 return Json(JsonRpcResponse::error(
348 JsonRpcError::new(-32000, "Invalid session ID format".to_string()),
349 None,
350 ))
351 .into_response();
352 }
353 id
354 }
355 None => {
356 return Json(JsonRpcResponse::error(
357 JsonRpcError::new(-32000, "Missing session ID".to_string()),
358 None,
359 ))
360 .into_response();
361 }
362 }
363 };
364
365 {
367 let mut sessions = state.session_store.write().await;
368 sessions
369 .entry(session_id.clone())
370 .or_insert_with(SessionInfo::new);
371 }
372
373 let message: std::result::Result<JsonRpcMessage, serde_json::Error> =
375 serde_json::from_slice(&body);
376 let message = match message {
377 Ok(msg) => msg,
378 Err(_) => {
379 return Json(JsonRpcResponse::error(
380 JsonRpcError::new(-32700, "Parse error: Invalid JSON-RPC message".to_string()),
381 None,
382 ))
383 .into_response();
384 }
385 };
386
387 info!(
388 "Processing POST request for session {}: {:?}",
389 session_id, message
390 );
391 match message {
392 JsonRpcMessage::Request(request) => {
393 handle_jsonrpc_request(state, session_id, request).await
394 }
395 JsonRpcMessage::Notification(_) | JsonRpcMessage::Response(_) => {
396 handle_notification_or_response(state, session_id, message).await
397 }
398 }
399}
400
401async fn handle_mcp_get(
402 State(state): State<Arc<HttpTransportState>>,
403 headers: HeaderMap,
404) -> impl IntoResponse {
405 if !validate_origin_header(&headers, &state.config) {
406 return (
407 StatusCode::FORBIDDEN,
408 Json(JsonRpcResponse::error(
409 JsonRpcError::new(-32000, "Origin not allowed".to_string()),
410 None,
411 )),
412 )
413 .into_response();
414 }
415
416 if let Some(protocol_version) = extract_protocol_version(&headers) {
418 if !validate_protocol_version_header(&protocol_version) {
419 return (
420 StatusCode::BAD_REQUEST,
421 Json(JsonRpcResponse::error(
422 JsonRpcError::new(
423 -32000,
424 format!("Unsupported protocol version: {protocol_version}"),
425 ),
426 None,
427 )),
428 )
429 .into_response();
430 }
431 }
432
433 let session_id = extract_session_id(&headers).unwrap_or_else(generate_session_id);
434 let last_event_id = extract_last_event_id(&headers);
435
436 info!(
437 "Processing GET request for session {} (SSE stream){}",
438 session_id,
439 last_event_id
440 .as_ref()
441 .map(|id| format!(", resuming from event {id}"))
442 .unwrap_or_default()
443 );
444
445 {
447 let mut sessions = state.session_store.write().await;
448 let session_info = sessions
449 .entry(session_id.clone())
450 .or_insert_with(SessionInfo::new);
451 if let Some(event_id) = &last_event_id {
452 session_info.last_event_id = Some(event_id.clone());
453 }
454 }
455
456 let stream = create_sse_stream(state, session_id, last_event_id);
457 Sse::new(stream).into_response()
458}
459
460async fn handle_mcp_delete(
461 State(state): State<Arc<HttpTransportState>>,
462 headers: HeaderMap,
463) -> impl IntoResponse {
464 if !validate_origin_header(&headers, &state.config) {
465 return (
466 StatusCode::FORBIDDEN,
467 Json(JsonRpcResponse::error(
468 JsonRpcError::new(-32000, "Origin not allowed".to_string()),
469 None,
470 )),
471 )
472 .into_response();
473 }
474
475 if let Some(protocol_version) = extract_protocol_version(&headers) {
477 if !validate_protocol_version_header(&protocol_version) {
478 return (
479 StatusCode::BAD_REQUEST,
480 Json(JsonRpcResponse::error(
481 JsonRpcError::new(
482 -32000,
483 format!("Unsupported protocol version: {protocol_version}"),
484 ),
485 None,
486 )),
487 )
488 .into_response();
489 }
490 }
491
492 let session_id = extract_session_id(&headers).unwrap_or_else(generate_session_id);
493
494 {
496 let mut sessions = state.session_store.write().await;
497 sessions.remove(&session_id);
498 }
499
500 info!("Terminating session: {}", session_id);
501 StatusCode::OK.into_response()
502}
503
504async fn handle_jsonrpc_request(
506 state: Arc<HttpTransportState>,
507 session_id: String,
508 request: JsonRpcRequest,
509) -> Response {
510 let mut response_receiver = state.response_sender.subscribe();
512
513 if let Err(e) = state
515 .message_sender
516 .send((session_id.clone(), JsonRpcMessage::Request(request.clone())))
517 {
518 error!("Failed to send message to server: {}", e);
519 return Json(JsonRpcResponse::error(
520 JsonRpcError::new(-32000, format!("Failed to process message: {e}")),
521 request.id,
522 ))
523 .into_response();
524 }
525
526 match tokio::time::timeout(
528 std::time::Duration::from_secs(30), response_receiver.recv(),
530 )
531 .await
532 {
533 Ok(Ok((response_session_id, response_message))) => {
534 if response_session_id == session_id || response_session_id == "*" {
535 match response_message {
537 JsonRpcMessage::Response(response) => {
538 info!("Sending response back to client: {:?}", response);
539 (
540 StatusCode::OK,
541 [
542 ("mcp-session-id", response_session_id),
543 (
544 "mcp-protocol-version",
545 state.config.protocol_version.clone(),
546 ),
547 ],
548 Json(response),
549 )
550 .into_response()
551 }
552 _ => {
553 error!("Unexpected response type: {:?}", response_message);
555 Json(JsonRpcResponse::error(
556 JsonRpcError::new(-32000, "Unexpected response type".to_string()),
557 request.id,
558 ))
559 .into_response()
560 }
561 }
562 } else {
563 error!(
565 "Received response for wrong session: expected {}, got {}",
566 session_id, response_session_id
567 );
568 Json(JsonRpcResponse::error(
569 JsonRpcError::new(-32000, "Session mismatch".to_string()),
570 request.id,
571 ))
572 .into_response()
573 }
574 }
575 Ok(Err(e)) => {
576 error!("Failed to receive response: {}", e);
577 Json(JsonRpcResponse::error(
578 JsonRpcError::new(-32000, format!("Failed to receive response: {e}")),
579 request.id,
580 ))
581 .into_response()
582 }
583 Err(_) => {
584 error!("Request timeout waiting for response from server");
585 Json(JsonRpcResponse::error(
586 JsonRpcError::new(-32000, "Request timeout".to_string()),
587 request.id,
588 ))
589 .into_response()
590 }
591 }
592}
593
594async fn handle_notification_or_response(
596 state: Arc<HttpTransportState>,
597 session_id: String,
598 message: JsonRpcMessage,
599) -> Response {
600 if let Err(e) = state.message_sender.send((session_id.clone(), message)) {
602 error!("Failed to send message to server: {}", e);
603 return (
604 StatusCode::BAD_REQUEST,
605 Json(JsonRpcResponse::error(
606 JsonRpcError::new(-32000, format!("Failed to process message: {e}")),
607 None,
608 )),
609 )
610 .into_response();
611 }
612
613 (StatusCode::ACCEPTED, [("mcp-session-id", session_id)]).into_response()
615}
616
617fn create_sse_stream(
619 state: Arc<HttpTransportState>,
620 session_id: String,
621 last_event_id: Option<String>,
622) -> impl Stream<Item = std::result::Result<Event, axum::Error>> {
623 let response_receiver = state.response_sender.subscribe();
624 let enable_resumability = state.config.enable_sse_resumability;
625
626 stream::unfold(
627 (
628 response_receiver,
629 session_id,
630 last_event_id,
631 enable_resumability,
632 ),
633 |(mut receiver, session_id, last_event_id, enable_resumability)| async move {
634 match receiver.recv().await {
635 Ok((msg_session_id, message)) => {
636 if msg_session_id == session_id || msg_session_id == "*" {
637 let event_data = serde_json::to_string(&message).unwrap_or_default();
638 let mut event = Event::default().data(event_data);
639
640 if enable_resumability {
642 let event_id = generate_event_id();
643 event = event.id(event_id);
644 }
645
646 event = event.comment("keep-alive");
648
649 Some((
650 Ok(event),
651 (receiver, session_id, last_event_id, enable_resumability),
652 ))
653 } else {
654 Some((
656 Ok(Event::default().comment("keep-alive")),
657 (receiver, session_id, last_event_id, enable_resumability),
658 ))
659 }
660 }
661 Err(_) => None, }
663 },
664 )
665}