1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
//!
//! RPC server module (native only). This module encapsulates
//! server-side types used to create an RPC server: [`RpcServer`],
//! [`RpcHandler`], [`Messenger`], [`Interface`] and the
//! protocol handlers: [`BorshProtocol`] and [`SerdeJsonProtocol`].
//!

pub mod error;
mod interface;
pub mod prelude;
mod protocol;
pub mod result;

pub use super::error::*;
pub use crate::encoding::Encoding;
use crate::imports::*;
pub use interface::{Interface, Method, Notification};
pub use protocol::{BorshProtocol, ProtocolHandler, SerdeJsonProtocol};
pub use std::net::SocketAddr;
pub use tokio::sync::mpsc::UnboundedSender as TokioUnboundedSender;
pub use workflow_websocket::server::{
    Error as WebSocketError, Message, Result as WebSocketResult, WebSocketHandler,
    WebSocketReceiver, WebSocketSender, WebSocketServer, WebSocketServerTrait, WebSocketSink,
};
pub mod handshake {
    //! WebSocket handshake helpers
    pub use workflow_websocket::server::handshake::*;
}
use crate::server::result::Result;

///
/// method!() macro for declaration of RPC method handlers
///
/// This macro simplifies creation of async method handler
/// closures supplied to the RPC dispatch interface. An
/// async method closure requires to be *Box*ed
/// and its result must be *Pin*ned, resulting in the following
/// syntax:
///
/// ```ignore
///
/// interface.method(Box::new(MyOps::Method, Method::new(|req: MyReq|
///     Box::pin(
///         async move {
///             // ...
///             Ok(MyResp { })
///         }
///     )
/// )))
///
/// ```
///
/// The method macro adds the required Box and Pin syntax,
/// simplifying the declaration as follows:
///
/// ```ignore
/// interface.method(MyOps::Method, method!(
///   | connection_ctx: ConnectionCtx,
///     server_ctx: ServerContext,
///     req: MyReq |
/// async move {
///     // ...
///     Ok(MyResp { })
/// }))
/// ```
///
pub use workflow_rpc_macros::server_method as method;

///
/// notification!() macro for declaration of RPC notification handlers
///
/// This macro simplifies creation of async notification handler
/// closures supplied to the RPC notification interface. An
/// async notification closure requires to be *Box*ed
/// and its result must be *Pin*ned, resulting in the following
/// syntax:
///
/// ```ignore
///
/// interface.notification(MyOps::Notify,Box::new(Notification::new(|msg: MyMsg|
///     Box::pin(
///         async move {
///             // ...
///             Ok(())
///         }
///     )
/// )))
///
/// ```
///
/// The notification macro adds the required Box and Pin syntax,
/// simplifying the declaration as follows:
///
/// ```ignore
/// interface.notification(MyOps::Notify, notification!(|msg: MyMsg| async move {
///     // ...
///     Ok(())
/// }))
/// ```
///
pub use workflow_rpc_macros::server_notification as notification;

/// A basic example RpcContext, can be used to keep track of
/// connected peers.
#[derive(Debug, Clone)]
pub struct RpcContext {
    pub peer: SocketAddr,
}

/// [`RpcHandler`] - a server-side event handler for RPC connections.
#[async_trait]
pub trait RpcHandler: Send + Sync + 'static {
    type Context: Send + Sync;

    /// Connection notification - issued when the server has opened a WebSocket
    /// connection, before any other interactions occur.  The supplied argument
    /// is the [`SocketAddr`] of the incoming connection. This function should
    /// return [`WebSocketResult::Ok`] if the server accepts connection or
    /// [`WebSocketError`] if the connection is rejected. This function can
    /// be used to reject connections based on a ban list.
    async fn connect(self: Arc<Self>, _peer: &SocketAddr) -> WebSocketResult<()> {
        Ok(())
    }

    /// [`RpcHandler::handshake()`] is called right after [`RpcHandler::connect()`]
    /// and is provided with a [`WebSocketSender`] and [`WebSocketReceiver`] channels
    /// which can be used to communicate with the underlying WebSocket connection
    /// to negotiate a connection. The function also receives the `&peer` ([`SocketAddr`])
    /// of the connection and a [`Messenger`] struct.  The [`Messenger`] struct can
    /// be used to post notifications to the given connection as well as to close it.
    /// If negotiation is successful, this function should return a `ConnectionContext`
    /// defined as [`Self::Context`]. This context will be supplied to all subsequent
    /// RPC calls received from this connection. The [`Messenger`] struct can be
    /// cloned and captured within the `ConnectionContext`. This allows an RPC
    /// method handler to later capture and post notifications to the connection
    /// asynchronously.
    async fn handshake(
        self: Arc<Self>,
        peer: &SocketAddr,
        sender: &mut WebSocketSender,
        receiver: &mut WebSocketReceiver,
        messenger: Arc<Messenger>,
    ) -> WebSocketResult<Self::Context>;

    /// Disconnect notification, receives the context and the result containing
    /// the disconnection reason (can be success if the connection is closed gracefully)
    async fn disconnect(self: Arc<Self>, _ctx: Self::Context, _result: WebSocketResult<()>) {}
}

///
/// The [`Messenger`] struct is supplied to the [`RpcHandler::handshake()`] call at
/// the connection negotiation time. This structure comes in as [`Arc<Messenger>`]
/// and can be retained for later processing. It provides two methods: [`Messenger::notify`]
/// that can be used asynchronously to dispatch RPC notifications to the client
/// and [`Messenger::close`] that can be used to terminate the RPC connection with
/// the client.
///
#[derive(Debug)]
pub struct Messenger {
    encoding: Encoding,
    sink: WebSocketSink,
}

impl Messenger {
    pub fn new(encoding: Encoding, sink: &WebSocketSink) -> Self {
        Self {
            encoding,
            sink: sink.clone(),
        }
    }

    pub fn close(&self) -> Result<()> {
        self.sink.send(Message::Close(None))?;
        Ok(())
    }

    pub async fn notify<Ops, Msg>(&self, op: Ops, msg: Msg) -> Result<()>
    where
        Ops: OpsT,
        Msg: BorshSerialize + BorshDeserialize + Serialize + Send + Sync + 'static,
    {
        match self.encoding {
            Encoding::Borsh => {
                self.sink
                    .send(protocol::create_notify_message_with_borsh(op, msg)?)?;
            }
            Encoding::SerdeJson => {
                self.sink
                    .send(protocol::create_notify_message_with_serde_json(op, msg)?)?;
            }
        }

        Ok(())
    }
}

/// WebSocket processor in charge of managing
/// WRPC Request/Response interactions.
#[derive(Clone)]
struct RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
where
    Ops: OpsT,
    ServerContext: Clone + Send + Sync + 'static,
    ConnectionContext: Clone + Send + Sync + 'static,
    Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
{
    rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
    protocol: Arc<Protocol>,
    _server_ctx: PhantomData<ServerContext>,
    _ops: PhantomData<Ops>,
}

impl<ServerContext, ConnectionContext, Protocol, Ops>
    RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
where
    Ops: OpsT,
    ServerContext: Clone + Send + Sync + 'static,
    ConnectionContext: Clone + Send + Sync + 'static,
    Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
{
    pub fn new(
        rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
        interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
    ) -> Self {
        let protocol = Arc::new(Protocol::new(interface));
        Self {
            rpc_handler,
            protocol,
            _server_ctx: PhantomData,
            _ops: PhantomData,
        }
    }
}

#[async_trait]
impl<ServerContext, ConnectionContext, Protocol, Ops> WebSocketHandler
    for RpcWebSocketHandler<ServerContext, ConnectionContext, Protocol, Ops>
where
    Ops: OpsT,
    ServerContext: Clone + Send + Sync + 'static,
    ConnectionContext: Clone + Send + Sync + 'static,
    Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
{
    type Context = ConnectionContext;

    async fn connect(self: &Arc<Self>, peer: &SocketAddr) -> WebSocketResult<()> {
        self.rpc_handler.clone().connect(peer).await
    }

    async fn disconnect(self: &Arc<Self>, ctx: Self::Context, result: WebSocketResult<()>) {
        self.rpc_handler.clone().disconnect(ctx, result).await
    }

    async fn handshake(
        self: &Arc<Self>,
        peer: &SocketAddr,
        sender: &mut WebSocketSender,
        receiver: &mut WebSocketReceiver,
        sink: &WebSocketSink,
    ) -> WebSocketResult<Self::Context> {
        let messenger = Arc::new(Messenger::new(self.protocol.encoding(), sink));

        self.rpc_handler
            .clone()
            .handshake(peer, sender, receiver, messenger)
            .await
    }

    async fn message(
        self: &Arc<Self>,
        connection_ctx: &Self::Context,
        msg: Message,
        sink: &WebSocketSink,
    ) -> WebSocketResult<()> {
        self.protocol
            .handle_message((*connection_ctx).clone(), msg, sink)
            .await
    }
}

/// [`RpcServer`] - a server-side object that listens
/// for incoming websocket connections and delegates interaction
/// with them to the supplied interfaces: [`RpcHandler`] (for RPC server
/// management) and [`Interface`] (for method and notification dispatch).
#[derive(Clone)]
pub struct RpcServer {
    ws_server: Arc<dyn WebSocketServerTrait>,
}

impl RpcServer {
    /// Create a new [`RpcServer`] supplying an [`Arc`] of the previsouly-created
    /// [`RpcHandler`] trait and the [`Interface`] struct.
    /// This method takes 4 generics:
    /// - `ConnectionContext`: a struct used as [`RpcHandler::Context`] to
    /// represent the connection. This struct is passed to each RPC method
    /// and notification call.
    /// - `ServerContext`: a struct supplied to the [`Interface`] at the
    /// Interface creation time. This struct is passed to each RPC method
    /// and notification call.
    /// - `Protocol`: A protocol type used for the RPC message serialization
    /// and deserialization (this can be omitted by using [`RpcServer::new_with_encoding`])
    /// - `Ops`: A data type (index or an `enum`) representing the RPC method
    /// or notification.
    pub fn new<ServerContext, ConnectionContext, Protocol, Ops>(
        rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
        interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
    ) -> RpcServer
    where
        ServerContext: Clone + Send + Sync + 'static,
        ConnectionContext: Clone + Send + Sync + 'static,
        Protocol: ProtocolHandler<ServerContext, ConnectionContext, Ops> + Send + Sync + 'static,
        Ops: OpsT,
    {
        let ws_handler = Arc::new(RpcWebSocketHandler::<
            ServerContext,
            ConnectionContext,
            Protocol,
            Ops,
        >::new(rpc_handler, interface));
        let ws_server = WebSocketServer::new(ws_handler);
        RpcServer { ws_server }
    }
    /// Create a new [`RpcServer`] supplying an [`Arc`] of the previsouly-created
    /// [`RpcHandler`] trait and the [`Interface`] struct.
    /// This method takes 4 generics:
    /// - `ConnectionContext`: a struct used as [`RpcHandler::Context`] to
    /// represent the connection. This struct is passed to each RPC method
    /// and notification call.
    /// - `ServerContext`: a struct supplied to the [`Interface`] at the
    /// Interface creation time. This struct is passed to each RPC method
    /// and notification call.
    /// - `Ops`: A data type (index or an `enum`) representing the RPC method
    /// or notification.
    /// - `Id`: A data type representing a message `Id` - this type must implement
    /// the [`id::Generator`](crate::id::Generator) trait. Implementation for default
    /// Ids such as [`Id32`] and [`Id64`] can be found in the [`id`](crate::id) module.
    ///
    /// This function call receives an `encoding`: [`Encoding`] argument containing
    /// [`Encoding::Borsh`] or [`Encoding::SerdeJson`], based on which it will
    /// instantiate the corresponding protocol handler ([`BorshProtocol`] or
    /// [`SerdeJsonProtocol`] respectively).
    ///
    pub fn new_with_encoding<ServerContext, ConnectionContext, Ops, Id>(
        encoding: Encoding,
        rpc_handler: Arc<dyn RpcHandler<Context = ConnectionContext>>,
        interface: Arc<Interface<ServerContext, ConnectionContext, Ops>>,
    ) -> RpcServer
    where
        ServerContext: Clone + Send + Sync + 'static,
        ConnectionContext: Clone + Send + Sync + 'static,
        Ops: OpsT,
        Id: IdT,
    {
        match encoding {
            Encoding::Borsh => RpcServer::new::<
                ServerContext,
                ConnectionContext,
                BorshProtocol<ServerContext, ConnectionContext, Ops, Id>,
                Ops,
            >(rpc_handler, interface),
            Encoding::SerdeJson => RpcServer::new::<
                ServerContext,
                ConnectionContext,
                SerdeJsonProtocol<ServerContext, ConnectionContext, Ops, Id>,
                Ops,
            >(rpc_handler, interface),
        }
    }

    /// Start listening for incoming RPC connections on the `addr`
    pub async fn listen(&self, addr: &str) -> WebSocketResult<()> {
        let addr = Regex::new(r"^wrpc://")?.replace(addr, "");
        self.ws_server.clone().listen(&addr).await
    }

    /// Signal the listening task to stop
    pub fn stop(&self) -> WebSocketResult<()> {
        self.ws_server.stop()
    }

    /// Blocks until the listening task has stopped
    pub async fn join(&self) -> WebSocketResult<()> {
        self.ws_server.join().await
    }

    /// Signal the listening task to stop and block
    /// until it has stopped
    pub async fn stop_and_join(&self) -> WebSocketResult<()> {
        self.ws_server.stop_and_join().await
    }
}