Skip to main content

turbomcp_websocket/
types.rs

1//! Core types and type aliases for WebSocket bidirectional transport
2//!
3//! This module defines the core types used throughout the WebSocket transport
4//! implementation, including stream type aliases and pending request structures.
5
6use std::future::Future;
7use std::sync::Arc;
8use std::sync::atomic::AtomicBool;
9use std::time::Duration;
10
11use bytes::Bytes;
12use dashmap::DashMap;
13use futures::{stream::SplitSink, stream::SplitStream};
14use serde_json::json;
15use tokio::net::TcpStream;
16use tokio::sync::{Mutex, RwLock, broadcast, mpsc, oneshot};
17use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
18use turbomcp_protocol::types::{ElicitRequest, ElicitResult};
19use uuid::Uuid;
20
21use turbomcp_transport_traits::{
22    ConnectionState, CorrelationContext, TransportCapabilities, TransportEventEmitter,
23    TransportMessage, TransportMetrics, TransportState,
24};
25
26use super::config::WebSocketBidirectionalConfig;
27
28// Type aliases to reduce complexity and improve readability
29/// WebSocket writer handle for sending messages (thread-safe, async-safe)
30pub type WebSocketWriter =
31    Arc<Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>>;
32/// WebSocket reader handle for receiving messages (thread-safe, async-safe)
33pub type WebSocketReader =
34    Arc<Mutex<Option<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>>;
35
36/// Pending elicitation request
37#[derive(Debug)]
38pub struct PendingElicitation {
39    /// Request ID for correlation
40    pub request_id: String,
41
42    /// The elicitation request
43    pub request: ElicitRequest,
44
45    /// Response channel
46    pub response_tx: oneshot::Sender<ElicitResult>,
47
48    /// Timeout deadline
49    pub deadline: tokio::time::Instant,
50
51    /// Retry count
52    pub retry_count: u32,
53}
54
55impl PendingElicitation {
56    /// Create a new pending elicitation
57    pub fn new(
58        request: ElicitRequest,
59        response_tx: oneshot::Sender<ElicitResult>,
60        timeout: Duration,
61    ) -> Self {
62        Self {
63            request_id: Uuid::new_v4().to_string(),
64            request,
65            response_tx,
66            deadline: tokio::time::Instant::now() + timeout,
67            retry_count: 0,
68        }
69    }
70
71    /// Check if the elicitation has expired
72    pub fn is_expired(&self) -> bool {
73        tokio::time::Instant::now() >= self.deadline
74    }
75
76    /// Get time remaining until expiration
77    pub fn time_remaining(&self) -> Duration {
78        if self.is_expired() {
79            Duration::ZERO
80        } else {
81            self.deadline.duration_since(tokio::time::Instant::now())
82        }
83    }
84
85    /// Increment retry count
86    pub fn increment_retry(&mut self) {
87        self.retry_count += 1;
88    }
89}
90
91/// WebSocket bidirectional transport implementation
92#[derive(Debug)]
93pub struct WebSocketBidirectionalTransport {
94    /// Transport state
95    pub state: Arc<RwLock<TransportState>>,
96
97    /// Transport capabilities
98    pub capabilities: TransportCapabilities,
99
100    /// Configuration (mutex for interior mutability)
101    pub config: Arc<parking_lot::Mutex<WebSocketBidirectionalConfig>>,
102
103    /// Metrics collector
104    pub metrics: Arc<RwLock<TransportMetrics>>,
105
106    /// Event emitter for transport events
107    pub event_emitter: Arc<TransportEventEmitter>,
108
109    /// WebSocket write half (sender)
110    pub writer: WebSocketWriter,
111
112    /// WebSocket read half (receiver)
113    pub reader: WebSocketReader,
114
115    /// Active correlations for request-response patterns
116    pub correlations: Arc<DashMap<String, CorrelationContext>>,
117
118    /// Pending elicitation requests
119    pub elicitations: Arc<DashMap<String, PendingElicitation>>,
120
121    /// Pending sampling requests
122    pub pending_samplings:
123        Arc<DashMap<String, oneshot::Sender<turbomcp_protocol::types::CreateMessageResult>>>,
124
125    /// Pending ping requests
126    pub pending_pings: Arc<DashMap<String, oneshot::Sender<turbomcp_protocol::types::PingResult>>>,
127
128    /// Pending roots list requests
129    pub pending_roots:
130        Arc<DashMap<String, oneshot::Sender<turbomcp_protocol::types::ListRootsResult>>>,
131
132    /// Connection state
133    pub connection_state: Arc<RwLock<ConnectionState>>,
134
135    /// Background task handles
136    pub task_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
137
138    /// Shutdown signal broadcaster (allows multiple receivers via subscribe())
139    ///
140    /// This broadcast channel enables all background tasks to receive shutdown signals.
141    /// Each task calls `shutdown_tx.subscribe()` to get its own receiver, then uses
142    /// `tokio::select!` to listen for the shutdown signal alongside its main logic.
143    ///
144    /// When `disconnect()` is called, it sends a shutdown signal that wakes all tasks,
145    /// allowing them to perform graceful cleanup before exiting.
146    pub shutdown_tx: Arc<broadcast::Sender<()>>,
147
148    /// Controls whether automatic reconnection is allowed
149    ///
150    /// This flag is set based on the initial config.reconnect.enabled value,
151    /// but can be permanently disabled by calling disconnect().
152    ///
153    /// Defense-in-depth: Even if shutdown signals are missed or state transitions
154    /// are delayed, this atomic flag ensures reconnection tasks will stop when
155    /// user explicitly calls disconnect().
156    pub reconnect_allowed: Arc<AtomicBool>,
157
158    /// Session ID for this connection
159    pub session_id: String,
160
161    /// Channel receiver for incoming messages (consumed by `Transport::receive()`)
162    ///
163    /// The background `spawn_message_reader_task()` reads from the WebSocket stream
164    /// and forwards non-correlation messages to this channel. This eliminates the
165    /// race condition where both the background task and `receive()` compete to
166    /// read from the same WebSocket stream.
167    pub incoming_rx: Arc<Mutex<mpsc::Receiver<TransportMessage>>>,
168
169    /// Channel sender for incoming messages (used by `spawn_message_reader_task()`)
170    ///
171    /// This sender is cloned and given to the background message reader task.
172    /// The task forwards all messages that aren't handled by correlation routing
173    /// to this channel for `Transport::receive()` to consume.
174    pub incoming_tx: mpsc::Sender<TransportMessage>,
175}
176
177impl WebSocketBidirectionalTransport {
178    /// Create transport capabilities for WebSocket bidirectional transport
179    pub fn create_capabilities(config: &WebSocketBidirectionalConfig) -> TransportCapabilities {
180        TransportCapabilities {
181            max_message_size: Some(config.max_message_size),
182            supports_compression: config.enable_compression,
183            supports_streaming: true,
184            supports_bidirectional: true,
185            supports_multiplexing: true,
186            compression_algorithms: if config.enable_compression {
187                vec!["deflate".to_string(), "gzip".to_string()]
188            } else {
189                Vec::new()
190            },
191            custom: {
192                let mut custom = std::collections::HashMap::new();
193                custom.insert("elicitation".to_string(), json!(true));
194                custom.insert("sampling".to_string(), json!(true));
195                custom.insert("websocket_version".to_string(), json!("13"));
196                custom.insert(
197                    "max_concurrent_elicitations".to_string(),
198                    json!(config.max_concurrent_elicitations),
199                );
200                custom
201            },
202        }
203    }
204
205    /// Get the current number of pending elicitations
206    pub fn pending_elicitations_count(&self) -> usize {
207        self.elicitations.len()
208    }
209
210    /// Get the current number of active correlations
211    pub fn active_correlations_count(&self) -> usize {
212        self.correlations.len()
213    }
214
215    /// Check if the transport is at elicitation capacity
216    pub fn is_at_elicitation_capacity(&self) -> bool {
217        self.elicitations.len() >= self.config.lock().max_concurrent_elicitations
218    }
219
220    /// Get session ID
221    pub fn session_id(&self) -> &str {
222        &self.session_id
223    }
224
225    /// Check if WebSocket is connected
226    pub async fn is_writer_connected(&self) -> bool {
227        self.writer.lock().await.is_some()
228    }
229
230    /// Check if WebSocket reader is available
231    pub async fn is_reader_connected(&self) -> bool {
232        self.reader.lock().await.is_some()
233    }
234}
235
236/// Trait for types that can be used as WebSocket stream endpoints
237pub trait WebSocketStreamHandler {
238    /// Setup the WebSocket stream
239    fn setup_stream(
240        &mut self,
241        stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
242    ) -> impl Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send;
243
244    /// Handle incoming WebSocket message
245    fn handle_message(
246        &self,
247        message: Message,
248    ) -> impl Future<Output = Result<Option<Message>, Box<dyn std::error::Error + Send + Sync>>> + Send;
249}
250
251/// WebSocket message processing result
252#[derive(Debug)]
253pub enum MessageProcessingResult {
254    /// Message was processed successfully
255    Processed,
256    /// Message should be forwarded to the application
257    Forward(Bytes),
258    /// Message processing failed
259    Failed(String),
260    /// No action needed (e.g., for ping/pong)
261    NoAction,
262}
263
264/// Connection statistics for monitoring
265#[derive(Debug, Clone)]
266pub struct WebSocketConnectionStats {
267    /// Number of messages sent
268    pub messages_sent: u64,
269    /// Number of messages received
270    pub messages_received: u64,
271    /// Number of ping messages sent
272    pub pings_sent: u64,
273    /// Number of pong messages received
274    pub pongs_received: u64,
275    /// Number of connection errors
276    pub connection_errors: u64,
277    /// Number of reconnection attempts
278    pub reconnection_attempts: u64,
279    /// Current connection state
280    pub connection_state: TransportState,
281    /// Time when connection was established
282    pub connected_at: Option<std::time::SystemTime>,
283    /// Time of last message activity
284    pub last_activity: Option<std::time::SystemTime>,
285}
286
287impl Default for WebSocketConnectionStats {
288    fn default() -> Self {
289        Self {
290            messages_sent: 0,
291            messages_received: 0,
292            pings_sent: 0,
293            pongs_received: 0,
294            connection_errors: 0,
295            reconnection_attempts: 0,
296            connection_state: TransportState::Disconnected,
297            connected_at: None,
298            last_activity: None,
299        }
300    }
301}
302
303impl WebSocketConnectionStats {
304    /// Create new connection statistics
305    pub fn new() -> Self {
306        Self::default()
307    }
308
309    /// Record a sent message
310    pub fn record_message_sent(&mut self) {
311        self.messages_sent += 1;
312        self.last_activity = Some(std::time::SystemTime::now());
313    }
314
315    /// Record a received message
316    pub fn record_message_received(&mut self) {
317        self.messages_received += 1;
318        self.last_activity = Some(std::time::SystemTime::now());
319    }
320
321    /// Record a sent ping
322    pub fn record_ping_sent(&mut self) {
323        self.pings_sent += 1;
324    }
325
326    /// Record a received pong
327    pub fn record_pong_received(&mut self) {
328        self.pongs_received += 1;
329    }
330
331    /// Record a connection error
332    pub fn record_connection_error(&mut self) {
333        self.connection_errors += 1;
334    }
335
336    /// Record a reconnection attempt
337    pub fn record_reconnection_attempt(&mut self) {
338        self.reconnection_attempts += 1;
339    }
340
341    /// Set connection state
342    pub fn set_connection_state(&mut self, state: TransportState) {
343        self.connection_state = state.clone();
344        if matches!(state, TransportState::Connected) {
345            self.connected_at = Some(std::time::SystemTime::now());
346        }
347    }
348
349    /// Get connection uptime
350    pub fn uptime(&self) -> Option<Duration> {
351        self.connected_at.and_then(|connected_at| {
352            std::time::SystemTime::now()
353                .duration_since(connected_at)
354                .ok()
355        })
356    }
357
358    /// Get idle time since last activity
359    pub fn idle_time(&self) -> Option<Duration> {
360        self.last_activity.and_then(|last_activity| {
361            std::time::SystemTime::now()
362                .duration_since(last_activity)
363                .ok()
364        })
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_pending_elicitation_creation() {
374        use turbomcp_protocol::types::ElicitationSchema;
375
376        let request = ElicitRequest {
377            params: turbomcp_protocol::types::ElicitRequestParams::form(
378                "Test message".to_string(),
379                ElicitationSchema {
380                    schema_type: "object".to_string(),
381                    properties: std::collections::HashMap::new(),
382                    required: None,
383                    additional_properties: None,
384                },
385                None,
386                Some(true),
387            ),
388            task: None,
389            _meta: None,
390        };
391        let (tx, _rx) = oneshot::channel();
392        let timeout = Duration::from_secs(30);
393
394        let pending = PendingElicitation::new(request, tx, timeout);
395
396        assert!(!pending.request_id.is_empty());
397        assert_eq!(pending.retry_count, 0);
398        assert!(!pending.is_expired());
399        assert!(pending.time_remaining() > Duration::from_secs(25));
400    }
401
402    #[test]
403    fn test_websocket_connection_stats() {
404        let mut stats = WebSocketConnectionStats::new();
405
406        stats.record_message_sent();
407        stats.record_message_received();
408        stats.record_ping_sent();
409        stats.record_pong_received();
410        stats.record_connection_error();
411
412        assert_eq!(stats.messages_sent, 1);
413        assert_eq!(stats.messages_received, 1);
414        assert_eq!(stats.pings_sent, 1);
415        assert_eq!(stats.pongs_received, 1);
416        assert_eq!(stats.connection_errors, 1);
417        assert!(stats.last_activity.is_some());
418    }
419
420    #[test]
421    fn test_create_capabilities() {
422        let config = WebSocketBidirectionalConfig {
423            enable_compression: true,
424            max_message_size: 1024 * 1024,
425            max_concurrent_elicitations: 5,
426            ..Default::default()
427        };
428
429        let capabilities = WebSocketBidirectionalTransport::create_capabilities(&config);
430
431        assert!(capabilities.supports_compression);
432        assert!(capabilities.supports_bidirectional);
433        assert!(capabilities.supports_streaming);
434        assert!(capabilities.supports_multiplexing);
435        assert_eq!(capabilities.max_message_size, Some(1024 * 1024));
436        assert!(!capabilities.compression_algorithms.is_empty());
437        assert!(capabilities.custom.contains_key("elicitation"));
438    }
439}