stream_tungstenite/connection/
supervisor.rs1use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::broadcast;
7
8use super::connector::{Connector, DefaultConnector};
9use super::retry::{ExponentialBackoff, RetryStrategy};
10use super::state::{ConnectionSnapshot, ConnectionState};
11use crate::error::{ConnectError, DisconnectReason, SupervisorError};
12
13#[derive(Clone)]
15pub struct ActivityHandle {
16 state: std::sync::Arc<ConnectionState>,
17}
18
19impl ActivityHandle {
20 pub async fn update(&self) {
22 self.state.update_activity().await;
23 }
24}
25
26#[derive(Debug, Clone)]
28pub enum ConnectionEvent {
29 Connecting { attempt: u32 },
31 Connected { id: u64 },
33 Disconnected { reason: DisconnectReason },
35 ReconnectScheduled { delay: Duration, attempt: u32 },
37 Error { error: ConnectError, attempt: u32 },
39 FatalError { error: ConnectError },
41 Shutdown,
43}
44
45#[derive(Clone)]
47pub struct SupervisorConfig {
48 pub retry_strategy: Box<dyn RetryStrategy>,
50 pub connect_timeout: Duration,
52 pub exit_on_first_failure: bool,
54}
55
56impl Default for SupervisorConfig {
57 fn default() -> Self {
58 Self {
59 retry_strategy: Box::new(ExponentialBackoff::standard()),
60 connect_timeout: Duration::from_secs(30),
61 exit_on_first_failure: false,
62 }
63 }
64}
65
66impl SupervisorConfig {
67 #[must_use]
69 pub fn new() -> Self {
70 Self::default()
71 }
72
73 #[must_use]
75 pub fn with_retry(mut self, strategy: impl RetryStrategy + 'static) -> Self {
76 self.retry_strategy = Box::new(strategy);
77 self
78 }
79
80 #[must_use]
82 pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
83 self.connect_timeout = timeout;
84 self
85 }
86
87 #[must_use]
89 pub const fn with_exit_on_first_failure(mut self, exit: bool) -> Self {
90 self.exit_on_first_failure = exit;
91 self
92 }
93
94 #[must_use]
96 pub fn fast() -> Self {
97 Self {
98 retry_strategy: Box::new(ExponentialBackoff::fast()),
99 connect_timeout: Duration::from_secs(10),
100 exit_on_first_failure: false,
101 }
102 }
103
104 #[must_use]
106 pub fn stable() -> Self {
107 Self {
108 retry_strategy: Box::new(ExponentialBackoff::conservative()),
109 connect_timeout: Duration::from_secs(60),
110 exit_on_first_failure: false,
111 }
112 }
113}
114
115pub struct ConnectionSupervisor<C: Connector = DefaultConnector> {
117 uri: String,
119 connector: C,
121 config: SupervisorConfig,
123 state: Arc<ConnectionState>,
125 event_tx: broadcast::Sender<ConnectionEvent>,
127 shutdown: Arc<AtomicBool>,
129}
130
131impl ConnectionSupervisor<DefaultConnector> {
132 pub fn new(uri: impl Into<String>) -> Self {
134 Self::with_connector(uri, DefaultConnector::new())
135 }
136}
137
138impl<C: Connector> ConnectionSupervisor<C> {
139 pub fn with_connector(uri: impl Into<String>, connector: C) -> Self {
141 let (event_tx, _) = broadcast::channel(64);
142
143 Self {
144 uri: uri.into(),
145 connector,
146 config: SupervisorConfig::default(),
147 state: Arc::new(ConnectionState::new()),
148 event_tx,
149 shutdown: Arc::new(AtomicBool::new(false)),
150 }
151 }
152
153 #[must_use]
155 pub fn with_config(mut self, config: SupervisorConfig) -> Self {
156 self.config = config;
157 self
158 }
159
160 pub fn uri(&self) -> &str {
162 &self.uri
163 }
164
165 pub async fn snapshot(&self) -> ConnectionSnapshot {
167 self.state.snapshot().await
168 }
169
170 pub fn is_connected(&self) -> bool {
172 self.state.is_connected()
173 }
174
175 pub fn connection_id(&self) -> u64 {
177 self.state.id()
178 }
179
180 pub fn subscribe(&self) -> broadcast::Receiver<ConnectionEvent> {
182 self.event_tx.subscribe()
183 }
184
185 pub fn activity_handle(&self) -> ActivityHandle {
187 ActivityHandle {
188 state: self.state.clone(),
189 }
190 }
191
192 pub fn fatal(&self, error: ConnectError) {
194 let _ = self.event_tx.send(ConnectionEvent::FatalError { error });
195 }
196
197 pub fn shutdown(&self) {
199 self.state.mark_shutting_down();
200 self.shutdown.store(true, Ordering::Release);
201 let _ = self.event_tx.send(ConnectionEvent::Shutdown);
202 }
203
204 pub fn is_shutdown_requested(&self) -> bool {
206 self.shutdown.load(Ordering::Acquire) || self.state.is_shutdown_requested()
207 }
208
209 fn emit(&self, event: ConnectionEvent) {
211 let _ = self.event_tx.send(event);
212 }
213
214 #[allow(clippy::too_many_lines)]
225 pub async fn connect(&self) -> Result<C::Stream, SupervisorError> {
226 let mut retry_strategy = self.config.retry_strategy.clone();
227 let mut attempt = 0u32;
228 let mut is_first_attempt = true;
229
230 loop {
231 if self.is_shutdown_requested() {
233 return Err(SupervisorError::Shutdown);
234 }
235
236 attempt += 1;
237
238 if is_first_attempt {
240 self.state.mark_connecting();
241 } else {
242 self.state.mark_reconnecting();
243 }
244
245 self.emit(ConnectionEvent::Connecting { attempt });
246
247 let connect_result = tokio::time::timeout(
249 self.config.connect_timeout,
250 self.connector.connect(&self.uri),
251 )
252 .await;
253
254 match connect_result {
255 Ok(Ok((stream, _response))) => {
256 let id = self.state.mark_connected().await;
258 retry_strategy.reset();
259
260 tracing::info!(
261 uri = %self.uri,
262 connection_id = id,
263 attempt = attempt,
264 "Connection established"
265 );
266
267 self.emit(ConnectionEvent::Connected { id });
268 return Ok(stream);
269 }
270 Ok(Err(error)) => {
271 self.state.record_error(error.clone()).await;
273 self.emit(ConnectionEvent::Error {
274 error: error.clone(),
275 attempt,
276 });
277
278 tracing::warn!(
279 uri = %self.uri,
280 attempt = attempt,
281 error = ?error,
282 "Connection failed"
283 );
284
285 if is_first_attempt && self.config.exit_on_first_failure {
287 self.emit(ConnectionEvent::FatalError {
288 error: error.clone(),
289 });
290 return Err(SupervisorError::Fatal(error.to_string()));
291 }
292
293 if let Some(delay) = retry_strategy.next_delay(&error, attempt) {
294 self.emit(ConnectionEvent::ReconnectScheduled { delay, attempt });
295
296 tracing::debug!(
297 delay = ?delay,
298 attempt = attempt,
299 "Scheduling reconnection"
300 );
301
302 tokio::time::sleep(delay).await;
304 if self.is_shutdown_requested() {
305 return Err(SupervisorError::Shutdown);
306 }
307 } else {
308 self.emit(ConnectionEvent::FatalError {
310 error: error.clone(),
311 });
312 return Err(SupervisorError::MaxRetriesExceeded { attempts: attempt });
313 }
314 }
315 Err(_) => {
316 let error = ConnectError::Timeout(self.config.connect_timeout);
318 self.state.record_error(error.clone()).await;
319 self.emit(ConnectionEvent::Error {
320 error: error.clone(),
321 attempt,
322 });
323
324 tracing::warn!(
325 uri = %self.uri,
326 attempt = attempt,
327 timeout = ?self.config.connect_timeout,
328 "Connection timeout"
329 );
330
331 if let Some(delay) = retry_strategy.next_delay(&error, attempt) {
332 self.emit(ConnectionEvent::ReconnectScheduled { delay, attempt });
333
334 tokio::time::sleep(delay).await;
335 if self.is_shutdown_requested() {
336 return Err(SupervisorError::Shutdown);
337 }
338 } else {
339 self.emit(ConnectionEvent::FatalError {
340 error: error.clone(),
341 });
342 return Err(SupervisorError::MaxRetriesExceeded { attempts: attempt });
343 }
344 }
345 }
346
347 is_first_attempt = false;
348 }
349 }
350
351 pub async fn mark_disconnected(&self, reason: DisconnectReason) {
353 self.state.mark_disconnected(reason.clone()).await;
354 self.emit(ConnectionEvent::Disconnected { reason });
355 }
356
357 pub async fn update_activity(&self) {
359 self.state.update_activity().await;
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
371 fn test_supervisor_config() {
372 let config = SupervisorConfig::fast();
373 assert_eq!(config.connect_timeout, Duration::from_secs(10));
374 }
375
376 #[test]
377 fn test_supervisor_creation() {
378 let supervisor = ConnectionSupervisor::new("wss://example.com/ws");
379 assert_eq!(supervisor.uri(), "wss://example.com/ws");
380 assert!(!supervisor.is_connected());
381 }
382
383 #[test]
384 fn test_supervisor_shutdown() {
385 let supervisor = ConnectionSupervisor::new("wss://example.com/ws");
386 assert!(!supervisor.is_shutdown_requested());
387
388 supervisor.shutdown();
389 assert!(supervisor.is_shutdown_requested());
390 }
391}