stream_tungstenite/connection/
state.rs

1//! Connection state management.
2
3use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6
7use crate::error::{ConnectError, DisconnectReason};
8
9/// Connection status
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionStatus {
12    /// Not connected
13    Disconnected,
14    /// Attempting to connect
15    Connecting,
16    /// Connected and operational
17    Connected,
18    /// Reconnecting after disconnection
19    Reconnecting,
20    /// Shutting down
21    ShuttingDown,
22}
23
24impl std::fmt::Display for ConnectionStatus {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Self::Disconnected => write!(f, "disconnected"),
28            Self::Connecting => write!(f, "connecting"),
29            Self::Connected => write!(f, "connected"),
30            Self::Reconnecting => write!(f, "reconnecting"),
31            Self::ShuttingDown => write!(f, "shutting_down"),
32        }
33    }
34}
35
36/// Internal connection state
37pub struct ConnectionState {
38    /// Unique connection ID (incremented on each connection)
39    id: AtomicU64,
40    /// Connection ID counter
41    id_counter: AtomicU64,
42    /// Current status (encoded as u8)
43    status: AtomicU8,
44    /// Number of reconnection attempts
45    reconnect_count: AtomicU64,
46    /// Total number of errors encountered
47    error_count: AtomicU64,
48    /// Whether shutdown has been requested
49    shutdown_requested: AtomicBool,
50
51    /// Time when current connection was established
52    connected_at: RwLock<Option<Instant>>,
53    /// Time of last activity (message sent/received)
54    last_activity: RwLock<Option<Instant>>,
55    /// Last error encountered
56    last_error: RwLock<Option<ConnectError>>,
57    /// Last disconnect reason
58    last_disconnect: RwLock<Option<DisconnectReason>>,
59}
60
61impl ConnectionState {
62    // Encoded status values (avoid magic numbers)
63    const STATUS_DISCONNECTED: u8 = 0;
64    const STATUS_CONNECTING: u8 = 1;
65    const STATUS_CONNECTED: u8 = 2;
66    const STATUS_RECONNECTING: u8 = 3;
67    const STATUS_SHUTTING_DOWN: u8 = 4;
68
69    /// Create a new connection state
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            id: AtomicU64::new(0),
74            id_counter: AtomicU64::new(0),
75            status: AtomicU8::new(Self::STATUS_DISCONNECTED),
76            reconnect_count: AtomicU64::new(0),
77            error_count: AtomicU64::new(0),
78            shutdown_requested: AtomicBool::new(false),
79            connected_at: RwLock::new(None),
80            last_activity: RwLock::new(None),
81            last_error: RwLock::new(None),
82            last_disconnect: RwLock::new(None),
83        }
84    }
85
86    /// Get the current connection ID
87    #[must_use]
88    pub fn id(&self) -> u64 {
89        self.id.load(Ordering::Acquire)
90    }
91
92    /// Get the current status
93    #[must_use]
94    pub fn status(&self) -> ConnectionStatus {
95        match self.status.load(Ordering::Acquire) {
96            Self::STATUS_CONNECTING => ConnectionStatus::Connecting,
97            Self::STATUS_CONNECTED => ConnectionStatus::Connected,
98            Self::STATUS_RECONNECTING => ConnectionStatus::Reconnecting,
99            Self::STATUS_SHUTTING_DOWN => ConnectionStatus::ShuttingDown,
100            // STATUS_DISCONNECTED and any invalid value => Disconnected
101            _ => ConnectionStatus::Disconnected,
102        }
103    }
104
105    /// Check if currently connected
106    #[must_use]
107    pub fn is_connected(&self) -> bool {
108        self.status() == ConnectionStatus::Connected
109    }
110
111    /// Check if shutdown was requested
112    #[must_use]
113    pub fn is_shutdown_requested(&self) -> bool {
114        self.shutdown_requested.load(Ordering::Acquire)
115    }
116
117    /// Get reconnect count
118    #[must_use]
119    pub fn reconnect_count(&self) -> u64 {
120        self.reconnect_count.load(Ordering::Relaxed)
121    }
122
123    /// Get error count
124    #[must_use]
125    pub fn error_count(&self) -> u64 {
126        self.error_count.load(Ordering::Relaxed)
127    }
128
129    /// Mark as connecting
130    pub fn mark_connecting(&self) {
131        self.status
132            .store(Self::STATUS_CONNECTING, Ordering::Release);
133    }
134
135    /// Mark as reconnecting
136    pub fn mark_reconnecting(&self) {
137        self.status
138            .store(Self::STATUS_RECONNECTING, Ordering::Release);
139        self.reconnect_count.fetch_add(1, Ordering::Relaxed);
140    }
141
142    /// Mark as connected with new connection ID
143    pub async fn mark_connected(&self) -> u64 {
144        let new_id = self.id_counter.fetch_add(1, Ordering::Relaxed) + 1;
145        self.id.store(new_id, Ordering::Release);
146        self.status.store(Self::STATUS_CONNECTED, Ordering::Release);
147
148        let now = Instant::now();
149        *self.connected_at.write().await = Some(now);
150        *self.last_activity.write().await = Some(now);
151
152        new_id
153    }
154
155    /// Mark as disconnected
156    pub async fn mark_disconnected(&self, reason: DisconnectReason) {
157        self.status
158            .store(Self::STATUS_DISCONNECTED, Ordering::Release);
159        *self.connected_at.write().await = None;
160        *self.last_disconnect.write().await = Some(reason);
161    }
162
163    /// Mark as shutting down
164    pub fn mark_shutting_down(&self) {
165        self.shutdown_requested.store(true, Ordering::Release);
166        self.status
167            .store(Self::STATUS_SHUTTING_DOWN, Ordering::Release);
168    }
169
170    /// Update last activity time
171    pub async fn update_activity(&self) {
172        *self.last_activity.write().await = Some(Instant::now());
173    }
174
175    /// Record an error
176    pub async fn record_error(&self, error: ConnectError) {
177        self.error_count.fetch_add(1, Ordering::Relaxed);
178        *self.last_error.write().await = Some(error);
179    }
180
181    /// Check if connection is healthy (received activity within timeout)
182    pub async fn is_healthy(&self, timeout: Duration) -> bool {
183        if !self.is_connected() {
184            return false;
185        }
186
187        let last_activity = self.last_activity.read().await;
188        last_activity.is_some_and(|time| time.elapsed() < timeout)
189    }
190
191    /// Get current connection duration
192    pub async fn connection_duration(&self) -> Option<Duration> {
193        let connected_at = self.connected_at.read().await;
194        connected_at.map(|t| t.elapsed())
195    }
196
197    /// Get a snapshot of the current state
198    pub async fn snapshot(&self) -> ConnectionSnapshot {
199        let connected_at = *self.connected_at.read().await;
200        let last_activity = *self.last_activity.read().await;
201        let last_error = self.last_error.read().await.clone();
202        let last_disconnect = self.last_disconnect.read().await.clone();
203
204        ConnectionSnapshot {
205            id: self.id(),
206            status: self.status(),
207            connected_at,
208            last_activity,
209            reconnect_count: self.reconnect_count(),
210            error_count: self.error_count(),
211            last_error,
212            last_disconnect,
213            connection_duration: connected_at.map(|t| t.elapsed()),
214        }
215    }
216
217    /// Reset state (for testing)
218    pub async fn reset(&self) {
219        self.id.store(0, Ordering::Release);
220        self.status
221            .store(Self::STATUS_DISCONNECTED, Ordering::Release);
222        self.reconnect_count.store(0, Ordering::Relaxed);
223        self.error_count.store(0, Ordering::Relaxed);
224        self.shutdown_requested.store(false, Ordering::Release);
225        *self.connected_at.write().await = None;
226        *self.last_activity.write().await = None;
227        *self.last_error.write().await = None;
228        *self.last_disconnect.write().await = None;
229    }
230}
231
232impl Default for ConnectionState {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238/// Snapshot of connection state for external consumption
239#[derive(Debug, Clone)]
240pub struct ConnectionSnapshot {
241    /// Current connection ID
242    pub id: u64,
243    /// Current status
244    pub status: ConnectionStatus,
245    /// Time when connection was established
246    pub connected_at: Option<Instant>,
247    /// Time of last activity
248    pub last_activity: Option<Instant>,
249    /// Number of reconnections
250    pub reconnect_count: u64,
251    /// Number of errors
252    pub error_count: u64,
253    /// Last error (if any)
254    pub last_error: Option<ConnectError>,
255    /// Last disconnect reason (if any)
256    pub last_disconnect: Option<DisconnectReason>,
257    /// Current connection duration
258    pub connection_duration: Option<Duration>,
259}
260
261impl ConnectionSnapshot {
262    /// Check if currently connected
263    #[must_use]
264    pub const fn is_connected(&self) -> bool {
265        matches!(self.status, ConnectionStatus::Connected)
266    }
267
268    /// Get uptime percentage (connected time / total time)
269    #[must_use]
270    pub fn uptime_ratio(&self, since: Instant) -> f64 {
271        let total_duration = since.elapsed();
272        if total_duration.is_zero() {
273            return 0.0;
274        }
275
276        let connected_duration = self.connection_duration.unwrap_or(Duration::ZERO);
277        connected_duration.as_secs_f64() / total_duration.as_secs_f64()
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[tokio::test]
286    async fn test_connection_state_lifecycle() {
287        let state = ConnectionState::new();
288
289        // Initially disconnected
290        assert_eq!(state.status(), ConnectionStatus::Disconnected);
291        assert!(!state.is_connected());
292
293        // Mark as connecting
294        state.mark_connecting();
295        assert_eq!(state.status(), ConnectionStatus::Connecting);
296
297        // Mark as connected
298        let id = state.mark_connected().await;
299        assert_eq!(id, 1);
300        assert_eq!(state.status(), ConnectionStatus::Connected);
301        assert!(state.is_connected());
302
303        // Mark as disconnected
304        state.mark_disconnected(DisconnectReason::Normal).await;
305        assert_eq!(state.status(), ConnectionStatus::Disconnected);
306        assert!(!state.is_connected());
307    }
308
309    #[tokio::test]
310    async fn test_connection_state_snapshot() {
311        let state = ConnectionState::new();
312        state.mark_connected().await;
313
314        let snapshot = state.snapshot().await;
315        assert!(snapshot.is_connected());
316        assert_eq!(snapshot.id, 1);
317        assert!(snapshot.connected_at.is_some());
318    }
319
320    #[tokio::test]
321    async fn test_reconnect_counting() {
322        let state = ConnectionState::new();
323
324        state.mark_reconnecting();
325        assert_eq!(state.reconnect_count(), 1);
326
327        state.mark_reconnecting();
328        assert_eq!(state.reconnect_count(), 2);
329    }
330}