workflow_websocket/client/
native.rs

1use super::{
2    error::Error, message::Message, result::Result, Ack, ConnectOptions, ConnectResult,
3    ConnectStrategy, Handshake, Resolver, WebSocketConfig,
4};
5use futures::{
6    select_biased,
7    stream::{SplitSink, SplitStream},
8    FutureExt,
9};
10use futures_util::{SinkExt, StreamExt};
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Mutex};
13#[allow(unused_imports)]
14use std::time::Instant;
15use tokio::net::TcpStream;
16use tokio::time::timeout;
17use tokio_tungstenite::{
18    connect_async_with_config, tungstenite::protocol::Message as TsMessage, MaybeTlsStream,
19    WebSocketStream,
20};
21use tungstenite::protocol::WebSocketConfig as TsWebSocketConfig;
22pub use workflow_core as core;
23use workflow_core::channel::*;
24pub use workflow_log::*;
25
26impl From<Message> for tungstenite::Message {
27    fn from(message: Message) -> Self {
28        match message {
29            Message::Text(text) => text.into(),
30            Message::Binary(data) => data.into(),
31            _ => {
32                panic!("From<Message> for tungstenite::Message - invalid message type: {message:?}",)
33            }
34        }
35    }
36}
37
38impl From<tungstenite::Message> for Message {
39    fn from(message: tungstenite::Message) -> Self {
40        match message {
41            TsMessage::Text(text) => Message::Text(text),
42            TsMessage::Binary(data) => Message::Binary(data),
43            TsMessage::Close(_) => Message::Close,
44            _ => panic!(
45                "TryFrom<tungstenite::Message> for Message - invalid message type: {message:?}",
46            ),
47        }
48    }
49}
50
51impl From<WebSocketConfig> for TsWebSocketConfig {
52    fn from(config: WebSocketConfig) -> Self {
53        TsWebSocketConfig {
54            write_buffer_size: config.write_buffer_size,
55            max_write_buffer_size: config.max_write_buffer_size,
56            max_message_size: config.max_message_size,
57            max_frame_size: config.max_frame_size,
58            accept_unmasked_frames: config.accept_unmasked_frames,
59            ..Default::default()
60        }
61    }
62}
63
64#[derive(Default)]
65struct Settings {
66    default_url: Option<String>,
67    current_url: Option<String>,
68}
69
70pub struct WebSocketInterface {
71    settings: Mutex<Settings>,
72    config: Mutex<WebSocketConfig>,
73    reconnect: AtomicBool,
74    is_connected: AtomicBool,
75    receiver_channel: Channel<Message>,
76    sender_channel: Channel<(Message, Ack)>,
77    shutdown: DuplexChannel<()>,
78}
79
80impl WebSocketInterface {
81    pub fn new(
82        url: Option<&str>,
83        config: Option<WebSocketConfig>,
84        sender_channel: Channel<(Message, Ack)>,
85        receiver_channel: Channel<Message>,
86    ) -> Result<WebSocketInterface> {
87        let settings = Settings {
88            default_url: url.map(String::from),
89            ..Default::default()
90        };
91
92        let iface = WebSocketInterface {
93            settings: Mutex::new(settings),
94            config: Mutex::new(config.unwrap_or_default()),
95            receiver_channel,
96            sender_channel,
97            reconnect: AtomicBool::new(true),
98            is_connected: AtomicBool::new(false),
99            shutdown: DuplexChannel::unbounded(),
100        };
101
102        Ok(iface)
103    }
104
105    pub fn default_url(self: &Arc<Self>) -> Option<String> {
106        self.settings.lock().unwrap().default_url.clone()
107    }
108
109    pub fn current_url(self: &Arc<Self>) -> Option<String> {
110        self.settings.lock().unwrap().current_url.clone()
111    }
112
113    pub fn set_default_url(self: &Arc<Self>, url: &str) {
114        self.settings
115            .lock()
116            .unwrap()
117            .default_url
118            .replace(url.to_string());
119    }
120
121    pub fn set_current_url(self: &Arc<Self>, url: &str) {
122        self.settings
123            .lock()
124            .unwrap()
125            .current_url
126            .replace(url.to_string());
127    }
128
129    pub fn is_connected(self: &Arc<Self>) -> bool {
130        self.is_connected.load(Ordering::SeqCst)
131    }
132
133    fn resolver(&self) -> Option<Arc<dyn Resolver>> {
134        self.config.lock().unwrap().resolver.clone()
135    }
136
137    fn handshake(&self) -> Option<Arc<dyn Handshake>> {
138        self.config.lock().unwrap().handshake.clone()
139    }
140
141    pub fn configure(&self, config: WebSocketConfig) {
142        *self.config.lock().unwrap() = config;
143    }
144
145    fn config(&self) -> WebSocketConfig {
146        self.config.lock().unwrap().clone()
147    }
148
149    async fn resolve_url(self: &Arc<Self>, options: &ConnectOptions) -> Result<String> {
150        let url = if let Some(url) = options.url.as_ref().or(self.default_url().as_ref()) {
151            url.clone()
152        } else if let Some(resolver) = self.resolver() {
153            resolver.resolve_url().await?
154        } else {
155            return Err(Error::MissingUrl);
156        };
157        self.set_current_url(&url);
158        Ok(url)
159    }
160
161    pub async fn connect(self: &Arc<Self>, options: ConnectOptions) -> ConnectResult<Error> {
162        let this = self.clone();
163
164        if self.is_connected.load(Ordering::SeqCst) {
165            return Err(Error::AlreadyConnected);
166        }
167
168        let (connect_trigger, connect_listener) = oneshot::<Result<()>>();
169        let mut connect_trigger = Some(connect_trigger);
170
171        this.reconnect.store(true, Ordering::SeqCst);
172
173        let block_async_connect = options.block_async_connect;
174        let ts_websocket_config = Some(self.config().into());
175
176        core::task::spawn(async move {
177            'outer: loop {
178                match this.resolve_url(&options).await {
179                    Ok(url) => {
180                        let connect_future =
181                            connect_async_with_config(&url, ts_websocket_config, false);
182                        let timeout_future = timeout(options.connect_timeout(), connect_future);
183
184                        match timeout_future.await {
185                            // connect success
186                            Ok(Ok(stream)) => {
187                                // log_trace!("connected...");
188
189                                this.is_connected.store(true, Ordering::SeqCst);
190                                let (mut ws_stream, _) = stream;
191
192                                if connect_trigger.is_some() {
193                                    connect_trigger.take().unwrap().try_send(Ok(())).ok();
194                                }
195
196                                if let Err(err) = this.dispatcher(&mut ws_stream, &options).await {
197                                    log_trace!("WebSocket dispatcher error: {}", err);
198                                }
199
200                                this.is_connected.store(false, Ordering::SeqCst);
201                            }
202                            // connect error
203                            Ok(Err(e)) => {
204                                log_trace!("WebSocket failed to connect to {}: {}", url, e);
205                                if matches!(options.strategy, ConnectStrategy::Fallback) {
206                                    if options.block_async_connect && connect_trigger.is_some() {
207                                        connect_trigger
208                                            .take()
209                                            .unwrap()
210                                            .try_send(Err(e.into()))
211                                            .ok();
212                                    }
213                                    break;
214                                }
215                                workflow_core::task::sleep(options.retry_interval()).await;
216                            }
217                            // timeout error
218                            Err(_) => {
219                                log_trace!(
220                                    "WebSocket connection timeout while connecting to {}",
221                                    url
222                                );
223                                if matches!(options.strategy, ConnectStrategy::Fallback) {
224                                    if options.block_async_connect && connect_trigger.is_some() {
225                                        connect_trigger
226                                            .take()
227                                            .unwrap()
228                                            .try_send(Err(Error::ConnectionTimeout))
229                                            .ok();
230                                    }
231                                    break;
232                                }
233                                workflow_core::task::sleep(options.retry_interval()).await;
234                            }
235                        };
236
237                        if !this.reconnect.load(Ordering::SeqCst) {
238                            break 'outer;
239                        };
240                    }
241                    Err(err) => {
242                        log_trace!("WebSocket failed to get session URL: {}", err);
243                        if !this.reconnect.load(Ordering::SeqCst) {
244                            break 'outer;
245                        } else {
246                            workflow_core::task::sleep(options.retry_interval()).await;
247                        }
248                    }
249                }
250            }
251        });
252
253        match block_async_connect {
254            true => match connect_listener.recv().await? {
255                Ok(_) => Ok(None),
256                Err(e) => Err(e),
257            },
258            false => Ok(Some(connect_listener)),
259        }
260    }
261
262    async fn handshake_impl(
263        self: &Arc<Self>,
264        ws_sender: &mut SplitSink<&mut WebSocketStream<MaybeTlsStream<TcpStream>>, TsMessage>,
265        ws_receiver: &mut SplitStream<&mut WebSocketStream<MaybeTlsStream<TcpStream>>>,
266    ) -> Result<()> {
267        if let Some(handshake) = self.handshake() {
268            let (sender_tx, sender_rx) = unbounded();
269            let (receiver_tx, receiver_rx) = unbounded();
270            let (accept_tx, accept_rx) = oneshot();
271
272            core::task::spawn(async move {
273                accept_tx
274                    .send(handshake.handshake(&sender_tx, &receiver_rx).await)
275                    .await
276                    .unwrap_or_else(|err| {
277                        log_trace!("WebSocket handshake unable to send completion: `{}`", err)
278                    });
279            });
280
281            loop {
282                select_biased! {
283                    result = accept_rx.recv().fuse() => {
284                        return result?;
285                    },
286                    msg = sender_rx.recv().fuse() => {
287                        if let Ok(msg) = msg {
288                            ws_sender.send(msg.into()).await?;
289                        }
290                    },
291                    msg = ws_receiver.next().fuse() => {
292                        if let Some(Ok(msg)) = msg {
293                            receiver_tx.send(msg.into()).await?;
294                        } else {
295                            return Err(Error::NegotiationFailure);
296                        }
297                    }
298                }
299            }
300        }
301
302        Ok(())
303    }
304
305    async fn dispatcher(
306        self: &Arc<Self>,
307        ws_stream: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
308        _options: &ConnectOptions,
309    ) -> Result<()> {
310        let (mut ws_sender, mut ws_receiver) = ws_stream.split();
311
312        self.handshake_impl(&mut ws_sender, &mut ws_receiver)
313            .await?;
314
315        #[cfg(feature = "delay-reconnect")]
316        let connection_start = Instant::now();
317        #[cfg(feature = "delay-reconnect")]
318        let mut closed_ungracefully = false;
319
320        self.receiver_channel.send(Message::Open).await?;
321
322        loop {
323            select_biased! {
324                dispatch = self.sender_channel.recv().fuse() => {
325                    if let Ok((msg,ack)) = dispatch {
326                        if let Some(ack_sender) = ack {
327                            let result = ws_sender.send(msg.into()).await
328                                .map(Arc::new)
329                                .map_err(|err|Arc::new(err.into()));
330                            ack_sender.send(result).await?;
331                        } else {
332                            ws_sender.send(msg.into()).await?;
333                        }
334                    }
335                }
336                msg = ws_receiver.next().fuse() => {
337                    match msg {
338                        Some(Ok(msg)) => {
339                            match msg {
340                                TsMessage::Binary(_) | TsMessage::Text(_) | TsMessage::Close(_) => {
341                                    self
342                                        .receiver_channel
343                                        .send(msg.into())
344                                        .await?;
345                                }
346                                TsMessage::Ping(data) => {
347                                    ws_sender.send(TsMessage::Pong(data)).await?;
348                                },
349                                TsMessage::Pong(_) => { },
350                                TsMessage::Frame(_frame) => { },
351                            }
352                        }
353                        Some(Err(e)) => {
354                            self.receiver_channel.send(Message::Close).await?;
355                            log_trace!("WebSocket error: {}", e);
356                            #[cfg(feature = "delay-reconnect")] {
357                                closed_ungracefully = true;
358                            }
359                            break;
360                        }
361                        None => {
362                            self.receiver_channel.send(Message::Close).await?;
363                            log_trace!("WebSocket connection closed");
364                            #[cfg(feature = "delay-reconnect")] {
365                                closed_ungracefully = true;
366                            }
367                            break;
368                        }
369                    }
370                }
371                _ = self.shutdown.request.receiver.recv().fuse() => {
372                    self.receiver_channel.send(Message::Close).await?;
373                    self.shutdown.response.sender.send(()).await?;
374                    break;
375                }
376            }
377        }
378
379        // if connection has closed ungracefully within 1 second, wait for retry interval
380        #[cfg(feature = "delay-reconnect")]
381        if closed_ungracefully && connection_start.elapsed().as_millis() < 1_000 {
382            workflow_core::task::sleep(_options.retry_interval()).await;
383        }
384
385        Ok(())
386    }
387
388    pub async fn close(self: &Arc<Self>) -> Result<()> {
389        // if self.inner.lock().unwrap().is_some() {
390        if self.is_connected.load(Ordering::SeqCst) {
391            // } self.inner.lock().unwrap().is_some() {
392            self.shutdown
393                .request
394                .sender
395                .send(())
396                .await
397                .unwrap_or_else(|err| {
398                    log_error!("Unable to signal WebSocket dispatcher shutdown: {}", err)
399                });
400            self.shutdown
401                .response
402                .receiver
403                .recv()
404                .await
405                .unwrap_or_else(|err| {
406                    log_error!("Unable to receive WebSocket dispatcher shutdown: {}", err)
407                });
408        }
409
410        Ok(())
411    }
412
413    pub async fn disconnect(self: &Arc<Self>) -> Result<()> {
414        self.reconnect.store(false, Ordering::SeqCst);
415        self.close().await?;
416        Ok(())
417    }
418
419    pub fn trigger_abort(self: &Arc<Self>) -> Result<()> {
420        if self.is_connected.load(Ordering::SeqCst) {
421            self.receiver_channel.try_send(Message::Close)?;
422        }
423        Ok(())
424    }
425}