1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13use tokio::sync::{RwLock, mpsc, oneshot};
14use tokio::time::timeout;
15use turbomcp_protocol::ServerInitiatedType;
16use uuid::Uuid;
17
18use crate::core::{
19 BidirectionalTransport, Transport, TransportCapabilities, TransportError, TransportMessage,
20 TransportResult, TransportState, TransportType,
21};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub enum MessageDirection {
26 ClientToServer,
28 ServerToClient,
30}
31
32#[derive(Debug)]
34pub struct CorrelationContext {
35 pub correlation_id: String,
37 pub request_id: String,
39 pub response_tx: Option<oneshot::Sender<TransportMessage>>,
41 pub timeout: Duration,
43 pub created_at: std::time::Instant,
45}
46
47#[derive(Debug)]
49pub struct BidirectionalTransportWrapper<T: Transport> {
50 inner: T,
52 direction: MessageDirection,
54 correlations: Arc<DashMap<String, CorrelationContext>>,
56 server_handlers: Arc<DashMap<String, mpsc::Sender<TransportMessage>>>,
58 validator: Arc<ProtocolDirectionValidator>,
60 router: Arc<MessageRouter>,
62 state: Arc<RwLock<ConnectionState>>,
64}
65
66#[derive(Debug, Clone, Default)]
68pub struct ConnectionState {
69 pub server_initiated_enabled: bool,
71 pub active_server_requests: Vec<String>,
73 pub pending_elicitations: Vec<String>,
75 pub metadata: HashMap<String, serde_json::Value>,
77}
78
79#[derive(Debug)]
81pub struct ProtocolDirectionValidator {
82 client_to_server: Vec<String>,
84 server_to_client: Vec<String>,
86 bidirectional: Vec<String>,
88}
89
90impl Default for ProtocolDirectionValidator {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96impl ProtocolDirectionValidator {
97 pub fn new() -> Self {
99 Self {
100 client_to_server: vec![
101 "initialize".to_string(),
102 "initialized".to_string(),
103 "tools/call".to_string(),
104 "resources/read".to_string(),
105 "prompts/get".to_string(),
106 "completion/complete".to_string(),
107 "resources/templates/list".to_string(),
108 ],
109 server_to_client: vec![
110 "sampling/createMessage".to_string(),
111 "roots/list".to_string(),
112 "elicitation/create".to_string(),
113 "notifications/message".to_string(),
114 "notifications/resources/updated".to_string(),
115 "notifications/tools/updated".to_string(),
116 ],
117 bidirectional: vec![
118 "ping".to_string(),
119 "notifications/cancelled".to_string(),
120 "notifications/progress".to_string(),
121 ],
122 }
123 }
124
125 pub fn validate(&self, message_type: &str, direction: MessageDirection) -> bool {
127 if self.bidirectional.contains(&message_type.to_string()) {
129 return true;
130 }
131
132 match direction {
133 MessageDirection::ClientToServer => {
134 self.client_to_server.contains(&message_type.to_string())
135 }
136 MessageDirection::ServerToClient => {
137 self.server_to_client.contains(&message_type.to_string())
138 }
139 }
140 }
141
142 pub fn get_allowed_direction(&self, message_type: &str) -> Option<MessageDirection> {
144 if self.bidirectional.contains(&message_type.to_string()) {
145 return None;
147 }
148
149 if self.client_to_server.contains(&message_type.to_string()) {
150 return Some(MessageDirection::ClientToServer);
151 }
152
153 if self.server_to_client.contains(&message_type.to_string()) {
154 return Some(MessageDirection::ServerToClient);
155 }
156
157 None
158 }
159}
160
161pub struct MessageRouter {
163 routes: DashMap<String, RouteHandler>,
165 default_handler: Option<RouteHandler>,
167}
168
169impl std::fmt::Debug for MessageRouter {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct("MessageRouter")
172 .field("routes_count", &self.routes.len())
173 .field("has_default_handler", &self.default_handler.is_some())
174 .finish()
175 }
176}
177
178type RouteHandler = Arc<dyn Fn(TransportMessage) -> RouteAction + Send + Sync>;
180
181#[derive(Debug, Clone)]
183pub enum RouteAction {
184 Forward,
186 Handle(String), Drop,
190 Transform(TransportMessage),
192}
193
194impl Default for MessageRouter {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl MessageRouter {
201 pub fn new() -> Self {
203 Self {
204 routes: DashMap::new(),
205 default_handler: None,
206 }
207 }
208
209 pub fn add_route<F>(&self, message_type: String, handler: F)
211 where
212 F: Fn(TransportMessage) -> RouteAction + Send + Sync + 'static,
213 {
214 self.routes.insert(message_type, Arc::new(handler));
215 }
216
217 pub fn route(&self, message: &TransportMessage) -> RouteAction {
219 let message_type = extract_message_type(message);
222
223 if let Some(handler) = self.routes.get(&message_type) {
224 handler(message.clone())
225 } else if let Some(ref default) = self.default_handler {
226 default(message.clone())
227 } else {
228 RouteAction::Forward
229 }
230 }
231}
232
233fn extract_message_type(message: &TransportMessage) -> String {
235 if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&message.payload)
239 && let Some(method) = json.get("method").and_then(|m| m.as_str())
240 {
241 return method.to_string();
242 }
243 "unknown".to_string()
244}
245
246impl<T: Transport> BidirectionalTransportWrapper<T> {
247 pub fn new(inner: T, direction: MessageDirection) -> Self {
249 Self {
250 inner,
251 direction,
252 correlations: Arc::new(DashMap::new()),
253 server_handlers: Arc::new(DashMap::new()),
254 validator: Arc::new(ProtocolDirectionValidator::new()),
255 router: Arc::new(MessageRouter::new()),
256 state: Arc::new(RwLock::new(ConnectionState::default())),
257 }
258 }
259
260 pub fn register_server_handler(
262 &self,
263 request_type: ServerInitiatedType,
264 handler: mpsc::Sender<TransportMessage>,
265 ) {
266 let key = match request_type {
267 ServerInitiatedType::Sampling => "sampling/createMessage",
268 ServerInitiatedType::Roots => "roots/list",
269 ServerInitiatedType::Elicitation => "elicitation/create",
270 ServerInitiatedType::Ping => "ping",
271 };
272 self.server_handlers.insert(key.to_string(), handler);
273 }
274
275 async fn process_incoming(&self, message: TransportMessage) -> TransportResult<()> {
277 let message_type = extract_message_type(&message);
278
279 if !self.validator.validate(&message_type, self.direction) {
281 return Err(TransportError::ProtocolError(format!(
282 "Invalid message direction for {}: expected {:?}",
283 message_type, self.direction
284 )));
285 }
286
287 if let Some(correlation_id) = extract_correlation_id(&message)
289 && let Some((_, context)) = self.correlations.remove(&correlation_id)
290 {
291 if let Some(tx) = context.response_tx {
293 let _ = tx.send(message);
294 }
295 return Ok(());
296 }
297
298 match self.router.route(&message) {
300 RouteAction::Forward => {
301 self.handle_standard_message(message).await
303 }
304 RouteAction::Handle(handler_name) => {
305 self.handle_with_handler(message, &handler_name).await
307 }
308 RouteAction::Drop => Ok(()),
309 RouteAction::Transform(transformed) => {
310 self.handle_standard_message(transformed).await
312 }
313 }
314 }
315
316 async fn handle_standard_message(&self, message: TransportMessage) -> TransportResult<()> {
318 let message_type = extract_message_type(&message);
320 if let Some(handler) = self.server_handlers.get(&message_type) {
321 handler
322 .send(message)
323 .await
324 .map_err(|e| TransportError::Internal(e.to_string()))?;
325 }
326 Ok(())
327 }
328
329 async fn handle_with_handler(
331 &self,
332 _message: TransportMessage,
333 _handler_name: &str,
334 ) -> TransportResult<()> {
335 Ok(())
338 }
339
340 pub async fn send_server_request(
342 &self,
343 _request_type: ServerInitiatedType,
344 message: TransportMessage,
345 timeout_duration: Duration,
346 ) -> TransportResult<TransportMessage> {
347 if self.direction != MessageDirection::ServerToClient {
349 return Err(TransportError::ProtocolError(
350 "Cannot send server-initiated request from client transport".to_string(),
351 ));
352 }
353
354 let correlation_id = Uuid::new_v4().to_string();
356 let (tx, rx) = oneshot::channel();
357
358 let context = CorrelationContext {
359 correlation_id: correlation_id.clone(),
360 request_id: Uuid::new_v4().to_string(),
361 response_tx: Some(tx),
362 timeout: timeout_duration,
363 created_at: std::time::Instant::now(),
364 };
365
366 self.correlations.insert(correlation_id.clone(), context);
367
368 self.inner.send(message).await?;
370
371 match timeout(timeout_duration, rx).await {
373 Ok(Ok(response)) => Ok(response),
374 Ok(Err(_)) => Err(TransportError::Internal(
375 "Response channel closed".to_string(),
376 )),
377 Err(_) => {
378 self.correlations.remove(&correlation_id);
379 Err(TransportError::Timeout)
380 }
381 }
382 }
383
384 pub async fn enable_server_initiated(&self) {
386 let mut state = self.state.write().await;
387 state.server_initiated_enabled = true;
388 }
389
390 pub async fn is_server_initiated_enabled(&self) -> bool {
392 let state = self.state.read().await;
393 state.server_initiated_enabled
394 }
395}
396
397fn extract_correlation_id(message: &TransportMessage) -> Option<String> {
401 if let Ok(json) = serde_json::from_slice::<serde_json::Value>(&message.payload) {
402 json.get("correlation_id")
403 .and_then(|id| id.as_str())
404 .map(|s| s.to_string())
405 } else {
406 None
407 }
408}
409
410#[allow(dead_code)]
412fn detect_server_initiated_type(message: &TransportMessage) -> Option<ServerInitiatedType> {
413 let message_type = extract_message_type(message);
414
415 match message_type.as_str() {
416 "sampling/createMessage" => Some(ServerInitiatedType::Sampling),
417 "roots/list" => Some(ServerInitiatedType::Roots),
418 "elicitation/create" => Some(ServerInitiatedType::Elicitation),
419 "ping" => Some(ServerInitiatedType::Ping),
420 _ => None,
421 }
422}
423
424#[async_trait]
426impl<T: Transport> Transport for BidirectionalTransportWrapper<T> {
427 fn transport_type(&self) -> TransportType {
428 self.inner.transport_type()
429 }
430
431 fn capabilities(&self) -> &TransportCapabilities {
432 self.inner.capabilities()
433 }
434
435 async fn state(&self) -> TransportState {
436 self.inner.state().await
437 }
438
439 async fn connect(&self) -> TransportResult<()> {
440 self.inner.connect().await
441 }
442
443 async fn disconnect(&self) -> TransportResult<()> {
444 self.correlations.clear();
446 self.inner.disconnect().await
447 }
448
449 async fn send(&self, message: TransportMessage) -> TransportResult<()> {
450 let message_type = extract_message_type(&message);
452 if !self.validator.validate(&message_type, self.direction) {
453 return Err(TransportError::ProtocolError(format!(
454 "Cannot send {} in direction {:?}",
455 message_type, self.direction
456 )));
457 }
458 self.inner.send(message).await
459 }
460
461 async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
462 if let Some(message) = self.inner.receive().await? {
463 self.process_incoming(message.clone()).await?;
464 Ok(Some(message))
465 } else {
466 Ok(None)
467 }
468 }
469
470 async fn metrics(&self) -> crate::core::TransportMetrics {
471 self.inner.metrics().await
472 }
473}
474
475#[async_trait]
477impl<T: Transport> BidirectionalTransport for BidirectionalTransportWrapper<T> {
478 async fn send_request(
479 &self,
480 message: TransportMessage,
481 timeout_duration: Option<Duration>,
482 ) -> TransportResult<TransportMessage> {
483 let timeout_duration = timeout_duration.unwrap_or(Duration::from_secs(30));
484
485 let correlation_id = Uuid::new_v4().to_string();
487 let (tx, rx) = oneshot::channel();
488
489 let context = CorrelationContext {
490 correlation_id: correlation_id.clone(),
491 request_id: Uuid::new_v4().to_string(),
492 response_tx: Some(tx),
493 timeout: timeout_duration,
494 created_at: std::time::Instant::now(),
495 };
496
497 self.correlations.insert(correlation_id.clone(), context);
498
499 self.send(message).await?;
501
502 match timeout(timeout_duration, rx).await {
504 Ok(Ok(response)) => Ok(response),
505 Ok(Err(_)) => Err(TransportError::Internal(
506 "Response channel closed".to_string(),
507 )),
508 Err(_) => {
509 self.correlations.remove(&correlation_id);
510 Err(TransportError::Timeout)
511 }
512 }
513 }
514
515 async fn start_correlation(&self, correlation_id: String) -> TransportResult<()> {
516 let context = CorrelationContext {
517 correlation_id: correlation_id.clone(),
518 request_id: Uuid::new_v4().to_string(),
519 response_tx: None,
520 timeout: Duration::from_secs(30),
521 created_at: std::time::Instant::now(),
522 };
523
524 self.correlations.insert(correlation_id, context);
525 Ok(())
526 }
527
528 async fn stop_correlation(&self, correlation_id: &str) -> TransportResult<()> {
529 self.correlations.remove(correlation_id);
530 Ok(())
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn test_protocol_direction_validator() {
540 let validator = ProtocolDirectionValidator::new();
541
542 assert!(validator.validate("tools/call", MessageDirection::ClientToServer));
544 assert!(!validator.validate("tools/call", MessageDirection::ServerToClient));
545
546 assert!(validator.validate("sampling/createMessage", MessageDirection::ServerToClient));
548 assert!(!validator.validate("sampling/createMessage", MessageDirection::ClientToServer));
549
550 assert!(validator.validate("ping", MessageDirection::ClientToServer));
552 assert!(validator.validate("ping", MessageDirection::ServerToClient));
553 }
554
555 #[test]
556 fn test_message_router() {
557 let router = MessageRouter::new();
558
559 router.add_route("test".to_string(), |_msg| {
560 RouteAction::Handle("test_handler".to_string())
561 });
562
563 let message = TransportMessage {
564 id: turbomcp_protocol::MessageId::from("test-message-id"),
565 payload: br#"{"method": "test"}"#.to_vec().into(),
566 metadata: Default::default(),
567 };
568
569 match router.route(&message) {
570 RouteAction::Handle(handler) => assert_eq!(handler, "test_handler"),
571 _ => panic!("Expected Handle action"),
572 }
573 }
574
575 #[tokio::test]
576 async fn test_connection_state() {
577 let state = ConnectionState::default();
578 assert!(!state.server_initiated_enabled);
579 assert!(state.active_server_requests.is_empty());
580 assert!(state.pending_elicitations.is_empty());
581 }
582}