workflow_rpc/client/
mod.rs

1//!
2//! RPC client (operates uniformly in native and WASM-browser environments).
3//!
4
5pub mod error;
6mod interface;
7pub mod prelude;
8mod protocol;
9pub mod result;
10pub use crate::client::error::Error;
11pub use crate::client::result::Result;
12
13use crate::imports::*;
14use futures_util::select_biased;
15pub use interface::{Interface, Notification};
16use protocol::ProtocolHandler;
17pub use protocol::{BorshProtocol, JsonProtocol};
18use std::fmt::Debug;
19use std::str::FromStr;
20use workflow_core::{channel::Multiplexer, task::yield_now};
21pub use workflow_websocket::client::{
22    ConnectOptions, ConnectResult, ConnectStrategy, Resolver, ResolverResult, WebSocketConfig,
23    WebSocketError,
24};
25
26#[cfg(feature = "wasm32-sdk")]
27pub use workflow_websocket::client::options::IConnectOptions;
28
29///
30/// notification!() macro for declaration of RPC notification handlers
31///
32/// This macro simplifies creation of async notification handler
33/// closures supplied to the RPC notification interface. An
34/// async notification closure requires to be *Box*ed
35/// and its result must be *Pin*ned, resulting in the following
36/// syntax:
37///
38/// ```ignore
39///
40/// interface.notification(Box::new(Notification::new(|msg: MyMsg|
41///     Box::pin(
42///         async move {
43///             // ...
44///             Ok(())
45///         }
46///     )
47/// )))
48///
49/// ```
50///
51/// The notification macro adds the required Box and Pin syntax,
52/// simplifying the declaration as follows:
53///
54/// ```ignore
55/// interface.notification(notification!(|msg: MyMsg| async move {
56///     // ...
57///     Ok(())
58/// }))
59/// ```
60///
61pub use workflow_rpc_macros::client_notification as notification;
62
63#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
64pub enum Ctl {
65    Connect,
66    Disconnect,
67}
68
69impl std::fmt::Display for Ctl {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            Ctl::Connect => write!(f, "connect"),
73            Ctl::Disconnect => write!(f, "disconnect"),
74        }
75    }
76}
77
78impl FromStr for Ctl {
79    type Err = Error;
80
81    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
82        match s {
83            "connect" => Ok(Ctl::Connect),
84            "disconnect" => Ok(Ctl::Disconnect),
85            _ => Err(Error::InvalidEvent(s.to_string())),
86        }
87    }
88}
89
90#[async_trait]
91pub trait NotificationHandler: Send + Sync + 'static {
92    async fn handle_notification(&self, data: &[u8]) -> Result<()>;
93}
94
95#[derive(Default)]
96pub struct Options<'url> {
97    pub ctl_multiplexer: Option<Multiplexer<Ctl>>,
98    pub url: Option<&'url str>,
99}
100
101impl<'url> Options<'url> {
102    pub fn new() -> Self {
103        Self::default()
104    }
105
106    pub fn with_url(mut self, url: &'url str) -> Self {
107        self.url = Some(url);
108        self
109    }
110
111    pub fn with_ctl_multiplexer(mut self, ctl_multiplexer: Multiplexer<Ctl>) -> Self {
112        self.ctl_multiplexer = Some(ctl_multiplexer);
113        self
114    }
115}
116
117struct Inner<Ops> {
118    ws: Arc<WebSocket>,
119    is_running: AtomicBool,
120    is_connected: AtomicBool,
121    receiver_is_running: AtomicBool,
122    timeout_is_running: AtomicBool,
123    receiver_shutdown: DuplexChannel,
124    timeout_shutdown: DuplexChannel,
125    timeout_timer_interval: AtomicU64,
126    timeout_duration: AtomicU64,
127    ctl_multiplexer: Option<Multiplexer<Ctl>>,
128    protocol: Arc<dyn ProtocolHandler<Ops>>,
129}
130
131impl<Ops> Inner<Ops>
132where
133    Ops: OpsT,
134{
135    fn new<T>(
136        ws: Arc<WebSocket>,
137        protocol: Arc<dyn ProtocolHandler<Ops>>,
138        options: Options,
139    ) -> Result<Self>
140    where
141        T: ProtocolHandler<Ops> + Send + Sync + 'static,
142    {
143        let inner = Inner {
144            ws,
145            is_running: AtomicBool::new(false),
146            is_connected: AtomicBool::new(false),
147            receiver_is_running: AtomicBool::new(false),
148            receiver_shutdown: DuplexChannel::oneshot(),
149            timeout_is_running: AtomicBool::new(false),
150            timeout_shutdown: DuplexChannel::oneshot(),
151            timeout_duration: AtomicU64::new(60_000),
152            timeout_timer_interval: AtomicU64::new(5_000),
153            ctl_multiplexer: options.ctl_multiplexer,
154            protocol,
155        };
156
157        Ok(inner)
158    }
159
160    #[inline]
161    pub fn is_running(&self) -> bool {
162        self.is_running.load(Ordering::SeqCst)
163    }
164
165    pub fn start(self: &Arc<Self>) -> Result<()> {
166        if !self.is_running.load(Ordering::Relaxed) {
167            self.is_running.store(true, Ordering::SeqCst);
168            self.clone().timeout_task();
169            self.clone().receiver_task();
170        } else {
171            log_warn!("wRPC services are already running: rpc::start() was called multiple times");
172        }
173        Ok(())
174    }
175
176    pub async fn shutdown(self: &Arc<Self>) -> Result<()> {
177        self.ws.disconnect().await?;
178        yield_now().await;
179        if self.is_running.load(Ordering::Relaxed) {
180            self.stop_timeout().await?;
181            self.stop_receiver().await?;
182            self.is_running.store(false, Ordering::SeqCst);
183        }
184        Ok(())
185    }
186
187    fn timeout_task(self: Arc<Self>) {
188        self.timeout_is_running.store(true, Ordering::SeqCst);
189        workflow_core::task::spawn(async move {
190            'outer: loop {
191                let timeout_timer_interval =
192                    Duration::from_millis(self.timeout_timer_interval.load(Ordering::SeqCst));
193                select_biased! {
194                    _ = workflow_core::task::sleep(timeout_timer_interval).fuse() => {
195                        let timeout = Duration::from_millis(self.timeout_duration.load(Ordering::Relaxed));
196                        self.protocol.handle_timeout(timeout).await;
197                    },
198                    _ = self.timeout_shutdown.request.receiver.recv().fuse() => {
199                        break 'outer;
200                    },
201                }
202            }
203
204            self.timeout_is_running.store(false, Ordering::SeqCst);
205            self.timeout_shutdown.response.sender.send(()).await.unwrap_or_else(|err|
206                log_error!("wRPC client - unable to signal shutdown completion for timeout task: `{err}`"));
207        });
208    }
209
210    fn receiver_task(self: Arc<Self>) {
211        self.receiver_is_running.store(true, Ordering::SeqCst);
212        let receiver_rx = self.ws.receiver_rx().clone();
213        workflow_core::task::spawn(async move {
214            'outer: loop {
215                select_biased! {
216                    msg = receiver_rx.recv().fuse() => {
217                        match msg {
218                            Ok(msg) => {
219                                match msg {
220                                    WebSocketMessage::Binary(_) | WebSocketMessage::Text(_) => {
221                                        self.protocol.handle_message(msg).await
222                                        .unwrap_or_else(|err|log_trace!("wRPC error: `{err}`"));
223                                    }
224                                    WebSocketMessage::Open => {
225                                        self.is_connected.store(true, Ordering::SeqCst);
226                                        if let Some(ctl_channel) = &self.ctl_multiplexer {
227                                            ctl_channel.try_broadcast(Ctl::Connect).expect("ctl_channel.try_broadcast(Ctl::Connect)");
228                                        }
229                                    }
230                                    WebSocketMessage::Close => {
231                                        self.is_connected.store(false, Ordering::SeqCst);
232
233                                        self.protocol.handle_disconnect().await.unwrap_or_else(|err|{
234                                            log_error!("wRPC error during protocol disconnect: {err}");
235                                        });
236
237                                        if let Some(ctl_channel) = &self.ctl_multiplexer {
238                                            ctl_channel.try_broadcast(Ctl::Disconnect).expect("ctl_channel.try_broadcast(Ctl::Disconnect)");
239                                        }
240                                    }
241                                }
242                            },
243                            Err(err) => {
244                                log_error!("wRPC client receiver channel error: {err}");
245                                break 'outer;
246                            }
247                        }
248                    },
249                    _ = self.receiver_shutdown.request.receiver.recv().fuse() => {
250                        break 'outer;
251                    },
252
253                }
254            }
255
256            self.receiver_is_running.store(false, Ordering::SeqCst);
257            self.receiver_shutdown.response.sender.send(()).await.unwrap_or_else(|err|
258                log_error!("wRPC client - unable to signal shutdown completion for receiver task: `{err}`")
259            );
260        });
261    }
262
263    async fn stop_receiver(&self) -> Result<()> {
264        if !self.receiver_is_running.load(Ordering::SeqCst) {
265            return Ok(());
266        }
267
268        self.receiver_shutdown
269            .signal(())
270            .await
271            .unwrap_or_else(|err| {
272                log_error!("wRPC client unable to signal receiver shutdown: `{err}`")
273            });
274
275        Ok(())
276    }
277
278    async fn stop_timeout(&self) -> Result<()> {
279        if !self.timeout_is_running.load(Ordering::SeqCst) {
280            return Ok(());
281        }
282
283        self.timeout_shutdown
284            .signal(())
285            .await
286            .unwrap_or_else(|err| {
287                log_error!("wRPC client unable to signal timeout shutdown: `{err}`")
288            });
289
290        Ok(())
291    }
292}
293
294#[derive(Clone)]
295enum Protocol<Ops, Id>
296where
297    Ops: OpsT,
298    Id: IdT,
299{
300    Borsh(Arc<BorshProtocol<Ops, Id>>),
301    Json(Arc<JsonProtocol<Ops, Id>>),
302}
303
304impl<Ops, Id> From<Arc<dyn ProtocolHandler<Ops>>> for Protocol<Ops, Id>
305where
306    Ops: OpsT,
307    Id: IdT,
308{
309    fn from(protocol: Arc<dyn ProtocolHandler<Ops>>) -> Self {
310        if let Ok(protocol) = protocol.clone().downcast_arc::<BorshProtocol<Ops, Id>>() {
311            Protocol::Borsh(protocol)
312        } else if let Ok(protocol) = protocol.clone().downcast_arc::<JsonProtocol<Ops, Id>>() {
313            Protocol::Json(protocol)
314        } else {
315            panic!()
316        }
317    }
318}
319
320#[derive(Clone)]
321pub struct RpcClient<Ops, Id = Id64>
322where
323    Ops: OpsT,
324    Id: IdT,
325{
326    inner: Arc<Inner<Ops>>,
327    protocol: Protocol<Ops, Id>,
328    ops: PhantomData<Ops>,
329    id: PhantomData<Id>,
330}
331
332impl<Ops, Id> RpcClient<Ops, Id>
333where
334    Ops: OpsT,
335    Id: IdT,
336{
337    ///
338    /// Create new wRPC client connecting to the supplied URL
339    ///
340    /// This function accepts the [`Encoding`] enum argument denoting the underlying
341    /// protocol that will be used by the client. Current variants supported
342    /// are:
343    ///
344    /// - [`Encoding::Borsh`]
345    /// - [`Encoding::SerdeJson`]
346    ///
347    ///
348    pub fn new_with_encoding(
349        encoding: Encoding,
350        interface: Option<Arc<Interface<Ops>>>,
351        options: Options,
352        config: Option<WebSocketConfig>,
353    ) -> Result<RpcClient<Ops, Id>> {
354        match encoding {
355            Encoding::Borsh => Self::new::<BorshProtocol<Ops, Id>>(interface, options, config),
356            Encoding::SerdeJson => Self::new::<JsonProtocol<Ops, Id>>(interface, options, config),
357        }
358    }
359
360    ///
361    /// Create new wRPC client connecting to the supplied URL.
362    ///
363    /// This function accepts a generic denoting the underlying
364    /// protocol that will be used by the client. Current protocols
365    /// supported are:
366    ///
367    /// - [`BorshProtocol`]
368    /// - [`JsonProtocol`]
369    ///
370    ///
371    pub fn new<T>(
372        interface: Option<Arc<Interface<Ops>>>,
373        options: Options,
374        config: Option<WebSocketConfig>,
375    ) -> Result<RpcClient<Ops, Id>>
376    where
377        T: ProtocolHandler<Ops> + Send + Sync + 'static,
378    {
379        let url = options.url.map(sanitize_url).transpose()?;
380
381        let ws = Arc::new(WebSocket::new(url.as_deref(), config)?);
382        let protocol: Arc<dyn ProtocolHandler<Ops>> = Arc::new(T::new(ws.clone(), interface));
383        let inner = Arc::new(Inner::new::<T>(ws, protocol.clone(), options)?);
384
385        let client = RpcClient::<Ops, Id> {
386            inner,
387            protocol: protocol.into(),
388            ops: PhantomData,
389            id: PhantomData,
390        };
391
392        Ok(client)
393    }
394
395    /// Connect to the target wRPC endpoint (websocket address)
396    pub async fn connect(&self, options: ConnectOptions) -> ConnectResult<Error> {
397        if !self.inner.is_running() {
398            self.inner.start()?;
399        }
400        Ok(self.inner.ws.connect(options).await?)
401    }
402
403    /// Stop wRPC client services
404    pub async fn shutdown(&self) -> Result<()> {
405        self.inner.shutdown().await?;
406        Ok(())
407    }
408
409    pub fn ctl_multiplexer(&self) -> &Option<Multiplexer<Ctl>> {
410        &self.inner.ctl_multiplexer
411    }
412
413    /// Test if the underlying WebSocket is currently open
414    pub fn is_connected(&self) -> bool {
415        self.inner.ws.is_connected()
416    }
417
418    /// Obtain the current URL of the underlying WebSocket
419    pub fn url(&self) -> Option<String> {
420        self.inner.ws.url()
421    }
422
423    /// Change the URL of the underlying WebSocket
424    /// (applicable only to the next connection).
425    /// Alternatively, the new URL can be supplied
426    /// in the `connect()` method using [`ConnectOptions`].
427    pub fn set_url(&self, url: &str) -> Result<()> {
428        self.inner.ws.set_url(url);
429        Ok(())
430    }
431
432    /// Change the configuration of the underlying WebSocket.
433    /// This method can be used to alter the configuration
434    /// for the next connection.
435    pub fn configure(&self, config: WebSocketConfig) {
436        self.inner.ws.configure(config);
437    }
438
439    ///
440    /// Issue an async Notification to the server (no response is expected)
441    ///
442    /// Following are the trait requirements on the arguments:
443    /// - `Ops`: [`OpsT`]
444    /// - `Msg`: [`MsgT`]
445    ///
446    pub async fn notify<Msg>(&self, op: Ops, payload: Msg) -> Result<()>
447    where
448        Msg: BorshSerialize + Serialize + Send + Sync + 'static,
449    {
450        if !self.is_connected() {
451            return Err(WebSocketError::NotConnected.into());
452        }
453
454        match &self.protocol {
455            Protocol::Borsh(protocol) => {
456                protocol.notify(op, payload).await?;
457            }
458            Protocol::Json(protocol) => {
459                protocol.notify(op, payload).await?;
460            }
461        }
462
463        Ok(())
464    }
465
466    ///
467    /// Issue an async wRPC call and wait for response.
468    ///
469    /// Following are the trait requirements on the arguments:
470    /// - `Ops`: [`OpsT`]
471    /// - `Req`: [`MsgT`]
472    /// - `Resp`: [`MsgT`]
473    ///
474    pub async fn call<Req, Resp>(&self, op: Ops, req: Req) -> Result<Resp>
475    where
476        Req: MsgT,
477        Resp: MsgT,
478    {
479        if !self.is_connected() {
480            return Err(WebSocketError::NotConnected.into());
481        }
482
483        match &self.protocol {
484            Protocol::Borsh(protocol) => Ok(protocol.request(op, req).await?),
485            Protocol::Json(protocol) => Ok(protocol.request(op, req).await?),
486        }
487    }
488
489    /// Triggers a disconnection on the underlying WebSocket.
490    /// This is intended for debug purposes only.
491    /// Can be used to test application reconnection logic.
492    pub fn trigger_abort(&self) -> Result<()> {
493        Ok(self.inner.ws.trigger_abort()?)
494    }
495}
496
497fn sanitize_url(url: &str) -> Result<String> {
498    let url = url
499        .replace("wrpc://", "ws://")
500        .replace("wrpcs://", "wss://");
501    Ok(url)
502}