workflow_websocket/client/
wasm.rs

1use super::{
2    bindings::WebSocket as W3CWebSocket,
3    error::Error,
4    message::{Ack, Message},
5    result::Result,
6    ConnectOptions, ConnectResult, Handshake, Resolver, WebSocketConfig,
7};
8use futures::{select, select_biased, FutureExt};
9use js_sys::{ArrayBuffer, Uint8Array};
10use std::ops::Deref;
11use std::sync::{
12    atomic::{AtomicBool, Ordering},
13    Arc, Mutex,
14};
15use wasm_bindgen::JsCast;
16use web_sys::{
17    CloseEvent as WsCloseEvent, ErrorEvent as WsErrorEvent, MessageEvent as WsMessageEvent,
18};
19use workflow_core::runtime::*;
20use workflow_core::{
21    channel::{oneshot, unbounded, Channel, DuplexChannel, Sender},
22    task::spawn,
23};
24use workflow_log::*;
25use workflow_wasm::callback::*;
26
27impl TryFrom<WsMessageEvent> for Message {
28    type Error = Error;
29
30    fn try_from(event: WsMessageEvent) -> std::result::Result<Self, Self::Error> {
31        match event.data() {
32            data if data.is_instance_of::<ArrayBuffer>() => {
33                let buffer = Uint8Array::new(data.unchecked_ref());
34                Ok(Message::Binary(buffer.to_vec()))
35            }
36            data if data.is_string() => match data.as_string() {
37                Some(text) => Ok(Message::Text(text)),
38                None => Err(Error::DataEncoding),
39            },
40            _ => Err(Error::DataType),
41        }
42    }
43}
44
45#[derive(Clone)]
46pub struct WebSocket(W3CWebSocket);
47unsafe impl Send for WebSocket {}
48unsafe impl Sync for WebSocket {}
49impl Deref for WebSocket {
50    type Target = W3CWebSocket;
51    fn deref(&self) -> &W3CWebSocket {
52        &self.0
53    }
54}
55
56impl WebSocket {
57    #[allow(dead_code)]
58    const CONNECTING: u16 = W3CWebSocket::CONNECTING;
59    #[allow(dead_code)]
60    const OPEN: u16 = W3CWebSocket::OPEN;
61    #[allow(dead_code)]
62    const CLOSING: u16 = W3CWebSocket::CLOSING;
63    #[allow(dead_code)]
64    const CLOSED: u16 = W3CWebSocket::CLOSED;
65
66    #[allow(dead_code)]
67    pub fn new(url: &str) -> Result<Self> {
68        Ok(WebSocket(W3CWebSocket::new(url)?))
69    }
70
71    pub fn new_with_config(url: &str, config: &WebSocketConfig) -> Result<Self> {
72        Ok(WebSocket(W3CWebSocket::new_with_config(url, config)?))
73    }
74
75    fn cleanup(&self) {
76        self.set_onopen(None);
77        self.set_onclose(None);
78        self.set_onerror(None);
79        self.set_onmessage(None);
80    }
81}
82
83impl From<W3CWebSocket> for WebSocket {
84    fn from(ws: W3CWebSocket) -> Self {
85        WebSocket(ws)
86    }
87}
88
89#[derive(Default)]
90struct Settings {
91    // default url WebSocket should connect to
92    default_url: Option<String>,
93    // URL WebSocket is currently connected to
94    current_url: Option<String>,
95}
96
97#[allow(dead_code)]
98struct Inner {
99    ws: WebSocket,
100    callbacks: CallbackMap,
101}
102
103unsafe impl Send for Inner {}
104unsafe impl Sync for Inner {}
105
106pub struct WebSocketInterface {
107    inner: Arc<Mutex<Option<Inner>>>,
108    settings: Arc<Mutex<Settings>>,
109    config: Mutex<WebSocketConfig>,
110    reconnect: AtomicBool,
111    is_connected: AtomicBool,
112    event_channel: Channel<Message>,
113    sender_channel: Channel<(Message, Ack)>,
114    receiver_channel: Channel<Message>,
115    dispatcher_shutdown: DuplexChannel,
116}
117
118impl WebSocketInterface {
119    pub fn new(
120        url: Option<&str>,
121        config: Option<WebSocketConfig>,
122        sender_channel: Channel<(Message, Ack)>,
123        receiver_channel: Channel<Message>,
124    ) -> Result<WebSocketInterface> {
125        sanity_checks()?;
126
127        let settings = Settings {
128            default_url: url.map(String::from),
129            ..Default::default()
130        };
131
132        let iface = WebSocketInterface {
133            inner: Arc::new(Mutex::new(None)),
134            settings: Arc::new(Mutex::new(settings)),
135            config: Mutex::new(config.unwrap_or_default()),
136            sender_channel,
137            receiver_channel,
138            event_channel: Channel::unbounded(),
139            reconnect: AtomicBool::new(true),
140            is_connected: AtomicBool::new(false),
141            dispatcher_shutdown: DuplexChannel::unbounded(),
142        };
143
144        Ok(iface)
145    }
146
147    pub fn default_url(self: &Arc<Self>) -> Option<String> {
148        self.settings.lock().unwrap().default_url.clone()
149    }
150
151    pub fn current_url(self: &Arc<Self>) -> Option<String> {
152        self.settings.lock().unwrap().current_url.clone()
153    }
154
155    pub fn set_default_url(self: &Arc<Self>, url: &str) {
156        self.settings
157            .lock()
158            .unwrap()
159            .default_url
160            .replace(url.to_string());
161    }
162
163    pub fn set_current_url(self: &Arc<Self>, url: &str) {
164        self.settings
165            .lock()
166            .unwrap()
167            .current_url
168            .replace(url.to_string());
169    }
170
171    pub fn is_connected(self: &Arc<Self>) -> bool {
172        self.is_connected.load(Ordering::SeqCst)
173    }
174
175    fn resolver(&self) -> Option<Arc<dyn Resolver>> {
176        self.config.lock().unwrap().resolver.clone()
177    }
178
179    fn handshake(&self) -> Option<Arc<dyn Handshake>> {
180        self.config.lock().unwrap().handshake.clone()
181    }
182
183    pub fn configure(&self, config: WebSocketConfig) {
184        *self.config.lock().unwrap() = config;
185    }
186
187    async fn resolve_url(self: &Arc<Self>, options: &ConnectOptions) -> Result<String> {
188        let url = if let Some(url) = options.url.as_ref().or(self.default_url().as_ref()) {
189            url.clone()
190        } else if let Some(resolver) = self.resolver() {
191            resolver.resolve_url().await?
192        } else {
193            return Err(Error::MissingUrl);
194        };
195        self.set_current_url(&url);
196        Ok(url)
197    }
198
199    pub async fn connect(self: &Arc<Self>, options: ConnectOptions) -> ConnectResult<Error> {
200        let (connect_trigger, connect_listener) = oneshot::<Result<()>>();
201
202        let connect_trigger = Arc::new(Mutex::new(Some(connect_trigger)));
203        self.connect_impl(options.clone(), connect_trigger).await?;
204
205        match options.block_async_connect {
206            true => match connect_listener.recv().await? {
207                Ok(_) => Ok(None),
208                Err(e) => Err(e),
209            },
210            false => Ok(Some(connect_listener)),
211        }
212    }
213
214    fn retry_connect_impl(
215        self: Arc<Self>,
216        options: ConnectOptions,
217        connect_trigger: Arc<Mutex<Option<Sender<Result<()>>>>>,
218    ) -> futures::future::BoxFuture<'static, Result<()>> {
219        Box::pin(async move { self.connect_impl(options, connect_trigger).await })
220            as futures::future::BoxFuture<'static, Result<()>>
221    }
222
223    async fn connect_impl(
224        self: &Arc<Self>,
225        options: ConnectOptions,
226        connect_trigger: Arc<Mutex<Option<Sender<Result<()>>>>>,
227    ) -> Result<()> {
228        if self.inner.lock().unwrap().is_some() {
229            log_warn!("WebSocket::connect() called while already initialized");
230
231            return Err(Error::AlreadyInitialized);
232        }
233
234        self.reconnect.store(true, Ordering::SeqCst);
235
236        let url = match self.resolve_url(&options).await {
237            Ok(url) => url,
238            Err(err) => {
239                log_trace!("WebSocket unable to resolve URL: {err}");
240                let self_ = self.clone();
241
242                if options.strategy.is_fallback() {
243                    self.reconnect.store(false, Ordering::SeqCst);
244
245                    // let connect_trigger = connect_trigger.lock().unwrap().take();
246                    // if let Some(connect_trigger) = connect_trigger {
247                    //     connect_trigger.send(Err(err)).await.ok();
248                    // }
249
250                    return Err(err);
251                }
252
253                let connect_trigger_ = connect_trigger.clone();
254                spawn(async move {
255                    // if reconnect is true, we sleep for reconnect interval and try to reconnect
256                    if self_.reconnect.load(Ordering::SeqCst) {
257                        workflow_core::task::sleep(
258                            options
259                                .retry_interval
260                                .unwrap_or(std::time::Duration::from_millis(1000)),
261                        )
262                        .await;
263                        // check again if reconnect may have been disabled during sleep
264                        if self_.reconnect.load(Ordering::SeqCst) {
265                            self_
266                                .retry_connect_impl(options, connect_trigger_)
267                                .await
268                                .ok();
269                        }
270                    }
271                });
272
273                return Ok(());
274            }
275        };
276
277        let mut inner = self.inner.lock().unwrap();
278
279        let ws = WebSocket::new_with_config(&url, &self.config.lock().unwrap())?;
280        ws.set_binary_type(web_sys::BinaryType::Arraybuffer);
281
282        // - Message
283        let event_sender_ = self.event_channel.sender.clone();
284        let onmessage = callback!(move |event: WsMessageEvent| {
285            let msg: Message = event.try_into().expect("MessageEvent Error");
286            event_sender_.try_send(msg).unwrap_or_else(|err| {
287                log_trace!("WebSocket unable to try_send() `message` to event channel: `{err}`")
288            });
289        });
290        ws.set_onmessage(Some(onmessage.as_ref()));
291
292        // - Error
293        let onerror = callback!(move |_event: WsErrorEvent| {
294            // log_trace!("WS - error event: {:?}", _event);
295        });
296        ws.set_onerror(Some(onerror.as_ref()));
297
298        // - Open
299        let event_sender_ = self.event_channel.sender.clone();
300        let onopen = callback!(move || {
301            event_sender_.try_send(Message::Open).unwrap_or_else(|err| {
302                log_trace!("WebSocket unable to try_send() `open` to event channel: `{err}`")
303            });
304        });
305        ws.set_onopen(Some(onopen.as_ref()));
306
307        // - Close
308        let event_sender_ = self.event_channel.sender.clone();
309        let onclose = callback!(move |_event: WsCloseEvent| {
310            // log_trace!("WS - close event: {:?}", _event);
311            event_sender_
312                .try_send(Message::Close)
313                .unwrap_or_else(|err| {
314                    log_trace!("WebSocket unable to try_send() `close` to event channel: `{err}`")
315                });
316        });
317        ws.set_onclose(Some(onclose.as_ref()));
318
319        let callbacks = CallbackMap::new();
320        callbacks.retain(onmessage)?;
321        callbacks.retain(onerror)?;
322        callbacks.retain(onopen)?;
323        callbacks.retain(onclose)?;
324
325        *inner = Some(Inner {
326            ws: ws.clone(),
327            callbacks,
328        });
329
330        let self_ = self.clone();
331        spawn(async move {
332            self_
333                .dispatcher_task(&ws, options.clone(), connect_trigger.clone())
334                .await
335                .unwrap_or_else(|err| log_trace!("WebSocket error: {err}"));
336            // if reconnect is true, we sleep for reconnect interval and try to reconnect
337            if self_.reconnect.load(Ordering::SeqCst) {
338                workflow_core::task::sleep(
339                    options
340                        .retry_interval
341                        .unwrap_or(std::time::Duration::from_millis(1000)),
342                )
343                .await;
344                // check again if reconnect may have been disabled during sleep
345                if self_.reconnect.load(Ordering::SeqCst) {
346                    self_.reconnect(options, connect_trigger).await.ok();
347                }
348            }
349        });
350
351        Ok(())
352    }
353
354    fn ws(self: &Arc<Self>) -> Option<WebSocket> {
355        self.inner
356            .lock()
357            .expect("WebSocket:: inner lock failure")
358            .as_ref()
359            .map(|inner| inner.ws.clone())
360    }
361
362    #[allow(dead_code)]
363    pub fn try_send(self: &Arc<Self>, message: &Message) -> Result<()> {
364        if let Some(ws) = self.ws() {
365            ws.try_send(message)?;
366            Ok(())
367        } else {
368            Err(Error::NotConnected)
369        }
370    }
371
372    async fn handshake_impl(self: &Arc<Self>, ws: &WebSocket) -> Result<()> {
373        if let Some(handshake) = self.handshake() {
374            let (sender_tx, sender_rx) = unbounded();
375            let (receiver_tx, receiver_rx) = unbounded();
376            let (accept_tx, accept_rx) = oneshot();
377
378            spawn(async move {
379                accept_tx
380                    .send(handshake.handshake(&sender_tx, &receiver_rx).await)
381                    .await
382                    .unwrap_or_else(|err| {
383                        log_trace!("WebSocket handshake unable to send completion: `{}`", err)
384                    });
385            });
386
387            loop {
388                select_biased! {
389                    result = accept_rx.recv().fuse() => {
390                        return result?;
391                    },
392                    msg = sender_rx.recv().fuse() => {
393                        if let Ok(msg) = msg {
394                            ws.try_send(&msg)?;
395                        }
396                    },
397                    msg = self.event_channel.recv().fuse() => {
398                        if let Ok(msg) = msg {
399                            receiver_tx.send(msg).await?;
400                        }
401                    }
402                }
403            }
404        }
405
406        Ok(())
407    }
408
409    async fn dispatcher_task(
410        self: &Arc<Self>,
411        ws: &WebSocket,
412        options: ConnectOptions,
413        connect_trigger: Arc<Mutex<Option<Sender<Result<()>>>>>,
414    ) -> Result<()> {
415        'outer: loop {
416            select! {
417                _ = self.dispatcher_shutdown.request.receiver.recv().fuse() => {
418                    break 'outer;
419                },
420                msg = self.event_channel.recv().fuse() => {
421                    match msg {
422                        Ok(msg) => {
423                            match msg {
424                                Message::Binary(_) | Message::Text(_) => {
425                                    self.receiver_channel.sender.send(msg).await.unwrap();
426                                },
427                                Message::Open => {
428                                    // log_info!("WebSocket Message::Open");
429                                    // handle handshake failure
430                                    if let Err(err) = self.handshake_impl(ws).await {
431                                        log_info!("WebSocket handshake negotiation error: {err}");
432
433                                        if options.strategy.is_fallback() {
434                                            self.reconnect.store(false, Ordering::SeqCst);
435                                        }
436
437                                        let connect_trigger = connect_trigger.lock().unwrap().take();
438                                        if let Some(connect_trigger) = connect_trigger {
439                                            connect_trigger.send(Err(err)).await.ok();
440                                        }
441
442                                        return Err(Error::NegotiationFailure);
443                                    }
444
445                                    self.is_connected.store(true, Ordering::SeqCst);
446
447                                    let connect_trigger = connect_trigger.lock().unwrap().take();
448                                    if let Some(connect_trigger) = connect_trigger {
449                                        connect_trigger.send(Ok(())).await.ok();
450                                    }
451
452                                    self.receiver_channel.sender.send(msg).await.unwrap();
453                                },
454                                Message::Close => {
455                                    // log_info!("WebSocket Message::Close");
456
457                                    if let Some(inner) = self.inner.lock().unwrap().take() {
458                                        inner.ws.cleanup();
459                                    }
460
461                                    if self.is_connected.load(Ordering::SeqCst) {
462                                        self.is_connected.store(false, Ordering::SeqCst);
463                                        self.receiver_channel.sender.send(msg).await.unwrap();
464                                    } else if options.strategy.is_fallback() && options.block_async_connect {
465                                        // if we never connected and receiver Close while
466                                        // the strategy is Fallback, we disable reconnect
467                                        self.reconnect.store(false, Ordering::SeqCst);
468
469                                        let connect_trigger = connect_trigger.lock().unwrap().take();
470                                        if let Some(connect_trigger) = connect_trigger {
471                                            connect_trigger.send(Err(Error::Connect(self.current_url().unwrap()))).await.ok();
472                                        }
473                                    }
474
475                                    break 'outer;
476                                }
477                            }
478                        }
479                        Err(err) => {
480                            log_error!("WebSocket dispatcher channel error: {err}");
481                        }
482                    }
483                },
484                msg = self.sender_channel.receiver.recv().fuse() => {
485
486                    if let Ok((msg, ack)) = msg {
487
488                        // if ws.ready_state() != WebSocket::OPEN {
489                        //     return Err(Error::NotConnected);
490                        // }
491
492                        if let Some(ack) = ack {
493                            let result = ws
494                                .try_send(&msg)
495                                .map(Arc::new)
496                                .map_err(Arc::new);
497                            ack.send(result).await.unwrap_or_else(|err| {
498                                log_trace!("WebSocket error producing message ack {:?}", err)
499                            });
500                        } else {
501                            ws.try_send(&msg).unwrap_or_else(|err| {
502                                log_trace!("WebSocket unable to send `raw ws` message: `{err}`")
503                            });
504                        }
505                    }
506                }
507            }
508        }
509
510        Ok(())
511    }
512
513    async fn _shutdown(self: &Arc<Self>) -> Result<()> {
514        self.dispatcher_shutdown
515            .signal(())
516            .await
517            .map_err(|_| Error::DispatcherSignal)?;
518
519        Ok(())
520    }
521
522    pub async fn close(self: &Arc<Self>) -> Result<()> {
523        if let Some(inner) = self.inner.lock().unwrap().take() {
524            inner.ws.cleanup();
525            inner.ws.close_if_open()?;
526        }
527
528        if self.is_connected.load(Ordering::SeqCst) {
529            self.event_channel.try_send(Message::Close)?;
530        }
531
532        Ok(())
533    }
534
535    async fn reconnect(
536        self: &Arc<Self>,
537        options: ConnectOptions,
538        connect_trigger: Arc<Mutex<Option<Sender<Result<()>>>>>,
539    ) -> Result<()> {
540        self.close().await?;
541
542        self.clone()
543            .retry_connect_impl(options, connect_trigger)
544            .await?;
545
546        Ok(())
547    }
548
549    pub async fn disconnect(self: &Arc<Self>) -> Result<()> {
550        self.reconnect.store(false, Ordering::SeqCst);
551        self.close().await.ok();
552        Ok(())
553    }
554
555    pub fn trigger_abort(self: &Arc<Self>) -> Result<()> {
556        if self.is_connected.load(Ordering::SeqCst) {
557            if let Some(ws) = self.ws() {
558                ws.close_if_open()?;
559            }
560        }
561        Ok(())
562    }
563}
564
565impl Drop for WebSocketInterface {
566    fn drop(&mut self) {}
567}
568
569trait TrySendMessage {
570    fn try_send(&self, message: &Message) -> Result<()>;
571}
572
573impl TrySendMessage for WebSocket {
574    fn try_send(&self, message: &Message) -> Result<()> {
575        match message {
576            Message::Binary(data) => {
577                if is_cross_origin_isolated() {
578                    // Create a non-shared ArrayBuffer for cross-origin isolated environments (Flutter).
579                    let array_buffer: ArrayBuffer = ArrayBuffer::new(data.len() as u32);
580                    let uint8_array = Uint8Array::new(&array_buffer);
581                    uint8_array.copy_from(&data[..]);
582                    self.send_with_array_buffer(&array_buffer)
583                        .map_err(|e| e.into())
584                } else {
585                    self.send_with_u8_array(data).map_err(|e| e.into())
586                }
587            }
588            Message::Text(text) => self.send_with_str(text).map_err(|e| e.into()),
589            _ => {
590                panic!("WebSocket trying to convert unsupported message type: `{message:?}`");
591            }
592        }
593    }
594}
595
596fn w3c_websocket_available() -> Result<bool> {
597    Ok(js_sys::Reflect::get(&js_sys::global(), &"WebSocket".into())
598        .map(|v| !v.is_falsy())
599        .unwrap_or(false))
600}
601
602fn sanity_checks() -> Result<()> {
603    if !w3c_websocket_available()? {
604        if is_node() {
605            log_info!("");
606            log_info!("+------------------------------------------------------------");
607            log_info!("|");
608            log_info!("| w3c websocket is not available");
609            log_info!("|");
610            log_info!("| Please include `WebSocket` module as a project dependency");
611            log_info!("| and add the following line to your Node.js script:");
612            log_info!("|");
613            log_info!("| `globalThis.WebSocket = require(\"websocket\").w3cwebsocket;`");
614            log_info!("|");
615            log_info!("| (or use any other w3c-compatible module)");
616            log_info!("|");
617            log_info!("+------------------------------------------------------------");
618            log_info!("");
619        } else {
620            log_info!("");
621            log_error!("w3c websocket is not available");
622            log_info!("");
623        }
624        panic!("w3c websocket is not available");
625    }
626    Ok(())
627}