1pub mod error;
6mod interface;
7pub mod prelude;
8mod protocol;
9pub mod result;
10pub use crate::client::error::Error;
11pub use crate::client::result::Result;
12
13use crate::imports::*;
14use futures_util::select_biased;
15pub use interface::{Interface, Notification};
16use protocol::ProtocolHandler;
17pub use protocol::{BorshProtocol, JsonProtocol};
18use std::fmt::Debug;
19use std::str::FromStr;
20use workflow_core::{channel::Multiplexer, task::yield_now};
21pub use workflow_websocket::client::{
22 ConnectOptions, ConnectResult, ConnectStrategy, Resolver, ResolverResult, WebSocketConfig,
23 WebSocketError,
24};
25
26#[cfg(feature = "wasm32-sdk")]
27pub use workflow_websocket::client::options::IConnectOptions;
28
29pub use workflow_rpc_macros::client_notification as notification;
62
63#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
64pub enum Ctl {
65 Connect,
66 Disconnect,
67}
68
69impl std::fmt::Display for Ctl {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 Ctl::Connect => write!(f, "connect"),
73 Ctl::Disconnect => write!(f, "disconnect"),
74 }
75 }
76}
77
78impl FromStr for Ctl {
79 type Err = Error;
80
81 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
82 match s {
83 "connect" => Ok(Ctl::Connect),
84 "disconnect" => Ok(Ctl::Disconnect),
85 _ => Err(Error::InvalidEvent(s.to_string())),
86 }
87 }
88}
89
90#[async_trait]
91pub trait NotificationHandler: Send + Sync + 'static {
92 async fn handle_notification(&self, data: &[u8]) -> Result<()>;
93}
94
95#[derive(Default)]
96pub struct Options<'url> {
97 pub ctl_multiplexer: Option<Multiplexer<Ctl>>,
98 pub url: Option<&'url str>,
99}
100
101impl<'url> Options<'url> {
102 pub fn new() -> Self {
103 Self::default()
104 }
105
106 pub fn with_url(mut self, url: &'url str) -> Self {
107 self.url = Some(url);
108 self
109 }
110
111 pub fn with_ctl_multiplexer(mut self, ctl_multiplexer: Multiplexer<Ctl>) -> Self {
112 self.ctl_multiplexer = Some(ctl_multiplexer);
113 self
114 }
115}
116
117struct Inner<Ops> {
118 ws: Arc<WebSocket>,
119 is_running: AtomicBool,
120 is_connected: AtomicBool,
121 receiver_is_running: AtomicBool,
122 timeout_is_running: AtomicBool,
123 receiver_shutdown: DuplexChannel,
124 timeout_shutdown: DuplexChannel,
125 timeout_timer_interval: AtomicU64,
126 timeout_duration: AtomicU64,
127 ctl_multiplexer: Option<Multiplexer<Ctl>>,
128 protocol: Arc<dyn ProtocolHandler<Ops>>,
129}
130
131impl<Ops> Inner<Ops>
132where
133 Ops: OpsT,
134{
135 fn new<T>(
136 ws: Arc<WebSocket>,
137 protocol: Arc<dyn ProtocolHandler<Ops>>,
138 options: Options,
139 ) -> Result<Self>
140 where
141 T: ProtocolHandler<Ops> + Send + Sync + 'static,
142 {
143 let inner = Inner {
144 ws,
145 is_running: AtomicBool::new(false),
146 is_connected: AtomicBool::new(false),
147 receiver_is_running: AtomicBool::new(false),
148 receiver_shutdown: DuplexChannel::oneshot(),
149 timeout_is_running: AtomicBool::new(false),
150 timeout_shutdown: DuplexChannel::oneshot(),
151 timeout_duration: AtomicU64::new(60_000),
152 timeout_timer_interval: AtomicU64::new(5_000),
153 ctl_multiplexer: options.ctl_multiplexer,
154 protocol,
155 };
156
157 Ok(inner)
158 }
159
160 #[inline]
161 pub fn is_running(&self) -> bool {
162 self.is_running.load(Ordering::SeqCst)
163 }
164
165 pub fn start(self: &Arc<Self>) -> Result<()> {
166 if !self.is_running.load(Ordering::Relaxed) {
167 self.is_running.store(true, Ordering::SeqCst);
168 self.clone().timeout_task();
169 self.clone().receiver_task();
170 } else {
171 log_warn!("wRPC services are already running: rpc::start() was called multiple times");
172 }
173 Ok(())
174 }
175
176 pub async fn shutdown(self: &Arc<Self>) -> Result<()> {
177 self.ws.disconnect().await?;
178 yield_now().await;
179 if self.is_running.load(Ordering::Relaxed) {
180 self.stop_timeout().await?;
181 self.stop_receiver().await?;
182 self.is_running.store(false, Ordering::SeqCst);
183 }
184 Ok(())
185 }
186
187 fn timeout_task(self: Arc<Self>) {
188 self.timeout_is_running.store(true, Ordering::SeqCst);
189 workflow_core::task::spawn(async move {
190 'outer: loop {
191 let timeout_timer_interval =
192 Duration::from_millis(self.timeout_timer_interval.load(Ordering::SeqCst));
193 select_biased! {
194 _ = workflow_core::task::sleep(timeout_timer_interval).fuse() => {
195 let timeout = Duration::from_millis(self.timeout_duration.load(Ordering::Relaxed));
196 self.protocol.handle_timeout(timeout).await;
197 },
198 _ = self.timeout_shutdown.request.receiver.recv().fuse() => {
199 break 'outer;
200 },
201 }
202 }
203
204 self.timeout_is_running.store(false, Ordering::SeqCst);
205 self.timeout_shutdown.response.sender.send(()).await.unwrap_or_else(|err|
206 log_error!("wRPC client - unable to signal shutdown completion for timeout task: `{err}`"));
207 });
208 }
209
210 fn receiver_task(self: Arc<Self>) {
211 self.receiver_is_running.store(true, Ordering::SeqCst);
212 let receiver_rx = self.ws.receiver_rx().clone();
213 workflow_core::task::spawn(async move {
214 'outer: loop {
215 select_biased! {
216 msg = receiver_rx.recv().fuse() => {
217 match msg {
218 Ok(msg) => {
219 match msg {
220 WebSocketMessage::Binary(_) | WebSocketMessage::Text(_) => {
221 self.protocol.handle_message(msg).await
222 .unwrap_or_else(|err|log_trace!("wRPC error: `{err}`"));
223 }
224 WebSocketMessage::Open => {
225 self.is_connected.store(true, Ordering::SeqCst);
226 if let Some(ctl_channel) = &self.ctl_multiplexer {
227 ctl_channel.try_broadcast(Ctl::Connect).expect("ctl_channel.try_broadcast(Ctl::Connect)");
228 }
229 }
230 WebSocketMessage::Close => {
231 self.is_connected.store(false, Ordering::SeqCst);
232
233 self.protocol.handle_disconnect().await.unwrap_or_else(|err|{
234 log_error!("wRPC error during protocol disconnect: {err}");
235 });
236
237 if let Some(ctl_channel) = &self.ctl_multiplexer {
238 ctl_channel.try_broadcast(Ctl::Disconnect).expect("ctl_channel.try_broadcast(Ctl::Disconnect)");
239 }
240 }
241 }
242 },
243 Err(err) => {
244 log_error!("wRPC client receiver channel error: {err}");
245 break 'outer;
246 }
247 }
248 },
249 _ = self.receiver_shutdown.request.receiver.recv().fuse() => {
250 break 'outer;
251 },
252
253 }
254 }
255
256 self.receiver_is_running.store(false, Ordering::SeqCst);
257 self.receiver_shutdown.response.sender.send(()).await.unwrap_or_else(|err|
258 log_error!("wRPC client - unable to signal shutdown completion for receiver task: `{err}`")
259 );
260 });
261 }
262
263 async fn stop_receiver(&self) -> Result<()> {
264 if !self.receiver_is_running.load(Ordering::SeqCst) {
265 return Ok(());
266 }
267
268 self.receiver_shutdown
269 .signal(())
270 .await
271 .unwrap_or_else(|err| {
272 log_error!("wRPC client unable to signal receiver shutdown: `{err}`")
273 });
274
275 Ok(())
276 }
277
278 async fn stop_timeout(&self) -> Result<()> {
279 if !self.timeout_is_running.load(Ordering::SeqCst) {
280 return Ok(());
281 }
282
283 self.timeout_shutdown
284 .signal(())
285 .await
286 .unwrap_or_else(|err| {
287 log_error!("wRPC client unable to signal timeout shutdown: `{err}`")
288 });
289
290 Ok(())
291 }
292}
293
294#[derive(Clone)]
295enum Protocol<Ops, Id>
296where
297 Ops: OpsT,
298 Id: IdT,
299{
300 Borsh(Arc<BorshProtocol<Ops, Id>>),
301 Json(Arc<JsonProtocol<Ops, Id>>),
302}
303
304impl<Ops, Id> From<Arc<dyn ProtocolHandler<Ops>>> for Protocol<Ops, Id>
305where
306 Ops: OpsT,
307 Id: IdT,
308{
309 fn from(protocol: Arc<dyn ProtocolHandler<Ops>>) -> Self {
310 if let Ok(protocol) = protocol.clone().downcast_arc::<BorshProtocol<Ops, Id>>() {
311 Protocol::Borsh(protocol)
312 } else if let Ok(protocol) = protocol.clone().downcast_arc::<JsonProtocol<Ops, Id>>() {
313 Protocol::Json(protocol)
314 } else {
315 panic!()
316 }
317 }
318}
319
320#[derive(Clone)]
321pub struct RpcClient<Ops, Id = Id64>
322where
323 Ops: OpsT,
324 Id: IdT,
325{
326 inner: Arc<Inner<Ops>>,
327 protocol: Protocol<Ops, Id>,
328 ops: PhantomData<Ops>,
329 id: PhantomData<Id>,
330}
331
332impl<Ops, Id> RpcClient<Ops, Id>
333where
334 Ops: OpsT,
335 Id: IdT,
336{
337 pub fn new_with_encoding(
349 encoding: Encoding,
350 interface: Option<Arc<Interface<Ops>>>,
351 options: Options,
352 config: Option<WebSocketConfig>,
353 ) -> Result<RpcClient<Ops, Id>> {
354 match encoding {
355 Encoding::Borsh => Self::new::<BorshProtocol<Ops, Id>>(interface, options, config),
356 Encoding::SerdeJson => Self::new::<JsonProtocol<Ops, Id>>(interface, options, config),
357 }
358 }
359
360 pub fn new<T>(
372 interface: Option<Arc<Interface<Ops>>>,
373 options: Options,
374 config: Option<WebSocketConfig>,
375 ) -> Result<RpcClient<Ops, Id>>
376 where
377 T: ProtocolHandler<Ops> + Send + Sync + 'static,
378 {
379 let url = options.url.map(sanitize_url).transpose()?;
380
381 let ws = Arc::new(WebSocket::new(url.as_deref(), config)?);
382 let protocol: Arc<dyn ProtocolHandler<Ops>> = Arc::new(T::new(ws.clone(), interface));
383 let inner = Arc::new(Inner::new::<T>(ws, protocol.clone(), options)?);
384
385 let client = RpcClient::<Ops, Id> {
386 inner,
387 protocol: protocol.into(),
388 ops: PhantomData,
389 id: PhantomData,
390 };
391
392 Ok(client)
393 }
394
395 pub async fn connect(&self, options: ConnectOptions) -> ConnectResult<Error> {
397 if !self.inner.is_running() {
398 self.inner.start()?;
399 }
400 Ok(self.inner.ws.connect(options).await?)
401 }
402
403 pub async fn shutdown(&self) -> Result<()> {
405 self.inner.shutdown().await?;
406 Ok(())
407 }
408
409 pub fn ctl_multiplexer(&self) -> &Option<Multiplexer<Ctl>> {
410 &self.inner.ctl_multiplexer
411 }
412
413 pub fn is_connected(&self) -> bool {
415 self.inner.ws.is_connected()
416 }
417
418 pub fn url(&self) -> Option<String> {
420 self.inner.ws.url()
421 }
422
423 pub fn set_url(&self, url: &str) -> Result<()> {
428 self.inner.ws.set_url(url);
429 Ok(())
430 }
431
432 pub fn configure(&self, config: WebSocketConfig) {
436 self.inner.ws.configure(config);
437 }
438
439 pub async fn notify<Msg>(&self, op: Ops, payload: Msg) -> Result<()>
447 where
448 Msg: BorshSerialize + Serialize + Send + Sync + 'static,
449 {
450 if !self.is_connected() {
451 return Err(WebSocketError::NotConnected.into());
452 }
453
454 match &self.protocol {
455 Protocol::Borsh(protocol) => {
456 protocol.notify(op, payload).await?;
457 }
458 Protocol::Json(protocol) => {
459 protocol.notify(op, payload).await?;
460 }
461 }
462
463 Ok(())
464 }
465
466 pub async fn call<Req, Resp>(&self, op: Ops, req: Req) -> Result<Resp>
475 where
476 Req: MsgT,
477 Resp: MsgT,
478 {
479 if !self.is_connected() {
480 return Err(WebSocketError::NotConnected.into());
481 }
482
483 match &self.protocol {
484 Protocol::Borsh(protocol) => Ok(protocol.request(op, req).await?),
485 Protocol::Json(protocol) => Ok(protocol.request(op, req).await?),
486 }
487 }
488
489 pub fn trigger_abort(&self) -> Result<()> {
493 Ok(self.inner.ws.trigger_abort()?)
494 }
495}
496
497fn sanitize_url(url: &str) -> Result<String> {
498 let url = url
499 .replace("wrpc://", "ws://")
500 .replace("wrpcs://", "wss://");
501 Ok(url)
502}