workflow_rpc/server/interface/
notification.rs

1//! Module containing RPC [`Notification`] closure wrappers
2use crate::imports::*;
3
4/// Base trait representing an RPC notification, used to retain
5/// notification structures in an [`Interface`](super::Interface)
6/// map without generics.
7#[async_trait]
8pub(crate) trait NotificationTrait<ServerContext, ConnectionContext>:
9    Send + Sync + 'static
10{
11    async fn call_with_borsh(
12        &self,
13        server_ctx: ServerContext,
14        connection_ctx: ConnectionContext,
15        data: &[u8],
16    ) -> ServerResult<()>;
17    async fn call_with_serde_json(
18        &self,
19        server_ctx: ServerContext,
20        connection_ctx: ConnectionContext,
21        value: Value,
22    ) -> ServerResult<()>;
23}
24
25/// Notification closure type
26pub type NotificationFn<ServerContext, ConnectionContext, Msg> = Arc<
27    Box<
28        dyn Send
29            + Sync
30            + Fn(ServerContext, ConnectionContext, Msg) -> NotificationFnReturn<()>
31            + 'static,
32    >,
33>;
34
35/// Notification closure return type
36pub type NotificationFnReturn<T> =
37    Pin<Box<(dyn Send + 'static + Future<Output = ServerResult<T>>)>>;
38
39/// RPC notification wrapper. Contains the notification closure function.
40
41pub struct Notification<ServerContext, ConnectionContext, Msg>
42where
43    ServerContext: Send + Sync + 'static,
44    Msg: BorshDeserialize + DeserializeOwned + Send + Sync + 'static,
45{
46    method: NotificationFn<ServerContext, ConnectionContext, Msg>,
47}
48
49impl<ServerContext, ConnectionContext, Msg> Notification<ServerContext, ConnectionContext, Msg>
50where
51    ServerContext: Send + Sync + 'static,
52    Msg: BorshDeserialize + DeserializeOwned + Send + Sync + 'static,
53{
54    pub fn new<FN>(method_fn: FN) -> Notification<ServerContext, ConnectionContext, Msg>
55    where
56        FN: Send
57            + Sync
58            + Fn(ServerContext, ConnectionContext, Msg) -> NotificationFnReturn<()>
59            + 'static,
60    {
61        Notification {
62            method: Arc::new(Box::new(method_fn)),
63        }
64    }
65}
66
67#[async_trait]
68impl<ServerContext, ConnectionContext, Msg> NotificationTrait<ServerContext, ConnectionContext>
69    for Notification<ServerContext, ConnectionContext, Msg>
70where
71    ConnectionContext: Clone + Send + Sync + 'static,
72    ServerContext: Send + Sync + 'static,
73    Msg: BorshDeserialize + DeserializeOwned + Send + Sync + 'static,
74{
75    async fn call_with_borsh(
76        &self,
77        server_ctx: ServerContext,
78        connection_ctx: ConnectionContext,
79        data: &[u8],
80    ) -> ServerResult<()> {
81        let req = Msg::try_from_slice(data)
82            .map_err(|err| ServerError::NotificationDeserialize(err.to_string()))?;
83        (self.method)(server_ctx, connection_ctx, req).await
84    }
85
86    async fn call_with_serde_json(
87        &self,
88        server_ctx: ServerContext,
89        connection_ctx: ConnectionContext,
90        value: Value,
91    ) -> ServerResult<()> {
92        let req: Msg = serde_json::from_value(value)
93            .map_err(|err| ServerError::NotificationDeserialize(err.to_string()))?;
94        (self.method)(server_ctx, connection_ctx, req).await
95    }
96}