stream_tungstenite/connection/
supervisor.rs

1//! Connection supervisor - manages connection lifecycle and reconnection.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::broadcast;
7
8use super::connector::{Connector, DefaultConnector};
9use super::retry::{ExponentialBackoff, RetryStrategy};
10use super::state::{ConnectionSnapshot, ConnectionState};
11use crate::error::{ConnectError, DisconnectReason, SupervisorError};
12
13/// Lightweight handle to update activity timestamp without borrowing the supervisor
14#[derive(Clone)]
15pub struct ActivityHandle {
16    state: std::sync::Arc<ConnectionState>,
17}
18
19impl ActivityHandle {
20    /// Mark activity (last activity timestamp) on the underlying connection state
21    pub async fn update(&self) {
22        self.state.update_activity().await;
23    }
24}
25
26/// Connection events broadcast to listeners
27#[derive(Debug, Clone)]
28pub enum ConnectionEvent {
29    /// Starting connection attempt
30    Connecting { attempt: u32 },
31    /// Successfully connected
32    Connected { id: u64 },
33    /// Disconnected
34    Disconnected { reason: DisconnectReason },
35    /// Reconnection scheduled
36    ReconnectScheduled { delay: Duration, attempt: u32 },
37    /// Error occurred (non-fatal)
38    Error { error: ConnectError, attempt: u32 },
39    /// Fatal error - will stop reconnecting
40    FatalError { error: ConnectError },
41    /// Shutdown initiated
42    Shutdown,
43}
44
45/// Connection supervisor configuration
46#[derive(Clone)]
47pub struct SupervisorConfig {
48    /// Retry strategy
49    pub retry_strategy: Box<dyn RetryStrategy>,
50    /// Connection timeout
51    pub connect_timeout: Duration,
52    /// Whether to exit on first connection failure
53    pub exit_on_first_failure: bool,
54}
55
56impl Default for SupervisorConfig {
57    fn default() -> Self {
58        Self {
59            retry_strategy: Box::new(ExponentialBackoff::standard()),
60            connect_timeout: Duration::from_secs(30),
61            exit_on_first_failure: false,
62        }
63    }
64}
65
66impl SupervisorConfig {
67    /// Create a new supervisor config
68    #[must_use]
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// Set retry strategy
74    #[must_use]
75    pub fn with_retry(mut self, strategy: impl RetryStrategy + 'static) -> Self {
76        self.retry_strategy = Box::new(strategy);
77        self
78    }
79
80    /// Set connection timeout
81    #[must_use]
82    pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
83        self.connect_timeout = timeout;
84        self
85    }
86
87    /// Set whether to exit on first connection failure
88    #[must_use]
89    pub const fn with_exit_on_first_failure(mut self, exit: bool) -> Self {
90        self.exit_on_first_failure = exit;
91        self
92    }
93
94    /// Create a fast reconnect configuration
95    #[must_use]
96    pub fn fast() -> Self {
97        Self {
98            retry_strategy: Box::new(ExponentialBackoff::fast()),
99            connect_timeout: Duration::from_secs(10),
100            exit_on_first_failure: false,
101        }
102    }
103
104    /// Create a stable connection configuration
105    #[must_use]
106    pub fn stable() -> Self {
107        Self {
108            retry_strategy: Box::new(ExponentialBackoff::conservative()),
109            connect_timeout: Duration::from_secs(60),
110            exit_on_first_failure: false,
111        }
112    }
113}
114
115/// Connection supervisor - manages the connection lifecycle
116pub struct ConnectionSupervisor<C: Connector = DefaultConnector> {
117    /// URI to connect to
118    uri: String,
119    /// Connector instance
120    connector: C,
121    /// Configuration
122    config: SupervisorConfig,
123    /// Connection state
124    state: Arc<ConnectionState>,
125    /// Event broadcaster
126    event_tx: broadcast::Sender<ConnectionEvent>,
127    /// Shutdown flag
128    shutdown: Arc<AtomicBool>,
129}
130
131impl ConnectionSupervisor<DefaultConnector> {
132    /// Create a new supervisor with default connector
133    pub fn new(uri: impl Into<String>) -> Self {
134        Self::with_connector(uri, DefaultConnector::new())
135    }
136}
137
138impl<C: Connector> ConnectionSupervisor<C> {
139    /// Create a new supervisor with custom connector
140    pub fn with_connector(uri: impl Into<String>, connector: C) -> Self {
141        let (event_tx, _) = broadcast::channel(64);
142
143        Self {
144            uri: uri.into(),
145            connector,
146            config: SupervisorConfig::default(),
147            state: Arc::new(ConnectionState::new()),
148            event_tx,
149            shutdown: Arc::new(AtomicBool::new(false)),
150        }
151    }
152
153    /// Configure the supervisor
154    #[must_use]
155    pub fn with_config(mut self, config: SupervisorConfig) -> Self {
156        self.config = config;
157        self
158    }
159
160    /// Get the URI
161    pub fn uri(&self) -> &str {
162        &self.uri
163    }
164
165    /// Get the current connection state snapshot
166    pub async fn snapshot(&self) -> ConnectionSnapshot {
167        self.state.snapshot().await
168    }
169
170    /// Check if currently connected
171    pub fn is_connected(&self) -> bool {
172        self.state.is_connected()
173    }
174
175    /// Get the current connection ID
176    pub fn connection_id(&self) -> u64 {
177        self.state.id()
178    }
179
180    /// Subscribe to connection events
181    pub fn subscribe(&self) -> broadcast::Receiver<ConnectionEvent> {
182        self.event_tx.subscribe()
183    }
184
185    /// Get a lightweight activity handle to update last-activity without borrowing self
186    pub fn activity_handle(&self) -> ActivityHandle {
187        ActivityHandle {
188            state: self.state.clone(),
189        }
190    }
191
192    /// Emit a fatal error event (used by higher layers for unrecoverable initialization failures)
193    pub fn fatal(&self, error: ConnectError) {
194        let _ = self.event_tx.send(ConnectionEvent::FatalError { error });
195    }
196
197    /// Request shutdown
198    pub fn shutdown(&self) {
199        self.state.mark_shutting_down();
200        self.shutdown.store(true, Ordering::Release);
201        let _ = self.event_tx.send(ConnectionEvent::Shutdown);
202    }
203
204    /// Check if shutdown was requested
205    pub fn is_shutdown_requested(&self) -> bool {
206        self.shutdown.load(Ordering::Acquire) || self.state.is_shutdown_requested()
207    }
208
209    /// Emit an event
210    fn emit(&self, event: ConnectionEvent) {
211        let _ = self.event_tx.send(event);
212    }
213
214    /// Connect with retry logic
215    ///
216    /// Attempts to connect to the configured URI, retrying according to the
217    /// configured retry strategy.
218    ///
219    /// # Errors
220    ///
221    /// - Returns [`SupervisorError::Shutdown`] if shutdown was requested during connection.
222    /// - Returns [`SupervisorError::Fatal`] if `exit_on_first_failure` is set and first connection fails.
223    /// - Returns [`SupervisorError::MaxRetriesExceeded`] if the retry strategy is exhausted.
224    #[allow(clippy::too_many_lines)]
225    pub async fn connect(&self) -> Result<C::Stream, SupervisorError> {
226        let mut retry_strategy = self.config.retry_strategy.clone();
227        let mut attempt = 0u32;
228        let mut is_first_attempt = true;
229
230        loop {
231            // Check for shutdown
232            if self.is_shutdown_requested() {
233                return Err(SupervisorError::Shutdown);
234            }
235
236            attempt += 1;
237
238            // Update state
239            if is_first_attempt {
240                self.state.mark_connecting();
241            } else {
242                self.state.mark_reconnecting();
243            }
244
245            self.emit(ConnectionEvent::Connecting { attempt });
246
247            // Attempt connection with timeout
248            let connect_result = tokio::time::timeout(
249                self.config.connect_timeout,
250                self.connector.connect(&self.uri),
251            )
252            .await;
253
254            match connect_result {
255                Ok(Ok((stream, _response))) => {
256                    // Success!
257                    let id = self.state.mark_connected().await;
258                    retry_strategy.reset();
259
260                    tracing::info!(
261                        uri = %self.uri,
262                        connection_id = id,
263                        attempt = attempt,
264                        "Connection established"
265                    );
266
267                    self.emit(ConnectionEvent::Connected { id });
268                    return Ok(stream);
269                }
270                Ok(Err(error)) => {
271                    // Connection error
272                    self.state.record_error(error.clone()).await;
273                    self.emit(ConnectionEvent::Error {
274                        error: error.clone(),
275                        attempt,
276                    });
277
278                    tracing::warn!(
279                        uri = %self.uri,
280                        attempt = attempt,
281                        error = ?error,
282                        "Connection failed"
283                    );
284
285                    // Check if we should retry
286                    if is_first_attempt && self.config.exit_on_first_failure {
287                        self.emit(ConnectionEvent::FatalError {
288                            error: error.clone(),
289                        });
290                        return Err(SupervisorError::Fatal(error.to_string()));
291                    }
292
293                    if let Some(delay) = retry_strategy.next_delay(&error, attempt) {
294                        self.emit(ConnectionEvent::ReconnectScheduled { delay, attempt });
295
296                        tracing::debug!(
297                            delay = ?delay,
298                            attempt = attempt,
299                            "Scheduling reconnection"
300                        );
301
302                        // Wait for delay, checking shutdown periodically
303                        tokio::time::sleep(delay).await;
304                        if self.is_shutdown_requested() {
305                            return Err(SupervisorError::Shutdown);
306                        }
307                    } else {
308                        // No more retries
309                        self.emit(ConnectionEvent::FatalError {
310                            error: error.clone(),
311                        });
312                        return Err(SupervisorError::MaxRetriesExceeded { attempts: attempt });
313                    }
314                }
315                Err(_) => {
316                    // Timeout
317                    let error = ConnectError::Timeout(self.config.connect_timeout);
318                    self.state.record_error(error.clone()).await;
319                    self.emit(ConnectionEvent::Error {
320                        error: error.clone(),
321                        attempt,
322                    });
323
324                    tracing::warn!(
325                        uri = %self.uri,
326                        attempt = attempt,
327                        timeout = ?self.config.connect_timeout,
328                        "Connection timeout"
329                    );
330
331                    if let Some(delay) = retry_strategy.next_delay(&error, attempt) {
332                        self.emit(ConnectionEvent::ReconnectScheduled { delay, attempt });
333
334                        tokio::time::sleep(delay).await;
335                        if self.is_shutdown_requested() {
336                            return Err(SupervisorError::Shutdown);
337                        }
338                    } else {
339                        self.emit(ConnectionEvent::FatalError {
340                            error: error.clone(),
341                        });
342                        return Err(SupervisorError::MaxRetriesExceeded { attempts: attempt });
343                    }
344                }
345            }
346
347            is_first_attempt = false;
348        }
349    }
350
351    /// Mark as disconnected (called externally when connection is lost)
352    pub async fn mark_disconnected(&self, reason: DisconnectReason) {
353        self.state.mark_disconnected(reason.clone()).await;
354        self.emit(ConnectionEvent::Disconnected { reason });
355    }
356
357    /// Update activity timestamp
358    pub async fn update_activity(&self) {
359        self.state.update_activity().await;
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    // Note: Full tests would require a MockConnector implementation
368    // These are basic tests for the API
369
370    #[test]
371    fn test_supervisor_config() {
372        let config = SupervisorConfig::fast();
373        assert_eq!(config.connect_timeout, Duration::from_secs(10));
374    }
375
376    #[test]
377    fn test_supervisor_creation() {
378        let supervisor = ConnectionSupervisor::new("wss://example.com/ws");
379        assert_eq!(supervisor.uri(), "wss://example.com/ws");
380        assert!(!supervisor.is_connected());
381    }
382
383    #[test]
384    fn test_supervisor_shutdown() {
385        let supervisor = ConnectionSupervisor::new("wss://example.com/ws");
386        assert!(!supervisor.is_shutdown_requested());
387
388        supervisor.shutdown();
389        assert!(supervisor.is_shutdown_requested());
390    }
391}