turbomcp_transport/
bidirectional.rs

1//! Bidirectional transport implementation with server-initiated request support
2//!
3//! This module provides enhanced transport capabilities for MCP 2025-06-18 protocol
4//! including server-initiated requests, message correlation, and protocol direction validation.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13use tokio::sync::{RwLock, mpsc, oneshot};
14use tokio::time::timeout;
15use turbomcp_protocol::ServerInitiatedType;
16use uuid::Uuid;
17
18use crate::core::{
19    BidirectionalTransport, Transport, TransportCapabilities, TransportError, TransportMessage,
20    TransportResult, TransportState, TransportType,
21};
22
23/// Message direction in the transport layer
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub enum MessageDirection {
26    /// Client to server message
27    ClientToServer,
28    /// Server to client message
29    ServerToClient,
30}
31
32/// Correlation context for request-response patterns
33#[derive(Debug)]
34pub struct CorrelationContext {
35    /// Unique correlation ID
36    pub correlation_id: String,
37    /// Original request message ID
38    pub request_id: String,
39    /// Response channel (not cloneable, so we don't derive Clone)
40    pub response_tx: Option<oneshot::Sender<TransportMessage>>,
41    /// Timeout duration
42    pub timeout: Duration,
43    /// Creation timestamp
44    pub created_at: std::time::Instant,
45}
46
47/// Enhanced bidirectional transport wrapper
48#[derive(Debug)]
49pub struct BidirectionalTransportWrapper<T: Transport> {
50    /// Inner transport implementation
51    inner: T,
52    /// Message direction for this transport
53    direction: MessageDirection,
54    /// Active correlations for request-response
55    correlations: Arc<DashMap<String, CorrelationContext>>,
56    /// Server-initiated request handlers (using String keys instead of ServerInitiatedType)
57    server_handlers: Arc<DashMap<String, mpsc::Sender<TransportMessage>>>,
58    /// Protocol direction validator
59    validator: Arc<ProtocolDirectionValidator>,
60    /// Message router
61    router: Arc<MessageRouter>,
62    /// Connection state
63    state: Arc<RwLock<ConnectionState>>,
64}
65
66/// Connection state for bidirectional communication
67#[derive(Debug, Clone, Default)]
68pub struct ConnectionState {
69    /// Whether server-initiated requests are enabled
70    pub server_initiated_enabled: bool,
71    /// Active server-initiated request IDs
72    pub active_server_requests: Vec<String>,
73    /// Pending elicitations
74    pub pending_elicitations: Vec<String>,
75    /// Connection metadata
76    pub metadata: HashMap<String, serde_json::Value>,
77}
78
79/// Protocol direction validator
80#[derive(Debug)]
81pub struct ProtocolDirectionValidator {
82    /// Allowed client-to-server message types
83    client_to_server: Vec<String>,
84    /// Allowed server-to-client message types
85    server_to_client: Vec<String>,
86    /// Bidirectional message types
87    bidirectional: Vec<String>,
88}
89
90impl Default for ProtocolDirectionValidator {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96impl ProtocolDirectionValidator {
97    /// Create a new validator with MCP protocol rules
98    pub fn new() -> Self {
99        Self {
100            client_to_server: vec![
101                "initialize".to_string(),
102                "initialized".to_string(),
103                "tools/call".to_string(),
104                "resources/read".to_string(),
105                "prompts/get".to_string(),
106                "completion/complete".to_string(),
107                "resources/templates/list".to_string(),
108            ],
109            server_to_client: vec![
110                "sampling/createMessage".to_string(),
111                "roots/list".to_string(),
112                "elicitation/create".to_string(),
113                "notifications/message".to_string(),
114                "notifications/resources/updated".to_string(),
115                "notifications/tools/updated".to_string(),
116            ],
117            bidirectional: vec![
118                "ping".to_string(),
119                "notifications/cancelled".to_string(),
120                "notifications/progress".to_string(),
121            ],
122        }
123    }
124
125    /// Validate message direction
126    pub fn validate(&self, message_type: &str, direction: MessageDirection) -> bool {
127        // Check bidirectional first
128        if self.bidirectional.contains(&message_type.to_string()) {
129            return true;
130        }
131
132        match direction {
133            MessageDirection::ClientToServer => {
134                self.client_to_server.contains(&message_type.to_string())
135            }
136            MessageDirection::ServerToClient => {
137                self.server_to_client.contains(&message_type.to_string())
138            }
139        }
140    }
141
142    /// Get allowed direction for a message type
143    pub fn get_allowed_direction(&self, message_type: &str) -> Option<MessageDirection> {
144        if self.bidirectional.contains(&message_type.to_string()) {
145            // Bidirectional messages can go either way
146            return None;
147        }
148
149        if self.client_to_server.contains(&message_type.to_string()) {
150            return Some(MessageDirection::ClientToServer);
151        }
152
153        if self.server_to_client.contains(&message_type.to_string()) {
154            return Some(MessageDirection::ServerToClient);
155        }
156
157        None
158    }
159}
160
161/// Message router for bidirectional communication
162pub struct MessageRouter {
163    /// Route table for message types
164    routes: DashMap<String, RouteHandler>,
165    /// Default handler for unrouted messages
166    default_handler: Option<RouteHandler>,
167}
168
169impl std::fmt::Debug for MessageRouter {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.debug_struct("MessageRouter")
172            .field("routes_count", &self.routes.len())
173            .field("has_default_handler", &self.default_handler.is_some())
174            .finish()
175    }
176}
177
178/// Route handler for messages
179type RouteHandler = Arc<dyn Fn(TransportMessage) -> RouteAction + Send + Sync>;
180
181/// Action to take for a routed message
182#[derive(Debug, Clone)]
183pub enum RouteAction {
184    /// Forward the message
185    Forward,
186    /// Handle locally
187    Handle(String), // Handler name
188    /// Drop the message
189    Drop,
190    /// Transform and forward
191    Transform(TransportMessage),
192}
193
194impl Default for MessageRouter {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200impl MessageRouter {
201    /// Create a new message router
202    pub fn new() -> Self {
203        Self {
204            routes: DashMap::new(),
205            default_handler: None,
206        }
207    }
208
209    /// Add a route for a message type
210    pub fn add_route<F>(&self, message_type: String, handler: F)
211    where
212        F: Fn(TransportMessage) -> RouteAction + Send + Sync + 'static,
213    {
214        self.routes.insert(message_type, Arc::new(handler));
215    }
216
217    /// Route a message
218    pub fn route(&self, message: &TransportMessage) -> RouteAction {
219        // Extract message type from the message
220        // This would need to parse the message content
221        let message_type = extract_message_type(message);
222
223        if let Some(handler) = self.routes.get(&message_type) {
224            handler(message.clone())
225        } else if let Some(ref default) = self.default_handler {
226            default(message.clone())
227        } else {
228            RouteAction::Forward
229        }
230    }
231}
232
233/// Extract message type from transport message
234fn extract_message_type(message: &TransportMessage) -> String {
235    // Current implementation: Basic JSON-RPC method extraction (works for message routing)
236    // Enhanced JSON-RPC parsing can be added in future iterations as needed
237    // Current implementation handles the essential method extraction for routing
238    if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&message.payload)
239        && let Some(method) = json.get("method").and_then(|m| m.as_str())
240    {
241        return method.to_string();
242    }
243    "unknown".to_string()
244}
245
246impl<T: Transport> BidirectionalTransportWrapper<T> {
247    /// Create a new bidirectional transport wrapper
248    pub fn new(inner: T, direction: MessageDirection) -> Self {
249        Self {
250            inner,
251            direction,
252            correlations: Arc::new(DashMap::new()),
253            server_handlers: Arc::new(DashMap::new()),
254            validator: Arc::new(ProtocolDirectionValidator::new()),
255            router: Arc::new(MessageRouter::new()),
256            state: Arc::new(RwLock::new(ConnectionState::default())),
257        }
258    }
259
260    /// Register a handler for server-initiated requests
261    pub fn register_server_handler(
262        &self,
263        request_type: ServerInitiatedType,
264        handler: mpsc::Sender<TransportMessage>,
265    ) {
266        let key = match request_type {
267            ServerInitiatedType::Sampling => "sampling/createMessage",
268            ServerInitiatedType::Roots => "roots/list",
269            ServerInitiatedType::Elicitation => "elicitation/create",
270            ServerInitiatedType::Ping => "ping",
271        };
272        self.server_handlers.insert(key.to_string(), handler);
273    }
274
275    /// Process incoming message with direction validation
276    async fn process_incoming(&self, message: TransportMessage) -> TransportResult<()> {
277        let message_type = extract_message_type(&message);
278
279        // Validate direction
280        if !self.validator.validate(&message_type, self.direction) {
281            return Err(TransportError::ProtocolError(format!(
282                "Invalid message direction for {}: expected {:?}",
283                message_type, self.direction
284            )));
285        }
286
287        // Check for correlation
288        if let Some(correlation_id) = extract_correlation_id(&message)
289            && let Some((_, context)) = self.correlations.remove(&correlation_id)
290        {
291            // This is a response to a previous request
292            if let Some(tx) = context.response_tx {
293                let _ = tx.send(message);
294            }
295            return Ok(());
296        }
297
298        // Route the message
299        match self.router.route(&message) {
300            RouteAction::Forward => {
301                // Forward to standard processing
302                self.handle_standard_message(message).await
303            }
304            RouteAction::Handle(handler_name) => {
305                // Route to specific handler
306                self.handle_with_handler(message, &handler_name).await
307            }
308            RouteAction::Drop => Ok(()),
309            RouteAction::Transform(transformed) => {
310                // Process transformed message
311                self.handle_standard_message(transformed).await
312            }
313        }
314    }
315
316    /// Handle standard message processing
317    async fn handle_standard_message(&self, message: TransportMessage) -> TransportResult<()> {
318        // Check if this is a server-initiated request
319        let message_type = extract_message_type(&message);
320        if let Some(handler) = self.server_handlers.get(&message_type) {
321            handler
322                .send(message)
323                .await
324                .map_err(|e| TransportError::Internal(e.to_string()))?;
325        }
326        Ok(())
327    }
328
329    /// Handle message with specific handler
330    async fn handle_with_handler(
331        &self,
332        _message: TransportMessage,
333        _handler_name: &str,
334    ) -> TransportResult<()> {
335        // This would route to registered handlers
336        // Implementation depends on handler registration system
337        Ok(())
338    }
339
340    /// Send a server-initiated request
341    pub async fn send_server_request(
342        &self,
343        _request_type: ServerInitiatedType,
344        message: TransportMessage,
345        timeout_duration: Duration,
346    ) -> TransportResult<TransportMessage> {
347        // Validate this is allowed from server
348        if self.direction != MessageDirection::ServerToClient {
349            return Err(TransportError::ProtocolError(
350                "Cannot send server-initiated request from client transport".to_string(),
351            ));
352        }
353
354        // Create correlation context
355        let correlation_id = Uuid::new_v4().to_string();
356        let (tx, rx) = oneshot::channel();
357
358        let context = CorrelationContext {
359            correlation_id: correlation_id.clone(),
360            request_id: Uuid::new_v4().to_string(),
361            response_tx: Some(tx),
362            timeout: timeout_duration,
363            created_at: std::time::Instant::now(),
364        };
365
366        self.correlations.insert(correlation_id.clone(), context);
367
368        // Send the message
369        self.inner.send(message).await?;
370
371        // Wait for response with timeout
372        match timeout(timeout_duration, rx).await {
373            Ok(Ok(response)) => Ok(response),
374            Ok(Err(_)) => Err(TransportError::Internal(
375                "Response channel closed".to_string(),
376            )),
377            Err(_) => {
378                self.correlations.remove(&correlation_id);
379                Err(TransportError::Timeout)
380            }
381        }
382    }
383
384    /// Enable server-initiated requests
385    pub async fn enable_server_initiated(&self) {
386        let mut state = self.state.write().await;
387        state.server_initiated_enabled = true;
388    }
389
390    /// Check if server-initiated requests are enabled
391    pub async fn is_server_initiated_enabled(&self) -> bool {
392        let state = self.state.read().await;
393        state.server_initiated_enabled
394    }
395}
396
397// Helper functions
398
399/// Extract correlation ID from message
400fn extract_correlation_id(message: &TransportMessage) -> Option<String> {
401    if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&message.payload) {
402        json.get("correlation_id")
403            .and_then(|id| id.as_str())
404            .map(|s| s.to_string())
405    } else {
406        None
407    }
408}
409
410/// Detect server-initiated request type
411#[allow(dead_code)]
412fn detect_server_initiated_type(message: &TransportMessage) -> Option<ServerInitiatedType> {
413    let message_type = extract_message_type(message);
414
415    match message_type.as_str() {
416        "sampling/createMessage" => Some(ServerInitiatedType::Sampling),
417        "roots/list" => Some(ServerInitiatedType::Roots),
418        "elicitation/create" => Some(ServerInitiatedType::Elicitation),
419        "ping" => Some(ServerInitiatedType::Ping),
420        _ => None,
421    }
422}
423
424// Implement Transport trait for the wrapper
425#[async_trait]
426impl<T: Transport> Transport for BidirectionalTransportWrapper<T> {
427    fn transport_type(&self) -> TransportType {
428        self.inner.transport_type()
429    }
430
431    fn capabilities(&self) -> &TransportCapabilities {
432        self.inner.capabilities()
433    }
434
435    async fn state(&self) -> TransportState {
436        self.inner.state().await
437    }
438
439    async fn connect(&self) -> TransportResult<()> {
440        self.inner.connect().await
441    }
442
443    async fn disconnect(&self) -> TransportResult<()> {
444        // Clean up correlations
445        self.correlations.clear();
446        self.inner.disconnect().await
447    }
448
449    async fn send(&self, message: TransportMessage) -> TransportResult<()> {
450        // Validate direction before sending
451        let message_type = extract_message_type(&message);
452        if !self.validator.validate(&message_type, self.direction) {
453            return Err(TransportError::ProtocolError(format!(
454                "Cannot send {} in direction {:?}",
455                message_type, self.direction
456            )));
457        }
458        self.inner.send(message).await
459    }
460
461    async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
462        if let Some(message) = self.inner.receive().await? {
463            self.process_incoming(message.clone()).await?;
464            Ok(Some(message))
465        } else {
466            Ok(None)
467        }
468    }
469
470    async fn metrics(&self) -> crate::core::TransportMetrics {
471        self.inner.metrics().await
472    }
473}
474
475// Implement BidirectionalTransport trait
476#[async_trait]
477impl<T: Transport> BidirectionalTransport for BidirectionalTransportWrapper<T> {
478    async fn send_request(
479        &self,
480        message: TransportMessage,
481        timeout_duration: Option<Duration>,
482    ) -> TransportResult<TransportMessage> {
483        let timeout_duration = timeout_duration.unwrap_or(Duration::from_secs(30));
484
485        // Create correlation
486        let correlation_id = Uuid::new_v4().to_string();
487        let (tx, rx) = oneshot::channel();
488
489        let context = CorrelationContext {
490            correlation_id: correlation_id.clone(),
491            request_id: Uuid::new_v4().to_string(),
492            response_tx: Some(tx),
493            timeout: timeout_duration,
494            created_at: std::time::Instant::now(),
495        };
496
497        self.correlations.insert(correlation_id.clone(), context);
498
499        // Send message
500        self.send(message).await?;
501
502        // Wait for response
503        match timeout(timeout_duration, rx).await {
504            Ok(Ok(response)) => Ok(response),
505            Ok(Err(_)) => Err(TransportError::Internal(
506                "Response channel closed".to_string(),
507            )),
508            Err(_) => {
509                self.correlations.remove(&correlation_id);
510                Err(TransportError::Timeout)
511            }
512        }
513    }
514
515    async fn start_correlation(&self, correlation_id: String) -> TransportResult<()> {
516        let context = CorrelationContext {
517            correlation_id: correlation_id.clone(),
518            request_id: Uuid::new_v4().to_string(),
519            response_tx: None,
520            timeout: Duration::from_secs(30),
521            created_at: std::time::Instant::now(),
522        };
523
524        self.correlations.insert(correlation_id, context);
525        Ok(())
526    }
527
528    async fn stop_correlation(&self, correlation_id: &str) -> TransportResult<()> {
529        self.correlations.remove(correlation_id);
530        Ok(())
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    #[test]
539    fn test_protocol_direction_validator() {
540        let validator = ProtocolDirectionValidator::new();
541
542        // Test client-to-server messages
543        assert!(validator.validate("tools/call", MessageDirection::ClientToServer));
544        assert!(!validator.validate("tools/call", MessageDirection::ServerToClient));
545
546        // Test server-to-client messages
547        assert!(validator.validate("sampling/createMessage", MessageDirection::ServerToClient));
548        assert!(!validator.validate("sampling/createMessage", MessageDirection::ClientToServer));
549
550        // Test bidirectional messages
551        assert!(validator.validate("ping", MessageDirection::ClientToServer));
552        assert!(validator.validate("ping", MessageDirection::ServerToClient));
553    }
554
555    #[test]
556    fn test_message_router() {
557        let router = MessageRouter::new();
558
559        router.add_route("test".to_string(), |_msg| {
560            RouteAction::Handle("test_handler".to_string())
561        });
562
563        let message = TransportMessage {
564            id: turbomcp_protocol::MessageId::from("test-message-id"),
565            payload: br#"{"method": "test"}"#.to_vec().into(),
566            metadata: Default::default(),
567        };
568
569        match router.route(&message) {
570            RouteAction::Handle(handler) => assert_eq!(handler, "test_handler"),
571            _ => panic!("Expected Handle action"),
572        }
573    }
574
575    #[tokio::test]
576    async fn test_connection_state() {
577        let state = ConnectionState::default();
578        assert!(!state.server_initiated_enabled);
579        assert!(state.active_server_requests.is_empty());
580        assert!(state.pending_elicitations.is_empty());
581    }
582}