workflow_rpc/server/mod.rs
1//!
2//! RPC server module (native only). This module encapsulates
3//! server-side types used to create an RPC server: [`RpcServer`],
4//! [`RpcHandler`], [`Messenger`], [`Interface`] and the
5//! protocol handlers: [`BorshProtocol`] and [`JsonProtocol`].
6//!
7
8pub mod error;
9mod interface;
10pub mod prelude;
11pub mod protocol;
12pub mod result;
13
14pub use super::error::*;
15pub use crate::encoding::Encoding;
16use crate::imports::*;
17pub use interface::{Interface, Method, Notification};
18pub use protocol::{BorshProtocol, JsonProtocol, ProtocolHandler};
19pub use std::net::SocketAddr;
20pub use tokio::sync::mpsc::UnboundedSender as TokioUnboundedSender;
21pub use workflow_core::task::spawn;
22pub use workflow_websocket::server::{
23 Error as WebSocketError, Message, Result as WebSocketResult, TcpListener, WebSocketConfig,
24 WebSocketCounters, WebSocketHandler, WebSocketReceiver, WebSocketSender, WebSocketServer,
25 WebSocketServerTrait, WebSocketSink,
26};
27pub mod handshake {
28 //! WebSocket handshake helpers
29 pub use workflow_websocket::server::handshake::*;
30}
31use crate::server::result::Result;
32
33///
34/// method!() macro for declaration of RPC method handlers
35///
36/// This macro simplifies creation of async method handler
37/// closures supplied to the RPC dispatch interface. An
38/// async method closure requires to be *Box*ed
39/// and its result must be *Pin*ned, resulting in the following
40/// syntax:
41///
42/// ```ignore
43///
44/// interface.method(Box::new(MyOps::Method, Method::new(|req: MyReq|
45/// Box::pin(
46/// async move {
47/// // ...
48/// Ok(MyResp { })
49/// }
50/// )
51/// )))
52///
53/// ```
54///
55/// The method macro adds the required Box and Pin syntax,
56/// simplifying the declaration as follows:
57///
58/// ```ignore
59/// interface.method(MyOps::Method, method!(
60/// | connection_ctx: ConnectionCtx,
61/// server_ctx: ServerContext,
62/// req: MyReq |
63/// async move {
64/// // ...
65/// Ok(MyResp { })
66/// }))
67/// ```
68///
69pub use workflow_rpc_macros::server_method as method;
70
71///
72/// notification!() macro for declaration of RPC notification handlers
73///
74/// This macro simplifies creation of async notification handler
75/// closures supplied to the RPC notification interface. An
76/// async notification closure requires to be *Box*ed
77/// and its result must be *Pin*ned, resulting in the following
78/// syntax:
79///
80/// ```ignore
81///
82/// interface.notification(MyOps::Notify,Box::new(Notification::new(|msg: MyMsg|
83/// Box::pin(
84/// async move {
85/// // ...
86/// Ok(())
87/// }
88/// )
89/// )))
90///
91/// ```
92///
93/// The notification macro adds the required Box and Pin syntax,
94/// simplifying the declaration as follows:
95///
96/// ```ignore
97/// interface.notification(MyOps::Notify, notification!(|msg: MyMsg| async move {
98/// // ...
99/// Ok(())
100/// }))
101/// ```
102///
103pub use workflow_rpc_macros::server_notification as notification;
104
105/// A basic example RpcContext, can be used to keep track of
106/// connected peers.
107#[derive(Debug, Clone)]
108pub struct RpcContext {
109 pub peer: SocketAddr,
110}
111
112/// [`RpcHandler`] - a server-side event handler for RPC connections.
113#[async_trait]
114pub trait RpcHandler: Send + Sync + 'static {
115 type Context: Send + Sync;
116
117 /// Called to determine if the connection should be accepted.
118 fn accept(&self, _peer: &SocketAddr) -> bool {
119 true
120 }
121
122 /// Connection notification - issued when the server has opened a WebSocket
123 /// connection, before any other interactions occur. The supplied argument
124 /// is the [`SocketAddr`] of the incoming connection. This function should
125 /// return [`WebSocketResult::Ok`] if the server accepts connection or
126 /// [`WebSocketError`] if the connection is rejected. This function can
127 /// be used to reject connections based on a ban list.
128 async fn connect(self: Arc<Self>, _peer: &SocketAddr) -> WebSocketResult<()> {
129 Ok(())
130 }
131
132 /// [`RpcHandler::handshake()`] is called right after [`RpcHandler::connect()`]
133 /// and is provided with a [`WebSocketSender`] and [`WebSocketReceiver`] channels
134 /// which can be used to communicate with the underlying WebSocket connection
135 /// to negotiate a connection. The function also receives the `&peer` ([`SocketAddr`])
136 /// of the connection and a [`Messenger`] struct. The [`Messenger`] struct can
137 /// be used to post notifications to the given connection as well as to close it.
138 /// If negotiation is successful, this function should return a `ConnectionContext`
139 /// defined as [`Self::Context`]. This context will be supplied to all subsequent
140 /// RPC calls received from this connection. The [`Messenger`] struct can be
141 /// cloned and captured within the `ConnectionContext`. This allows an RPC
142 /// method handler to later capture and post notifications to the connection
143 /// asynchronously.
144 async fn handshake(
145 self: Arc<Self>,
146 peer: &SocketAddr,
147 sender: &mut WebSocketSender,
148 receiver: &mut WebSocketReceiver,
149 messenger: Arc<Messenger>,
150 ) -> WebSocketResult<Self::Context>;
151
152 /// Disconnect notification, receives the context and the result containing
153 /// the disconnection reason (can be success if the connection is closed gracefully)
154 async fn disconnect(self: Arc<Self>, _ctx: Self::Context, _result: WebSocketResult<()>) {}
155}
156
157///
158/// The [`Messenger`] struct is supplied to the [`RpcHandler::handshake()`] call at
159/// the connection negotiation time. This structure comes in as [`Arc<Messenger>`]
160/// and can be retained for later processing. It provides two methods: [`Messenger::notify`]
161/// that can be used asynchronously to dispatch RPC notifications to the client
162/// and [`Messenger::close`] that can be used to terminate the RPC connection with
163/// the client.
164///
165#[derive(Debug)]
166pub struct Messenger {
167 encoding: Encoding,
168 sink: WebSocketSink,
169}
170
171impl Messenger {
172 pub fn new(encoding: Encoding, sink: &WebSocketSink) -> Self {
173 Self {
174 encoding,
175 sink: sink.clone(),
176 }
177 }
178
179 /// Close the WebSocket connection. The server checks for the connection channel
180 /// for the dispatch of this message and relays it to the client as well as
181 /// proactively terminates the connection.
182 pub fn close(&self) -> Result<()> {
183 self.sink.send(Message::Close(None))?;
184 Ok(())
185 }
186
187 /// Post notification message to the WebSocket connection
188 pub async fn notify<Ops, Msg>(&self, op: Ops, msg: Msg) -> Result<()>
189 where
190 Ops: OpsT,
191 Msg: BorshSerialize + BorshDeserialize + Serialize + Send + Sync + 'static,
192 {
193 match self.encoding {
194 Encoding::Borsh => {
195 self.sink
196 .send(protocol::borsh::create_serialized_notification_message(
197 op, msg,
198 )?)?;
199 }
200 Encoding::SerdeJson => {
201 self.sink
202 .send(protocol::serde_json::create_serialized_notification_message(op, msg)?)?;
203 }
204 }
205
206 Ok(())
207 }
208
209 /// Serialize message into a [`tungstenite::Message`] for direct websocket delivery.
210 /// Once serialized it can be relayed using [`Messenger::send_raw_message()`].
211 pub fn serialize_notification_message<Ops, Msg>(
212 &self,
213 op: Ops,
214 msg: Msg,
215 ) -> Result<tungstenite::Message>
216 where
217 Ops: OpsT,
218 Msg: MsgT,
219 {
220 match self.encoding {
221 Encoding::Borsh => Ok(protocol::borsh::create_serialized_notification_message(
222 op, msg,
223 )?),
224 Encoding::SerdeJson => {
225 Ok(protocol::serde_json::create_serialized_notification_message(op, msg)?)
226 }
227 }
228 }
229
230 /// Send a raw [`tungstenite::Message`] via the websocket tokio channel.
231 pub fn send_raw_message(&self, msg: tungstenite::Message) -> Result<()> {
232 self.sink.send(msg)?;
233 Ok(())
234 }
235
236 /// Provides direct access to the underlying tokio channel.
237 pub fn sink(&self) -> &WebSocketSink {
238 &self.sink
239 }
240
241 /// Get encoding of the current messenger.
242 pub fn encoding(&self) -> Encoding {
243 self.encoding
244 }
245}
246
247/// WebSocket processor in charge of managing
248/// WRPC Request/Response interactions.
249#[derive(Clone)]
250struct RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
251where
252 Ops: OpsT,
253 ServerContext: Clone + Send + Sync + 'static,
254 ConnectionContext: Clone + Send + Sync + 'static,
255 Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
256{
257 rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
258 protocol: Arc<Protocol>,
259 enable_async_handling: bool,
260 _server_ctx: PhantomData<ServerContext>,
261 _ops: PhantomData<Ops>,
262}
263
264impl<ServerContext, ConnectionContext, Protocol, Ops>
265 RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
266where
267 Ops: OpsT,
268 ServerContext: Clone + Send + Sync + 'static,
269 ConnectionContext: Clone + Send + Sync + 'static,
270 Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
271{
272 pub fn new(
273 rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
274 interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
275 enable_async_handling: bool,
276 ) -> Self {
277 let protocol = Arc::new(Protocol::new(interface));
278 Self {
279 rpc_handler,
280 protocol,
281 enable_async_handling,
282 _server_ctx: PhantomData,
283 _ops: PhantomData,
284 }
285 }
286}
287
288#[async_trait]
289impl<ServerContext, ConnectionContext, Protocol, Ops> WebSocketHandler
290 for RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
291where
292 Ops: OpsT,
293 ServerContext: Clone + Send + Sync + 'static,
294 ConnectionContext: Clone + Send + Sync + 'static,
295 Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
296{
297 type Context = ConnectionContext;
298
299 fn accept(&self, peer: &SocketAddr) -> bool {
300 self.rpc_handler.accept(peer)
301 }
302
303 async fn connect(self: &Arc<Self>, peer: &SocketAddr) -> WebSocketResult<()> {
304 self.rpc_handler.clone().connect(peer).await
305 }
306
307 async fn disconnect(self: &Arc<Self>, ctx: Self::Context, result: WebSocketResult<()>) {
308 self.rpc_handler.clone().disconnect(ctx, result).await
309 }
310
311 async fn handshake(
312 self: &Arc<Self>,
313 peer: &SocketAddr,
314 sender: &mut WebSocketSender,
315 receiver: &mut WebSocketReceiver,
316 sink: &WebSocketSink,
317 ) -> WebSocketResult<Self::Context> {
318 let messenger = Arc::new(Messenger::new(self.protocol.encoding(), sink));
319
320 self.rpc_handler
321 .clone()
322 .handshake(peer, sender, receiver, messenger)
323 .await
324 }
325
326 async fn message(
327 self: &Arc<Self>,
328 connection_ctx: &Self::Context,
329 msg: Message,
330 sink: &WebSocketSink,
331 ) -> WebSocketResult<()> {
332 let connection_ctx = (*connection_ctx).clone();
333 if self.enable_async_handling {
334 let sink = sink.clone();
335 let this = self.clone();
336 spawn(async move {
337 this.protocol
338 .handle_message(connection_ctx, msg, &sink)
339 .await
340 });
341 Ok(())
342 } else {
343 self.protocol
344 .handle_message(connection_ctx, msg, sink)
345 .await
346 }
347 }
348}
349
350/// [`RpcServer`] - a server-side object that listens
351/// for incoming websocket connections and delegates interaction
352/// with them to the supplied interfaces: [`RpcHandler`] (for RPC server
353/// management) and [`Interface`] (for method and notification dispatch).
354#[derive(Clone)]
355pub struct RpcServer {
356 ws_server: Arc<dyn WebSocketServerTrait>,
357}
358
359impl RpcServer {
360 /// Create a new [`RpcServer`] supplying an [`Arc`] of the previously-created
361 /// [`RpcHandler`] trait and the [`Interface`] struct.
362 /// This method takes 4 generics:
363 /// - `ConnectionContext`: a struct used as [`RpcHandler::Context`] to
364 /// represent the connection. This struct is passed to each RPC method
365 /// and notification call.
366 /// - `ServerContext`: a struct supplied to the [`Interface`] at the
367 /// Interface creation time. This struct is passed to each RPC method
368 /// and notification call.
369 /// - `Protocol`: A protocol type used for the RPC message serialization
370 /// and deserialization (this can be omitted by using [`RpcServer::new_with_encoding`])
371 /// - `Ops`: A data type (index or an `enum`) representing the RPC method
372 /// or notification.
373 pub fn new<ServerContext, ConnectionContext, Protocol, Ops>(
374 rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
375 interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
376 counters: Option<Arc<WebSocketCounters>>,
377 enable_async_handling: bool,
378 ) -> RpcServer
379 where
380 ServerContext: Clone + Send + Sync + 'static,
381 ConnectionContext: Clone + Send + Sync + 'static,
382 Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
383 Ops: OpsT,
384 {
385 let ws_handler = Arc::new(RpcWebSocketHandler::<
386 ServerContext,
387 ConnectionContext,
388 Protocol,
389 Ops,
390 >::new(rpc_handler, interface, enable_async_handling));
391
392 let ws_server = WebSocketServer::new(ws_handler, counters);
393 RpcServer { ws_server }
394 }
395 /// Create a new [`RpcServer`] supplying an [`Arc`] of the previously-created
396 /// [`RpcHandler`] trait and the [`Interface`] struct.
397 /// This method takes 4 generics:
398 /// - `ConnectionContext`: a struct used as [`RpcHandler::Context`] to
399 /// represent the connection. This struct is passed to each RPC method
400 /// and notification call.
401 /// - `ServerContext`: a struct supplied to the [`Interface`] at the
402 /// Interface creation time. This struct is passed to each RPC method
403 /// and notification call.
404 /// - `Ops`: A data type (index or an `enum`) representing the RPC method
405 /// or notification.
406 /// - `Id`: A data type representing a message `Id` - this type must implement
407 /// the [`id::Generator`](crate::id::Generator) trait. Implementation for default
408 /// Ids such as [`Id32`] and [`Id64`] can be found in the [`id`](crate::id) module.
409 ///
410 /// This function call receives an `encoding`: [`Encoding`] argument containing
411 /// [`Encoding::Borsh`] or [`Encoding::SerdeJson`], based on which it will
412 /// instantiate the corresponding protocol handler ([`BorshProtocol`] or
413 /// [`JsonProtocol`] respectively).
414 ///
415 /// `enable_async_handling` is a boolean flag that determines if the server
416 /// should spawn a new async task for each incoming message. If set to `false`,
417 /// the server will handle message intake synchronously where each message
418 /// is posted to the underlying handler one-at-a-time. (i.e. RPC awaits for the
419 /// message intake processing to be complete before the next message arrives).
420 /// If `true`, each message is dispatched via a new async task.
421 ///
422 pub fn new_with_encoding<ServerContext, ConnectionContext, Ops, Id>(
423 encoding: Encoding,
424 rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
425 interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
426 counters: Option<Arc<WebSocketCounters>>,
427 enable_async_handling: bool,
428 ) -> RpcServer
429 where
430 ServerContext: Clone + Send + Sync + 'static,
431 ConnectionContext: Clone + Send + Sync + 'static,
432 Ops: OpsT,
433 Id: IdT,
434 {
435 match encoding {
436 Encoding::Borsh => {
437 RpcServer::new::<
438 ServerContext,
439 ConnectionContext,
440 BorshProtocol<ServerContext, ConnectionContext, Ops, Id>,
441 Ops,
442 >(rpc_handler, interface, counters, enable_async_handling)
443 }
444 Encoding::SerdeJson => {
445 RpcServer::new::<
446 ServerContext,
447 ConnectionContext,
448 JsonProtocol<ServerContext, ConnectionContext, Ops, Id>,
449 Ops,
450 >(rpc_handler, interface, counters, enable_async_handling)
451 }
452 }
453 }
454
455 /// Bind network interface address to the `TcpListener`
456 pub async fn bind(&self, addr: &str) -> WebSocketResult<TcpListener> {
457 let addr = addr.replace("wrpc://", "");
458 self.ws_server.clone().bind(&addr).await
459 }
460
461 /// Start listening for incoming RPC connections on the supplied `TcpListener`
462 pub async fn listen(
463 &self,
464 listener: TcpListener,
465 config: Option<WebSocketConfig>,
466 ) -> WebSocketResult<()> {
467 self.ws_server.clone().listen(listener, config).await
468 }
469
470 /// Signal the listening task to stop
471 pub fn stop(&self) -> WebSocketResult<()> {
472 self.ws_server.stop()
473 }
474
475 /// Blocks until the listening task has stopped
476 pub async fn join(&self) -> WebSocketResult<()> {
477 self.ws_server.join().await
478 }
479
480 /// Signal the listening task to stop and block
481 /// until it has stopped
482 pub async fn stop_and_join(&self) -> WebSocketResult<()> {
483 self.ws_server.stop_and_join().await
484 }
485}