workflow_rpc/client/protocol/
borsh.rs

1use super::{Pending, PendingMap, ProtocolHandler};
2pub use crate::client::error::Error;
3pub use crate::client::result::Result;
4use crate::client::Interface;
5use crate::imports::*;
6use crate::messages::borsh::*;
7use core::marker::PhantomData;
8
9pub type BorshResponseFn =
10    Arc<Box<(dyn Fn(Result<&[u8]>, Option<&Duration>) -> Result<()> + Sync + Send)>>;
11
12/// Borsh RPC message handler and dispatcher
13pub struct BorshProtocol<Ops, Id>
14where
15    Ops: OpsT,
16    Id: IdT,
17{
18    ws: Arc<WebSocket>,
19    pending: PendingMap<Id, BorshResponseFn>,
20    interface: Option<Arc<Interface<Ops>>>,
21    ops: PhantomData<Ops>,
22    id: PhantomData<Id>,
23}
24
25impl<Ops, Id> BorshProtocol<Ops, Id>
26where
27    Ops: OpsT,
28    Id: IdT,
29{
30    fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self {
31        BorshProtocol {
32            ws,
33            pending: Arc::new(Mutex::new(AHashMap::new())),
34            interface,
35            ops: PhantomData,
36            id: PhantomData,
37        }
38    }
39}
40
41type MessageInfo<'l, Ops, Id> = (Option<Id>, Option<Ops>, Result<&'l [u8]>);
42
43impl<Ops, Id> BorshProtocol<Ops, Id>
44where
45    Id: IdT,
46    Ops: OpsT,
47{
48    fn decode<'l>(&self, server_message: &'l [u8]) -> ServerResult<MessageInfo<'l, Ops, Id>> {
49        match BorshServerMessage::try_from(server_message) {
50            Ok(msg) => {
51                let header = msg.header;
52                match header.kind {
53                    ServerMessageKind::Success => {
54                        Ok((header.id, header.op, Ok(msg.payload)))
55                        // Ok((Some(header.id), header.op.clone(), Ok(msg.data)))
56                    }
57                    ServerMessageKind::Error => {
58                        if let Ok(err) = ServerError::try_from_slice(msg.payload) {
59                            Ok((header.id, None, Err(Error::RpcCall(err))))
60                        } else {
61                            Ok((header.id, None, Err(Error::ErrorDeserializingResponseData)))
62                        }
63                    }
64                    ServerMessageKind::Notification => Ok((None, header.op, Ok(msg.payload))),
65                }
66            }
67            Err(err) => Err(ServerError::RespDeserialize(err.to_string())),
68        }
69    }
70
71    pub async fn request<Req, Resp>(&self, op: Ops, req: Req) -> Result<Resp>
72    where
73        Req: MsgT,
74        Resp: MsgT,
75    {
76        let payload = borsh::to_vec(&req).map_err(|_| Error::BorshSerialize)?;
77
78        let id = Id::generate();
79        let (sender, receiver) = oneshot();
80
81        {
82            let mut pending = self.pending.lock().unwrap();
83            pending.insert(
84                id.clone(),
85                Pending::new(Arc::new(Box::new(move |result, _duration| {
86                    sender.try_send(result.map(|data| data.to_vec()))?;
87                    Ok(())
88                }))),
89            );
90        }
91
92        // TODO - post error into sender if ws.send() fails
93        self.ws
94            .post(to_ws_msg(BorshReqHeader::new(Some(id), op), &payload))
95            .await?;
96
97        let data = receiver.recv().await??;
98        let resp = ServerResult::<Resp>::try_from_slice(data.as_ref())
99            .map_err(|e| Error::BorshDeserialize(e.to_string()))?;
100
101        Ok(resp?)
102    }
103
104    pub async fn notify<Msg>(&self, op: Ops, payload: Msg) -> Result<()>
105    where
106        Msg: BorshSerialize + Send + Sync + 'static,
107    {
108        let payload = borsh::to_vec(&payload).map_err(|_| Error::BorshSerialize)?;
109        self.ws
110            .post(to_ws_msg(
111                BorshReqHeader::<Ops, Id>::new(None, op),
112                &payload,
113            ))
114            .await?;
115        Ok(())
116    }
117
118    async fn handle_notification(&self, op: &Ops, payload: &[u8]) -> Result<()> {
119        if let Some(interface) = &self.interface {
120            interface
121                .call_notification_with_borsh(op, payload)
122                .await
123                .unwrap_or_else(|err| log_trace!("error handling server notification {}", err));
124        } else {
125            log_trace!("unable to handle server notification - interface is not initialized");
126        }
127
128        Ok(())
129    }
130}
131
132#[async_trait]
133impl<Ops, Id> ProtocolHandler<Ops> for BorshProtocol<Ops, Id>
134where
135    Id: IdT,
136    Ops: OpsT,
137{
138    fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self
139    where
140        Self: Sized,
141    {
142        BorshProtocol::new(ws, interface)
143    }
144
145    async fn handle_timeout(&self, timeout: Duration) {
146        self.pending.lock().unwrap().retain(|_, pending| {
147            if pending.timestamp.elapsed() > timeout {
148                (pending.callback)(Err(Error::Timeout), None).unwrap_or_else(|err| {
149                    log_trace!("Error in RPC callback during timeout: `{err}`")
150                });
151                false
152            } else {
153                true
154            }
155        });
156    }
157
158    async fn handle_disconnect(&self) -> Result<()> {
159        self.pending.lock().unwrap().retain(|_, pending| {
160            (pending.callback)(Err(Error::Disconnect), None)
161                .unwrap_or_else(|err| log_trace!("Error in RPC callback during timeout: `{err}`"));
162            false
163        });
164
165        Ok(())
166    }
167
168    async fn handle_message(&self, message: WebSocketMessage) -> Result<()> {
169        if let WebSocketMessage::Binary(server_message) = message {
170            let (id, op, result) = self.decode(server_message.as_slice())?;
171            if let Some(id) = id {
172                if let Some(pending) = self.pending.lock().unwrap().remove(&id) {
173                    (pending.callback)(result, Some(&pending.timestamp.elapsed()))
174                } else {
175                    Err(Error::ResponseHandler(format!("{id:?}")))
176                }
177            } else if let Some(op) = op {
178                match result {
179                    Ok(data) => self.handle_notification(&op, data).await,
180                    _ => Ok(()),
181                }
182            } else {
183                Err(Error::NotificationMethod)
184            }
185        } else {
186            return Err(Error::WebSocketMessageType);
187        }
188    }
189}