1use std::future::Future;
7use std::sync::Arc;
8use std::sync::atomic::AtomicBool;
9use std::time::Duration;
10
11use bytes::Bytes;
12use dashmap::DashMap;
13use futures::{stream::SplitSink, stream::SplitStream};
14use serde_json::json;
15use tokio::net::TcpStream;
16use tokio::sync::{Mutex, RwLock, broadcast, mpsc, oneshot};
17use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
18use turbomcp_protocol::types::{ElicitRequest, ElicitResult};
19use uuid::Uuid;
20
21use turbomcp_transport_traits::{
22 ConnectionState, CorrelationContext, TransportCapabilities, TransportEventEmitter,
23 TransportMessage, TransportMetrics, TransportState,
24};
25
26use super::config::WebSocketBidirectionalConfig;
27
28pub type WebSocketWriter =
31 Arc<Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>>;
32pub type WebSocketReader =
34 Arc<Mutex<Option<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>>;
35
36#[derive(Debug)]
38pub struct PendingElicitation {
39 pub request_id: String,
41
42 pub request: ElicitRequest,
44
45 pub response_tx: oneshot::Sender<ElicitResult>,
47
48 pub deadline: tokio::time::Instant,
50
51 pub retry_count: u32,
53}
54
55impl PendingElicitation {
56 pub fn new(
58 request: ElicitRequest,
59 response_tx: oneshot::Sender<ElicitResult>,
60 timeout: Duration,
61 ) -> Self {
62 Self {
63 request_id: Uuid::new_v4().to_string(),
64 request,
65 response_tx,
66 deadline: tokio::time::Instant::now() + timeout,
67 retry_count: 0,
68 }
69 }
70
71 pub fn is_expired(&self) -> bool {
73 tokio::time::Instant::now() >= self.deadline
74 }
75
76 pub fn time_remaining(&self) -> Duration {
78 if self.is_expired() {
79 Duration::ZERO
80 } else {
81 self.deadline.duration_since(tokio::time::Instant::now())
82 }
83 }
84
85 pub fn increment_retry(&mut self) {
87 self.retry_count += 1;
88 }
89}
90
91#[derive(Debug)]
93pub struct WebSocketBidirectionalTransport {
94 pub state: Arc<RwLock<TransportState>>,
96
97 pub capabilities: TransportCapabilities,
99
100 pub config: Arc<parking_lot::Mutex<WebSocketBidirectionalConfig>>,
102
103 pub metrics: Arc<RwLock<TransportMetrics>>,
105
106 pub event_emitter: Arc<TransportEventEmitter>,
108
109 pub writer: WebSocketWriter,
111
112 pub reader: WebSocketReader,
114
115 pub correlations: Arc<DashMap<String, CorrelationContext>>,
117
118 pub elicitations: Arc<DashMap<String, PendingElicitation>>,
120
121 pub pending_samplings:
123 Arc<DashMap<String, oneshot::Sender<turbomcp_protocol::types::CreateMessageResult>>>,
124
125 pub pending_pings: Arc<DashMap<String, oneshot::Sender<turbomcp_protocol::types::PingResult>>>,
127
128 pub pending_roots:
130 Arc<DashMap<String, oneshot::Sender<turbomcp_protocol::types::ListRootsResult>>>,
131
132 pub connection_state: Arc<RwLock<ConnectionState>>,
134
135 pub task_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
137
138 pub shutdown_tx: Arc<broadcast::Sender<()>>,
147
148 pub reconnect_allowed: Arc<AtomicBool>,
157
158 pub session_id: String,
160
161 pub incoming_rx: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
168
169 pub incoming_tx: mpsc::Sender<TransportMessage>,
175}
176
177impl WebSocketBidirectionalTransport {
178 pub fn create_capabilities(config: &WebSocketBidirectionalConfig) -> TransportCapabilities {
180 TransportCapabilities {
181 max_message_size: Some(config.max_message_size),
182 supports_compression: config.enable_compression,
183 supports_streaming: true,
184 supports_bidirectional: true,
185 supports_multiplexing: true,
186 compression_algorithms: if config.enable_compression {
187 vec!["deflate".to_string(), "gzip".to_string()]
188 } else {
189 Vec::new()
190 },
191 custom: {
192 let mut custom = std::collections::HashMap::new();
193 custom.insert("elicitation".to_string(), json!(true));
194 custom.insert("sampling".to_string(), json!(true));
195 custom.insert("websocket_version".to_string(), json!("13"));
196 custom.insert(
197 "max_concurrent_elicitations".to_string(),
198 json!(config.max_concurrent_elicitations),
199 );
200 custom
201 },
202 }
203 }
204
205 pub fn pending_elicitations_count(&self) -> usize {
207 self.elicitations.len()
208 }
209
210 pub fn active_correlations_count(&self) -> usize {
212 self.correlations.len()
213 }
214
215 pub fn is_at_elicitation_capacity(&self) -> bool {
217 self.elicitations.len() >= self.config.lock().max_concurrent_elicitations
218 }
219
220 pub fn session_id(&self) -> &str {
222 &self.session_id
223 }
224
225 pub async fn is_writer_connected(&self) -> bool {
227 self.writer.lock().await.is_some()
228 }
229
230 pub async fn is_reader_connected(&self) -> bool {
232 self.reader.lock().await.is_some()
233 }
234}
235
236pub trait WebSocketStreamHandler {
238 fn setup_stream(
240 &mut self,
241 stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
242 ) -> impl Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send;
243
244 fn handle_message(
246 &self,
247 message: Message,
248 ) -> impl Future<Output = Result<Option<Message>, Box<dyn std::error::Error + Send + Sync>>> + Send;
249}
250
251#[derive(Debug)]
253pub enum MessageProcessingResult {
254 Processed,
256 Forward(Bytes),
258 Failed(String),
260 NoAction,
262}
263
264#[derive(Debug, Clone)]
266pub struct WebSocketConnectionStats {
267 pub messages_sent: u64,
269 pub messages_received: u64,
271 pub pings_sent: u64,
273 pub pongs_received: u64,
275 pub connection_errors: u64,
277 pub reconnection_attempts: u64,
279 pub connection_state: TransportState,
281 pub connected_at: Option<std::time::SystemTime>,
283 pub last_activity: Option<std::time::SystemTime>,
285}
286
287impl Default for WebSocketConnectionStats {
288 fn default() -> Self {
289 Self {
290 messages_sent: 0,
291 messages_received: 0,
292 pings_sent: 0,
293 pongs_received: 0,
294 connection_errors: 0,
295 reconnection_attempts: 0,
296 connection_state: TransportState::Disconnected,
297 connected_at: None,
298 last_activity: None,
299 }
300 }
301}
302
303impl WebSocketConnectionStats {
304 pub fn new() -> Self {
306 Self::default()
307 }
308
309 pub fn record_message_sent(&mut self) {
311 self.messages_sent += 1;
312 self.last_activity = Some(std::time::SystemTime::now());
313 }
314
315 pub fn record_message_received(&mut self) {
317 self.messages_received += 1;
318 self.last_activity = Some(std::time::SystemTime::now());
319 }
320
321 pub fn record_ping_sent(&mut self) {
323 self.pings_sent += 1;
324 }
325
326 pub fn record_pong_received(&mut self) {
328 self.pongs_received += 1;
329 }
330
331 pub fn record_connection_error(&mut self) {
333 self.connection_errors += 1;
334 }
335
336 pub fn record_reconnection_attempt(&mut self) {
338 self.reconnection_attempts += 1;
339 }
340
341 pub fn set_connection_state(&mut self, state: TransportState) {
343 self.connection_state = state.clone();
344 if matches!(state, TransportState::Connected) {
345 self.connected_at = Some(std::time::SystemTime::now());
346 }
347 }
348
349 pub fn uptime(&self) -> Option<Duration> {
351 self.connected_at.and_then(|connected_at| {
352 std::time::SystemTime::now()
353 .duration_since(connected_at)
354 .ok()
355 })
356 }
357
358 pub fn idle_time(&self) -> Option<Duration> {
360 self.last_activity.and_then(|last_activity| {
361 std::time::SystemTime::now()
362 .duration_since(last_activity)
363 .ok()
364 })
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_pending_elicitation_creation() {
374 use turbomcp_protocol::types::ElicitationSchema;
375
376 let request = ElicitRequest {
377 params: turbomcp_protocol::types::ElicitRequestParams::form(
378 "Test message".to_string(),
379 ElicitationSchema {
380 schema_type: "object".to_string(),
381 properties: std::collections::HashMap::new(),
382 required: None,
383 additional_properties: None,
384 },
385 None,
386 Some(true),
387 ),
388 task: None,
389 _meta: None,
390 };
391 let (tx, _rx) = oneshot::channel();
392 let timeout = Duration::from_secs(30);
393
394 let pending = PendingElicitation::new(request, tx, timeout);
395
396 assert!(!pending.request_id.is_empty());
397 assert_eq!(pending.retry_count, 0);
398 assert!(!pending.is_expired());
399 assert!(pending.time_remaining() > Duration::from_secs(25));
400 }
401
402 #[test]
403 fn test_websocket_connection_stats() {
404 let mut stats = WebSocketConnectionStats::new();
405
406 stats.record_message_sent();
407 stats.record_message_received();
408 stats.record_ping_sent();
409 stats.record_pong_received();
410 stats.record_connection_error();
411
412 assert_eq!(stats.messages_sent, 1);
413 assert_eq!(stats.messages_received, 1);
414 assert_eq!(stats.pings_sent, 1);
415 assert_eq!(stats.pongs_received, 1);
416 assert_eq!(stats.connection_errors, 1);
417 assert!(stats.last_activity.is_some());
418 }
419
420 #[test]
421 fn test_create_capabilities() {
422 let config = WebSocketBidirectionalConfig {
423 enable_compression: true,
424 max_message_size: 1024 * 1024,
425 max_concurrent_elicitations: 5,
426 ..Default::default()
427 };
428
429 let capabilities = WebSocketBidirectionalTransport::create_capabilities(&config);
430
431 assert!(capabilities.supports_compression);
432 assert!(capabilities.supports_bidirectional);
433 assert!(capabilities.supports_streaming);
434 assert!(capabilities.supports_multiplexing);
435 assert_eq!(capabilities.max_message_size, Some(1024 * 1024));
436 assert!(!capabilities.compression_algorithms.is_empty());
437 assert!(capabilities.custom.contains_key("elicitation"));
438 }
439}