ultrafast_mcp_transport/streamable_http/
server.rs

1//! HTTP transport server implementation
2//!
3//! This module provides a MCP-compliant Streamable HTTP server implementation
4//! that follows the MCP specification for stateless request/response communication.
5
6use axum::{
7    Json,
8    extract::State,
9    http::{StatusCode, header::HeaderMap},
10    response::{IntoResponse, Response, Sse, sse::Event},
11    routing::Router,
12};
13use bytes::Bytes;
14use futures::stream::{self, Stream};
15use std::sync::Arc;
16use tokio::sync::broadcast;
17use tower_http::cors::CorsLayer;
18use tracing::{error, info};
19
20use ultrafast_mcp_core::{
21    protocol::{
22        jsonrpc::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse},
23        version::PROTOCOL_VERSION,
24    },
25    utils::{generate_event_id, generate_session_id},
26    validation::{validate_origin, validate_protocol_version, validate_session_id},
27};
28use ultrafast_mcp_monitoring::metrics::RequestTimer;
29use ultrafast_mcp_monitoring::{MetricsCollector, MonitoringSystem};
30
31use crate::{Result, Transport, TransportError};
32use async_trait::async_trait;
33
34/// HTTP transport configuration
35#[derive(Debug, Clone)]
36pub struct HttpTransportConfig {
37    pub host: String,
38    pub port: u16,
39    pub cors_enabled: bool,
40    pub protocol_version: String,
41    pub allow_origin: Option<String>,
42    pub monitoring_enabled: bool,
43    pub enable_sse_resumability: bool,
44}
45
46impl Default for HttpTransportConfig {
47    fn default() -> Self {
48        Self {
49            host: "127.0.0.1".to_string(),
50            port: 8080,
51            cors_enabled: true,
52            protocol_version: PROTOCOL_VERSION.to_string(),
53            allow_origin: Some("http://localhost:*".to_string()),
54            monitoring_enabled: true,
55            enable_sse_resumability: true,
56        }
57    }
58}
59
60/// Shared state for HTTP transport
61#[derive(Clone)]
62pub struct HttpTransportState {
63    pub message_sender: broadcast::Sender<(String, JsonRpcMessage)>,
64    pub response_sender: broadcast::Sender<(String, JsonRpcMessage)>,
65    pub config: HttpTransportConfig,
66    pub metrics: Option<Arc<MetricsCollector>>,
67    pub monitoring: Option<Arc<MonitoringSystem>>,
68    pub session_store: Arc<tokio::sync::RwLock<std::collections::HashMap<String, SessionInfo>>>,
69}
70
71/// Session information for tracking and resumability
72#[derive(Debug, Clone)]
73pub struct SessionInfo {
74    pub created_at: std::time::SystemTime,
75    pub last_event_id: Option<String>,
76    pub active_streams: std::collections::HashSet<String>,
77}
78
79impl SessionInfo {
80    pub fn new() -> Self {
81        Self {
82            created_at: std::time::SystemTime::now(),
83            last_event_id: None,
84            active_streams: std::collections::HashSet::new(),
85        }
86    }
87}
88
89impl Default for SessionInfo {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95/// HTTP transport server implementation
96pub struct HttpTransportServer {
97    state: HttpTransportState,
98    message_receiver: broadcast::Receiver<(String, JsonRpcMessage)>,
99}
100
101impl HttpTransportServer {
102    pub fn new(config: HttpTransportConfig) -> Self {
103        let (message_sender, message_receiver) = broadcast::channel(1000);
104        let (response_sender, _) = broadcast::channel(1000);
105
106        let state = HttpTransportState {
107            message_sender,
108            response_sender,
109            config,
110            metrics: None,
111            monitoring: None,
112            session_store: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
113        };
114
115        Self {
116            state,
117            message_receiver,
118        }
119    }
120
121    pub fn with_metrics(mut self, metrics: Arc<MetricsCollector>) -> Self {
122        self.state.metrics = Some(metrics);
123        self
124    }
125
126    pub fn with_monitoring(mut self, monitoring: Arc<MonitoringSystem>) -> Self {
127        self.state.monitoring = Some(monitoring);
128        self
129    }
130
131    pub fn get_message_receiver(&self) -> broadcast::Receiver<(String, JsonRpcMessage)> {
132        self.state.message_sender.subscribe()
133    }
134
135    pub fn get_message_sender(&self) -> broadcast::Sender<(String, JsonRpcMessage)> {
136        self.state.message_sender.clone()
137    }
138
139    pub fn get_response_sender(&self) -> broadcast::Sender<(String, JsonRpcMessage)> {
140        self.state.response_sender.clone()
141    }
142
143    pub fn get_state(&self) -> HttpTransportState {
144        self.state.clone()
145    }
146
147    pub fn get_metrics(&self) -> Option<Arc<MetricsCollector>> {
148        self.state.metrics.clone()
149    }
150
151    pub fn get_monitoring(&self) -> Option<Arc<MonitoringSystem>> {
152        self.state.monitoring.clone()
153    }
154
155    pub async fn run(self) -> Result<()> {
156        info!(
157            "Starting HTTP transport server on {}:{}",
158            self.state.config.host, self.state.config.port
159        );
160
161        let app = self.create_router();
162        let addr = (self.state.config.host.as_str(), self.state.config.port);
163
164        let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
165            TransportError::InitializationError {
166                message: format!("Failed to bind to address: {e}"),
167            }
168        })?;
169
170        // Start monitoring HTTP server if enabled
171        if let Some(monitoring) = &self.state.monitoring {
172            let monitoring_addr =
173                format!("{}:{}", self.state.config.host, self.state.config.port + 1)
174                    .parse()
175                    .map_err(|e| TransportError::InitializationError {
176                        message: format!("Failed to parse monitoring address: {e}"),
177                    })?;
178
179            let monitoring_clone = monitoring.clone();
180            tokio::spawn(async move {
181                if let Err(e) = monitoring_clone.start_http_server(monitoring_addr).await {
182                    error!("Failed to start monitoring server: {}", e);
183                }
184            });
185        }
186
187        axum::serve(listener, app.into_make_service())
188            .await
189            .map_err(|e| TransportError::InitializationError {
190                message: format!("Server failed: {e}"),
191            })?;
192
193        Ok(())
194    }
195
196    fn create_router(&self) -> Router {
197        let state = Arc::new(self.state.clone());
198        let mut router = Router::new()
199            .route("/mcp", axum::routing::post(handle_mcp_post))
200            .route("/mcp", axum::routing::get(handle_mcp_get))
201            .route("/mcp", axum::routing::delete(handle_mcp_delete));
202
203        if self.state.config.cors_enabled {
204            router = router.layer(CorsLayer::permissive());
205        }
206
207        router.with_state(state)
208    }
209}
210
211#[async_trait]
212impl Transport for HttpTransportServer {
213    async fn send_message(&mut self, message: JsonRpcMessage) -> Result<()> {
214        // Broadcast to all connected sessions
215        let _ = self.state.message_sender.send(("*".to_string(), message));
216        Ok(())
217    }
218
219    async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
220        match self.message_receiver.recv().await {
221            Ok((_, message)) => Ok(message),
222            Err(_) => Err(TransportError::ConnectionClosed),
223        }
224    }
225
226    async fn close(&mut self) -> Result<()> {
227        Ok(())
228    }
229}
230
231/// Extract session ID from headers
232fn extract_session_id(headers: &HeaderMap) -> Option<String> {
233    headers
234        .get("mcp-session-id")
235        .and_then(|v| v.to_str().ok())
236        .map(|s| s.to_string())
237}
238
239/// Extract protocol version from headers
240fn extract_protocol_version(headers: &HeaderMap) -> Option<String> {
241    headers
242        .get("mcp-protocol-version")
243        .and_then(|v| v.to_str().ok())
244        .map(|s| s.to_string())
245}
246
247/// Extract Last-Event-ID from headers for SSE resumability
248fn extract_last_event_id(headers: &HeaderMap) -> Option<String> {
249    headers
250        .get("last-event-id")
251        .and_then(|v| v.to_str().ok())
252        .map(|s| s.to_string())
253}
254
255// Validation functions moved to ultrafast_mcp_core::validation
256
257/// Validate Origin header for security using core validation
258fn validate_origin_header(headers: &HeaderMap, config: &HttpTransportConfig) -> bool {
259    let origin = headers.get("origin").and_then(|v| v.to_str().ok());
260
261    validate_origin(origin, config.allow_origin.as_deref(), &config.host)
262}
263
264/// Validate protocol version header using core validation
265fn validate_protocol_version_header(version: &str) -> bool {
266    validate_protocol_version(version).is_ok()
267}
268
269/// Validate session ID format using core validation  
270fn validate_session_id_header(session_id: &str) -> bool {
271    validate_session_id(session_id).is_ok()
272}
273
274// Session ID and event ID generation functions moved to ultrafast_mcp_core::utils
275
276async fn handle_mcp_post(
277    State(state): State<Arc<HttpTransportState>>,
278    headers: HeaderMap,
279    body: Bytes,
280) -> impl IntoResponse {
281    // Start request timer for monitoring
282    let timer = state
283        .metrics
284        .as_ref()
285        .map(|metrics| RequestTimer::start("mcp_post", metrics.clone()));
286
287    let result = handle_mcp_post_internal(state, headers, body).await;
288
289    // Record metrics
290    if let Some(timer) = timer {
291        let success = result.status() == StatusCode::OK;
292        timer.finish(success).await;
293    }
294
295    result
296}
297
298async fn handle_mcp_post_internal(
299    state: Arc<HttpTransportState>,
300    headers: HeaderMap,
301    body: Bytes,
302) -> Response {
303    // Validate Origin header
304    if !validate_origin_header(&headers, &state.config) {
305        return (
306            StatusCode::FORBIDDEN,
307            Json(JsonRpcResponse::error(
308                JsonRpcError::new(-32000, "Origin not allowed".to_string()),
309                None,
310            )),
311        )
312            .into_response();
313    }
314
315    // Validate protocol version header if present
316    if let Some(protocol_version) = extract_protocol_version(&headers) {
317        if !validate_protocol_version_header(&protocol_version) {
318            return (
319                StatusCode::BAD_REQUEST,
320                Json(JsonRpcResponse::error(
321                    JsonRpcError::new(
322                        -32000,
323                        format!("Unsupported protocol version: {protocol_version}"),
324                    ),
325                    None,
326                )),
327            )
328                .into_response();
329        }
330    }
331
332    // Check if this is an initial connection (initialize request or empty body)
333    let is_initial_connection = body.is_empty() || {
334        if let Ok(message) = serde_json::from_slice::<JsonRpcMessage>(&body) {
335            matches!(message, JsonRpcMessage::Request(req) if req.method == "initialize")
336        } else {
337            false
338        }
339    };
340
341    let session_id = if is_initial_connection {
342        extract_session_id(&headers).unwrap_or_else(generate_session_id)
343    } else {
344        match extract_session_id(&headers) {
345            Some(id) => {
346                if !validate_session_id_header(&id) {
347                    return Json(JsonRpcResponse::error(
348                        JsonRpcError::new(-32000, "Invalid session ID format".to_string()),
349                        None,
350                    ))
351                    .into_response();
352                }
353                id
354            }
355            None => {
356                return Json(JsonRpcResponse::error(
357                    JsonRpcError::new(-32000, "Missing session ID".to_string()),
358                    None,
359                ))
360                .into_response();
361            }
362        }
363    };
364
365    // Store session info
366    {
367        let mut sessions = state.session_store.write().await;
368        sessions
369            .entry(session_id.clone())
370            .or_insert_with(SessionInfo::new);
371    }
372
373    // Try to parse the body as a JSON-RPC message
374    let message: std::result::Result<JsonRpcMessage, serde_json::Error> =
375        serde_json::from_slice(&body);
376    let message = match message {
377        Ok(msg) => msg,
378        Err(_) => {
379            return Json(JsonRpcResponse::error(
380                JsonRpcError::new(-32700, "Parse error: Invalid JSON-RPC message".to_string()),
381                None,
382            ))
383            .into_response();
384        }
385    };
386
387    info!(
388        "Processing POST request for session {}: {:?}",
389        session_id, message
390    );
391    match message {
392        JsonRpcMessage::Request(request) => {
393            handle_jsonrpc_request(state, session_id, request).await
394        }
395        JsonRpcMessage::Notification(_) | JsonRpcMessage::Response(_) => {
396            handle_notification_or_response(state, session_id, message).await
397        }
398    }
399}
400
401async fn handle_mcp_get(
402    State(state): State<Arc<HttpTransportState>>,
403    headers: HeaderMap,
404) -> impl IntoResponse {
405    if !validate_origin_header(&headers, &state.config) {
406        return (
407            StatusCode::FORBIDDEN,
408            Json(JsonRpcResponse::error(
409                JsonRpcError::new(-32000, "Origin not allowed".to_string()),
410                None,
411            )),
412        )
413            .into_response();
414    }
415
416    // Validate protocol version header if present
417    if let Some(protocol_version) = extract_protocol_version(&headers) {
418        if !validate_protocol_version_header(&protocol_version) {
419            return (
420                StatusCode::BAD_REQUEST,
421                Json(JsonRpcResponse::error(
422                    JsonRpcError::new(
423                        -32000,
424                        format!("Unsupported protocol version: {protocol_version}"),
425                    ),
426                    None,
427                )),
428            )
429                .into_response();
430        }
431    }
432
433    let session_id = extract_session_id(&headers).unwrap_or_else(generate_session_id);
434    let last_event_id = extract_last_event_id(&headers);
435
436    info!(
437        "Processing GET request for session {} (SSE stream){}",
438        session_id,
439        last_event_id
440            .as_ref()
441            .map(|id| format!(", resuming from event {id}"))
442            .unwrap_or_default()
443    );
444
445    // Store session info
446    {
447        let mut sessions = state.session_store.write().await;
448        let session_info = sessions
449            .entry(session_id.clone())
450            .or_insert_with(SessionInfo::new);
451        if let Some(event_id) = &last_event_id {
452            session_info.last_event_id = Some(event_id.clone());
453        }
454    }
455
456    let stream = create_sse_stream(state, session_id, last_event_id);
457    Sse::new(stream).into_response()
458}
459
460async fn handle_mcp_delete(
461    State(state): State<Arc<HttpTransportState>>,
462    headers: HeaderMap,
463) -> impl IntoResponse {
464    if !validate_origin_header(&headers, &state.config) {
465        return (
466            StatusCode::FORBIDDEN,
467            Json(JsonRpcResponse::error(
468                JsonRpcError::new(-32000, "Origin not allowed".to_string()),
469                None,
470            )),
471        )
472            .into_response();
473    }
474
475    // Validate protocol version header if present
476    if let Some(protocol_version) = extract_protocol_version(&headers) {
477        if !validate_protocol_version_header(&protocol_version) {
478            return (
479                StatusCode::BAD_REQUEST,
480                Json(JsonRpcResponse::error(
481                    JsonRpcError::new(
482                        -32000,
483                        format!("Unsupported protocol version: {protocol_version}"),
484                    ),
485                    None,
486                )),
487            )
488                .into_response();
489        }
490    }
491
492    let session_id = extract_session_id(&headers).unwrap_or_else(generate_session_id);
493
494    // Remove session from store
495    {
496        let mut sessions = state.session_store.write().await;
497        sessions.remove(&session_id);
498    }
499
500    info!("Terminating session: {}", session_id);
501    StatusCode::OK.into_response()
502}
503
504/// Handle JSON-RPC requests
505async fn handle_jsonrpc_request(
506    state: Arc<HttpTransportState>,
507    session_id: String,
508    request: JsonRpcRequest,
509) -> Response {
510    // Create a response receiver for this specific request
511    let mut response_receiver = state.response_sender.subscribe();
512
513    // Send message to server for processing
514    if let Err(e) = state
515        .message_sender
516        .send((session_id.clone(), JsonRpcMessage::Request(request.clone())))
517    {
518        error!("Failed to send message to server: {}", e);
519        return Json(JsonRpcResponse::error(
520            JsonRpcError::new(-32000, format!("Failed to process message: {e}")),
521            request.id,
522        ))
523        .into_response();
524    }
525
526    // Wait for response from server with timeout
527    match tokio::time::timeout(
528        std::time::Duration::from_secs(30), // Increased timeout to 30 seconds
529        response_receiver.recv(),
530    )
531    .await
532    {
533        Ok(Ok((response_session_id, response_message))) => {
534            if response_session_id == session_id || response_session_id == "*" {
535                // Return the actual response from the server
536                match response_message {
537                    JsonRpcMessage::Response(response) => {
538                        info!("Sending response back to client: {:?}", response);
539                        (
540                            StatusCode::OK,
541                            [
542                                ("mcp-session-id", response_session_id),
543                                (
544                                    "mcp-protocol-version",
545                                    state.config.protocol_version.clone(),
546                                ),
547                            ],
548                            Json(response),
549                        )
550                            .into_response()
551                    }
552                    _ => {
553                        // Unexpected message type
554                        error!("Unexpected response type: {:?}", response_message);
555                        Json(JsonRpcResponse::error(
556                            JsonRpcError::new(-32000, "Unexpected response type".to_string()),
557                            request.id,
558                        ))
559                        .into_response()
560                    }
561                }
562            } else {
563                // Wrong session
564                error!(
565                    "Received response for wrong session: expected {}, got {}",
566                    session_id, response_session_id
567                );
568                Json(JsonRpcResponse::error(
569                    JsonRpcError::new(-32000, "Session mismatch".to_string()),
570                    request.id,
571                ))
572                .into_response()
573            }
574        }
575        Ok(Err(e)) => {
576            error!("Failed to receive response: {}", e);
577            Json(JsonRpcResponse::error(
578                JsonRpcError::new(-32000, format!("Failed to receive response: {e}")),
579                request.id,
580            ))
581            .into_response()
582        }
583        Err(_) => {
584            error!("Request timeout waiting for response from server");
585            Json(JsonRpcResponse::error(
586                JsonRpcError::new(-32000, "Request timeout".to_string()),
587                request.id,
588            ))
589            .into_response()
590        }
591    }
592}
593
594/// Handle notifications and responses
595async fn handle_notification_or_response(
596    state: Arc<HttpTransportState>,
597    session_id: String,
598    message: JsonRpcMessage,
599) -> Response {
600    // Send message to server for processing
601    if let Err(e) = state.message_sender.send((session_id.clone(), message)) {
602        error!("Failed to send message to server: {}", e);
603        return (
604            StatusCode::BAD_REQUEST,
605            Json(JsonRpcResponse::error(
606                JsonRpcError::new(-32000, format!("Failed to process message: {e}")),
607                None,
608            )),
609        )
610            .into_response();
611    }
612
613    // Return 202 Accepted for notifications and responses
614    (StatusCode::ACCEPTED, [("mcp-session-id", session_id)]).into_response()
615}
616
617/// Create SSE stream for server-to-client communication
618fn create_sse_stream(
619    state: Arc<HttpTransportState>,
620    session_id: String,
621    last_event_id: Option<String>,
622) -> impl Stream<Item = std::result::Result<Event, axum::Error>> {
623    let response_receiver = state.response_sender.subscribe();
624    let enable_resumability = state.config.enable_sse_resumability;
625
626    stream::unfold(
627        (
628            response_receiver,
629            session_id,
630            last_event_id,
631            enable_resumability,
632        ),
633        |(mut receiver, session_id, last_event_id, enable_resumability)| async move {
634            match receiver.recv().await {
635                Ok((msg_session_id, message)) => {
636                    if msg_session_id == session_id || msg_session_id == "*" {
637                        let event_data = serde_json::to_string(&message).unwrap_or_default();
638                        let mut event = Event::default().data(event_data);
639
640                        // Add event ID for resumability if enabled
641                        if enable_resumability {
642                            let event_id = generate_event_id();
643                            event = event.id(event_id);
644                        }
645
646                        // Add keep-alive comment
647                        event = event.comment("keep-alive");
648
649                        Some((
650                            Ok(event),
651                            (receiver, session_id, last_event_id, enable_resumability),
652                        ))
653                    } else {
654                        // Skip messages for other sessions, send keep-alive comment
655                        Some((
656                            Ok(Event::default().comment("keep-alive")),
657                            (receiver, session_id, last_event_id, enable_resumability),
658                        ))
659                    }
660                }
661                Err(_) => None, // Connection closed
662            }
663        },
664    )
665}