workflow_rpc/server/
mod.rs

1//!
2//! RPC server module (native only). This module encapsulates
3//! server-side types used to create an RPC server: [`RpcServer`],
4//! [`RpcHandler`], [`Messenger`], [`Interface`] and the
5//! protocol handlers: [`BorshProtocol`] and [`JsonProtocol`].
6//!
7
8pub mod error;
9mod interface;
10pub mod prelude;
11pub mod protocol;
12pub mod result;
13
14pub use super::error::*;
15pub use crate::encoding::Encoding;
16use crate::imports::*;
17pub use interface::{Interface, Method, Notification};
18pub use protocol::{BorshProtocol, JsonProtocol, ProtocolHandler};
19pub use std::net::SocketAddr;
20pub use tokio::sync::mpsc::UnboundedSender as TokioUnboundedSender;
21pub use workflow_core::task::spawn;
22pub use workflow_websocket::server::{
23    Error as WebSocketError, Message, Result as WebSocketResult, TcpListener, WebSocketConfig,
24    WebSocketCounters, WebSocketHandler, WebSocketReceiver, WebSocketSender, WebSocketServer,
25    WebSocketServerTrait, WebSocketSink,
26};
27pub mod handshake {
28    //! WebSocket handshake helpers
29    pub use workflow_websocket::server::handshake::*;
30}
31use crate::server::result::Result;
32
33///
34/// method!() macro for declaration of RPC method handlers
35///
36/// This macro simplifies creation of async method handler
37/// closures supplied to the RPC dispatch interface. An
38/// async method closure requires to be *Box*ed
39/// and its result must be *Pin*ned, resulting in the following
40/// syntax:
41///
42/// ```ignore
43///
44/// interface.method(Box::new(MyOps::Method, Method::new(|req: MyReq|
45///     Box::pin(
46///         async move {
47///             // ...
48///             Ok(MyResp { })
49///         }
50///     )
51/// )))
52///
53/// ```
54///
55/// The method macro adds the required Box and Pin syntax,
56/// simplifying the declaration as follows:
57///
58/// ```ignore
59/// interface.method(MyOps::Method, method!(
60///   | connection_ctx: ConnectionCtx,
61///     server_ctx: ServerContext,
62///     req: MyReq |
63/// async move {
64///     // ...
65///     Ok(MyResp { })
66/// }))
67/// ```
68///
69pub use workflow_rpc_macros::server_method as method;
70
71///
72/// notification!() macro for declaration of RPC notification handlers
73///
74/// This macro simplifies creation of async notification handler
75/// closures supplied to the RPC notification interface. An
76/// async notification closure requires to be *Box*ed
77/// and its result must be *Pin*ned, resulting in the following
78/// syntax:
79///
80/// ```ignore
81///
82/// interface.notification(MyOps::Notify,Box::new(Notification::new(|msg: MyMsg|
83///     Box::pin(
84///         async move {
85///             // ...
86///             Ok(())
87///         }
88///     )
89/// )))
90///
91/// ```
92///
93/// The notification macro adds the required Box and Pin syntax,
94/// simplifying the declaration as follows:
95///
96/// ```ignore
97/// interface.notification(MyOps::Notify, notification!(|msg: MyMsg| async move {
98///     // ...
99///     Ok(())
100/// }))
101/// ```
102///
103pub use workflow_rpc_macros::server_notification as notification;
104
105/// A basic example RpcContext, can be used to keep track of
106/// connected peers.
107#[derive(Debug, Clone)]
108pub struct RpcContext {
109    pub peer: SocketAddr,
110}
111
112/// [`RpcHandler`] - a server-side event handler for RPC connections.
113#[async_trait]
114pub trait RpcHandler: Send + Sync + 'static {
115    type Context: Send + Sync;
116
117    /// Called to determine if the connection should be accepted.
118    fn accept(&self, _peer: &SocketAddr) -> bool {
119        true
120    }
121
122    /// Connection notification - issued when the server has opened a WebSocket
123    /// connection, before any other interactions occur.  The supplied argument
124    /// is the [`SocketAddr`] of the incoming connection. This function should
125    /// return [`WebSocketResult::Ok`] if the server accepts connection or
126    /// [`WebSocketError`] if the connection is rejected. This function can
127    /// be used to reject connections based on a ban list.
128    async fn connect(self: Arc<Self>, _peer: &SocketAddr) -> WebSocketResult<()> {
129        Ok(())
130    }
131
132    /// [`RpcHandler::handshake()`] is called right after [`RpcHandler::connect()`]
133    /// and is provided with a [`WebSocketSender`] and [`WebSocketReceiver`] channels
134    /// which can be used to communicate with the underlying WebSocket connection
135    /// to negotiate a connection. The function also receives the `&peer` ([`SocketAddr`])
136    /// of the connection and a [`Messenger`] struct.  The [`Messenger`] struct can
137    /// be used to post notifications to the given connection as well as to close it.
138    /// If negotiation is successful, this function should return a `ConnectionContext`
139    /// defined as [`Self::Context`]. This context will be supplied to all subsequent
140    /// RPC calls received from this connection. The [`Messenger`] struct can be
141    /// cloned and captured within the `ConnectionContext`. This allows an RPC
142    /// method handler to later capture and post notifications to the connection
143    /// asynchronously.
144    async fn handshake(
145        self: Arc<Self>,
146        peer: &SocketAddr,
147        sender: &mut WebSocketSender,
148        receiver: &mut WebSocketReceiver,
149        messenger: Arc<Messenger>,
150    ) -> WebSocketResult<Self::Context>;
151
152    /// Disconnect notification, receives the context and the result containing
153    /// the disconnection reason (can be success if the connection is closed gracefully)
154    async fn disconnect(self: Arc<Self>, _ctx: Self::Context, _result: WebSocketResult<()>) {}
155}
156
157///
158/// The [`Messenger`] struct is supplied to the [`RpcHandler::handshake()`] call at
159/// the connection negotiation time. This structure comes in as [`Arc<Messenger>`]
160/// and can be retained for later processing. It provides two methods: [`Messenger::notify`]
161/// that can be used asynchronously to dispatch RPC notifications to the client
162/// and [`Messenger::close`] that can be used to terminate the RPC connection with
163/// the client.
164///
165#[derive(Debug)]
166pub struct Messenger {
167    encoding: Encoding,
168    sink: WebSocketSink,
169}
170
171impl Messenger {
172    pub fn new(encoding: Encoding, sink: &WebSocketSink) -> Self {
173        Self {
174            encoding,
175            sink: sink.clone(),
176        }
177    }
178
179    /// Close the WebSocket connection. The server checks for the connection channel
180    /// for the dispatch of this message and relays it to the client as well as
181    /// proactively terminates the connection.
182    pub fn close(&self) -> Result<()> {
183        self.sink.send(Message::Close(None))?;
184        Ok(())
185    }
186
187    /// Post notification message to the WebSocket connection
188    pub async fn notify<Ops, Msg>(&self, op: Ops, msg: Msg) -> Result<()>
189    where
190        Ops: OpsT,
191        Msg: BorshSerialize + BorshDeserialize + Serialize + Send + Sync + 'static,
192    {
193        match self.encoding {
194            Encoding::Borsh => {
195                self.sink
196                    .send(protocol::borsh::create_serialized_notification_message(
197                        op, msg,
198                    )?)?;
199            }
200            Encoding::SerdeJson => {
201                self.sink
202                    .send(protocol::serde_json::create_serialized_notification_message(op, msg)?)?;
203            }
204        }
205
206        Ok(())
207    }
208
209    /// Serialize message into a [`tungstenite::Message`] for direct websocket delivery.
210    /// Once serialized it can be relayed using [`Messenger::send_raw_message()`].
211    pub fn serialize_notification_message<Ops, Msg>(
212        &self,
213        op: Ops,
214        msg: Msg,
215    ) -> Result<tungstenite::Message>
216    where
217        Ops: OpsT,
218        Msg: MsgT,
219    {
220        match self.encoding {
221            Encoding::Borsh => Ok(protocol::borsh::create_serialized_notification_message(
222                op, msg,
223            )?),
224            Encoding::SerdeJson => {
225                Ok(protocol::serde_json::create_serialized_notification_message(op, msg)?)
226            }
227        }
228    }
229
230    /// Send a raw [`tungstenite::Message`] via the websocket tokio channel.
231    pub fn send_raw_message(&self, msg: tungstenite::Message) -> Result<()> {
232        self.sink.send(msg)?;
233        Ok(())
234    }
235
236    /// Provides direct access to the underlying tokio channel.
237    pub fn sink(&self) -> &WebSocketSink {
238        &self.sink
239    }
240
241    /// Get encoding of the current messenger.
242    pub fn encoding(&self) -> Encoding {
243        self.encoding
244    }
245}
246
247/// WebSocket processor in charge of managing
248/// WRPC Request/Response interactions.
249#[derive(Clone)]
250struct RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
251where
252    Ops: OpsT,
253    ServerContext: Clone + Send + Sync + 'static,
254    ConnectionContext: Clone + Send + Sync + 'static,
255    Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
256{
257    rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
258    protocol: Arc<Protocol>,
259    enable_async_handling: bool,
260    _server_ctx: PhantomData<ServerContext>,
261    _ops: PhantomData<Ops>,
262}
263
264impl<ServerContext, ConnectionContext, Protocol, Ops>
265    RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
266where
267    Ops: OpsT,
268    ServerContext: Clone + Send + Sync + 'static,
269    ConnectionContext: Clone + Send + Sync + 'static,
270    Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
271{
272    pub fn new(
273        rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
274        interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
275        enable_async_handling: bool,
276    ) -> Self {
277        let protocol = Arc::new(Protocol::new(interface));
278        Self {
279            rpc_handler,
280            protocol,
281            enable_async_handling,
282            _server_ctx: PhantomData,
283            _ops: PhantomData,
284        }
285    }
286}
287
288#[async_trait]
289impl<ServerContext, ConnectionContext, Protocol, Ops> WebSocketHandler
290    for RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
291where
292    Ops: OpsT,
293    ServerContext: Clone + Send + Sync + 'static,
294    ConnectionContext: Clone + Send + Sync + 'static,
295    Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
296{
297    type Context = ConnectionContext;
298
299    fn accept(&self, peer: &SocketAddr) -> bool {
300        self.rpc_handler.accept(peer)
301    }
302
303    async fn connect(self: &Arc<Self>, peer: &SocketAddr) -> WebSocketResult<()> {
304        self.rpc_handler.clone().connect(peer).await
305    }
306
307    async fn disconnect(self: &Arc<Self>, ctx: Self::Context, result: WebSocketResult<()>) {
308        self.rpc_handler.clone().disconnect(ctx, result).await
309    }
310
311    async fn handshake(
312        self: &Arc<Self>,
313        peer: &SocketAddr,
314        sender: &mut WebSocketSender,
315        receiver: &mut WebSocketReceiver,
316        sink: &WebSocketSink,
317    ) -> WebSocketResult<Self::Context> {
318        let messenger = Arc::new(Messenger::new(self.protocol.encoding(), sink));
319
320        self.rpc_handler
321            .clone()
322            .handshake(peer, sender, receiver, messenger)
323            .await
324    }
325
326    async fn message(
327        self: &Arc<Self>,
328        connection_ctx: &Self::Context,
329        msg: Message,
330        sink: &WebSocketSink,
331    ) -> WebSocketResult<()> {
332        let connection_ctx = (*connection_ctx).clone();
333        if self.enable_async_handling {
334            let sink = sink.clone();
335            let this = self.clone();
336            spawn(async move {
337                this.protocol
338                    .handle_message(connection_ctx, msg, &sink)
339                    .await
340            });
341            Ok(())
342        } else {
343            self.protocol
344                .handle_message(connection_ctx, msg, sink)
345                .await
346        }
347    }
348}
349
350/// [`RpcServer`] - a server-side object that listens
351/// for incoming websocket connections and delegates interaction
352/// with them to the supplied interfaces: [`RpcHandler`] (for RPC server
353/// management) and [`Interface`] (for method and notification dispatch).
354#[derive(Clone)]
355pub struct RpcServer {
356    ws_server: Arc<dyn WebSocketServerTrait>,
357}
358
359impl RpcServer {
360    /// Create a new [`RpcServer`] supplying an [`Arc`] of the previously-created
361    /// [`RpcHandler`] trait and the [`Interface`] struct.
362    /// This method takes 4 generics:
363    /// - `ConnectionContext`: a struct used as [`RpcHandler::Context`] to
364    ///   represent the connection. This struct is passed to each RPC method
365    ///   and notification call.
366    /// - `ServerContext`: a struct supplied to the [`Interface`] at the
367    ///   Interface creation time. This struct is passed to each RPC method
368    ///   and notification call.
369    /// - `Protocol`: A protocol type used for the RPC message serialization
370    ///   and deserialization (this can be omitted by using [`RpcServer::new_with_encoding`])
371    /// - `Ops`: A data type (index or an `enum`) representing the RPC method
372    ///   or notification.
373    pub fn new<ServerContext, ConnectionContext, Protocol, Ops>(
374        rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
375        interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
376        counters: Option<Arc<WebSocketCounters>>,
377        enable_async_handling: bool,
378    ) -> RpcServer
379    where
380        ServerContext: Clone + Send + Sync + 'static,
381        ConnectionContext: Clone + Send + Sync + 'static,
382        Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
383        Ops: OpsT,
384    {
385        let ws_handler = Arc::new(RpcWebSocketHandler::<
386            ServerContext,
387            ConnectionContext,
388            Protocol,
389            Ops,
390        >::new(rpc_handler, interface, enable_async_handling));
391
392        let ws_server = WebSocketServer::new(ws_handler, counters);
393        RpcServer { ws_server }
394    }
395    /// Create a new [`RpcServer`] supplying an [`Arc`] of the previously-created
396    /// [`RpcHandler`] trait and the [`Interface`] struct.
397    /// This method takes 4 generics:
398    /// - `ConnectionContext`: a struct used as [`RpcHandler::Context`] to
399    ///   represent the connection. This struct is passed to each RPC method
400    ///   and notification call.
401    /// - `ServerContext`: a struct supplied to the [`Interface`] at the
402    ///   Interface creation time. This struct is passed to each RPC method
403    ///   and notification call.
404    /// - `Ops`: A data type (index or an `enum`) representing the RPC method
405    ///   or notification.
406    /// - `Id`: A data type representing a message `Id` - this type must implement
407    ///   the [`id::Generator`](crate::id::Generator) trait. Implementation for default
408    ///   Ids such as [`Id32`] and [`Id64`] can be found in the [`id`](crate::id) module.
409    ///
410    /// This function call receives an `encoding`: [`Encoding`] argument containing
411    /// [`Encoding::Borsh`] or [`Encoding::SerdeJson`], based on which it will
412    /// instantiate the corresponding protocol handler ([`BorshProtocol`] or
413    /// [`JsonProtocol`] respectively).
414    ///
415    /// `enable_async_handling` is a boolean flag that determines if the server
416    /// should spawn a new async task for each incoming message. If set to `false`,
417    /// the server will handle message intake synchronously where each message
418    /// is posted to the underlying handler one-at-a-time. (i.e. RPC awaits for the
419    /// message intake processing to be complete before the next message arrives).
420    /// If `true`, each message is dispatched via a new async task.
421    ///
422    pub fn new_with_encoding<ServerContext, ConnectionContext, Ops, Id>(
423        encoding: Encoding,
424        rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
425        interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
426        counters: Option<Arc<WebSocketCounters>>,
427        enable_async_handling: bool,
428    ) -> RpcServer
429    where
430        ServerContext: Clone + Send + Sync + 'static,
431        ConnectionContext: Clone + Send + Sync + 'static,
432        Ops: OpsT,
433        Id: IdT,
434    {
435        match encoding {
436            Encoding::Borsh => {
437                RpcServer::new::<
438                    ServerContext,
439                    ConnectionContext,
440                    BorshProtocol<ServerContext, ConnectionContext, Ops, Id>,
441                    Ops,
442                >(rpc_handler, interface, counters, enable_async_handling)
443            }
444            Encoding::SerdeJson => {
445                RpcServer::new::<
446                    ServerContext,
447                    ConnectionContext,
448                    JsonProtocol<ServerContext, ConnectionContext, Ops, Id>,
449                    Ops,
450                >(rpc_handler, interface, counters, enable_async_handling)
451            }
452        }
453    }
454
455    /// Bind network interface address to the `TcpListener`
456    pub async fn bind(&self, addr: &str) -> WebSocketResult<TcpListener> {
457        let addr = addr.replace("wrpc://", "");
458        self.ws_server.clone().bind(&addr).await
459    }
460
461    /// Start listening for incoming RPC connections on the supplied `TcpListener`
462    pub async fn listen(
463        &self,
464        listener: TcpListener,
465        config: Option<WebSocketConfig>,
466    ) -> WebSocketResult<()> {
467        self.ws_server.clone().listen(listener, config).await
468    }
469
470    /// Signal the listening task to stop
471    pub fn stop(&self) -> WebSocketResult<()> {
472        self.ws_server.stop()
473    }
474
475    /// Blocks until the listening task has stopped
476    pub async fn join(&self) -> WebSocketResult<()> {
477        self.ws_server.join().await
478    }
479
480    /// Signal the listening task to stop and block
481    /// until it has stopped
482    pub async fn stop_and_join(&self) -> WebSocketResult<()> {
483        self.ws_server.stop_and_join().await
484    }
485}