workflow_rpc/client/protocol/
borsh.rs1use super::{Pending, PendingMap, ProtocolHandler};
2pub use crate::client::error::Error;
3pub use crate::client::result::Result;
4use crate::client::Interface;
5use crate::imports::*;
6use crate::messages::borsh::*;
7use core::marker::PhantomData;
8
9pub type BorshResponseFn =
10 Arc<Box<(dyn Fn(Result<&[u8]>, Option<&Duration>) -> Result<()> + Sync + Send)>>;
11
12pub struct BorshProtocol<Ops, Id>
14where
15 Ops: OpsT,
16 Id: IdT,
17{
18 ws: Arc<WebSocket>,
19 pending: PendingMap<Id, BorshResponseFn>,
20 interface: Option<Arc<Interface<Ops>>>,
21 ops: PhantomData<Ops>,
22 id: PhantomData<Id>,
23}
24
25impl<Ops, Id> BorshProtocol<Ops, Id>
26where
27 Ops: OpsT,
28 Id: IdT,
29{
30 fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self {
31 BorshProtocol {
32 ws,
33 pending: Arc::new(Mutex::new(AHashMap::new())),
34 interface,
35 ops: PhantomData,
36 id: PhantomData,
37 }
38 }
39}
40
41type MessageInfo<'l, Ops, Id> = (Option<Id>, Option<Ops>, Result<&'l [u8]>);
42
43impl<Ops, Id> BorshProtocol<Ops, Id>
44where
45 Id: IdT,
46 Ops: OpsT,
47{
48 fn decode<'l>(&self, server_message: &'l [u8]) -> ServerResult<MessageInfo<'l, Ops, Id>> {
49 match BorshServerMessage::try_from(server_message) {
50 Ok(msg) => {
51 let header = msg.header;
52 match header.kind {
53 ServerMessageKind::Success => {
54 Ok((header.id, header.op, Ok(msg.payload)))
55 }
57 ServerMessageKind::Error => {
58 if let Ok(err) = ServerError::try_from_slice(msg.payload) {
59 Ok((header.id, None, Err(Error::RpcCall(err))))
60 } else {
61 Ok((header.id, None, Err(Error::ErrorDeserializingResponseData)))
62 }
63 }
64 ServerMessageKind::Notification => Ok((None, header.op, Ok(msg.payload))),
65 }
66 }
67 Err(err) => Err(ServerError::RespDeserialize(err.to_string())),
68 }
69 }
70
71 pub async fn request<Req, Resp>(&self, op: Ops, req: Req) -> Result<Resp>
72 where
73 Req: MsgT,
74 Resp: MsgT,
75 {
76 let payload = borsh::to_vec(&req).map_err(|_| Error::BorshSerialize)?;
77
78 let id = Id::generate();
79 let (sender, receiver) = oneshot();
80
81 {
82 let mut pending = self.pending.lock().unwrap();
83 pending.insert(
84 id.clone(),
85 Pending::new(Arc::new(Box::new(move |result, _duration| {
86 sender.try_send(result.map(|data| data.to_vec()))?;
87 Ok(())
88 }))),
89 );
90 }
91
92 self.ws
94 .post(to_ws_msg(BorshReqHeader::new(Some(id), op), &payload))
95 .await?;
96
97 let data = receiver.recv().await??;
98 let resp = ServerResult::<Resp>::try_from_slice(data.as_ref())
99 .map_err(|e| Error::BorshDeserialize(e.to_string()))?;
100
101 Ok(resp?)
102 }
103
104 pub async fn notify<Msg>(&self, op: Ops, payload: Msg) -> Result<()>
105 where
106 Msg: BorshSerialize + Send + Sync + 'static,
107 {
108 let payload = borsh::to_vec(&payload).map_err(|_| Error::BorshSerialize)?;
109 self.ws
110 .post(to_ws_msg(
111 BorshReqHeader::<Ops, Id>::new(None, op),
112 &payload,
113 ))
114 .await?;
115 Ok(())
116 }
117
118 async fn handle_notification(&self, op: &Ops, payload: &[u8]) -> Result<()> {
119 if let Some(interface) = &self.interface {
120 interface
121 .call_notification_with_borsh(op, payload)
122 .await
123 .unwrap_or_else(|err| log_trace!("error handling server notification {}", err));
124 } else {
125 log_trace!("unable to handle server notification - interface is not initialized");
126 }
127
128 Ok(())
129 }
130}
131
132#[async_trait]
133impl<Ops, Id> ProtocolHandler<Ops> for BorshProtocol<Ops, Id>
134where
135 Id: IdT,
136 Ops: OpsT,
137{
138 fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self
139 where
140 Self: Sized,
141 {
142 BorshProtocol::new(ws, interface)
143 }
144
145 async fn handle_timeout(&self, timeout: Duration) {
146 self.pending.lock().unwrap().retain(|_, pending| {
147 if pending.timestamp.elapsed() > timeout {
148 (pending.callback)(Err(Error::Timeout), None).unwrap_or_else(|err| {
149 log_trace!("Error in RPC callback during timeout: `{err}`")
150 });
151 false
152 } else {
153 true
154 }
155 });
156 }
157
158 async fn handle_disconnect(&self) -> Result<()> {
159 self.pending.lock().unwrap().retain(|_, pending| {
160 (pending.callback)(Err(Error::Disconnect), None)
161 .unwrap_or_else(|err| log_trace!("Error in RPC callback during timeout: `{err}`"));
162 false
163 });
164
165 Ok(())
166 }
167
168 async fn handle_message(&self, message: WebSocketMessage) -> Result<()> {
169 if let WebSocketMessage::Binary(server_message) = message {
170 let (id, op, result) = self.decode(server_message.as_slice())?;
171 if let Some(id) = id {
172 if let Some(pending) = self.pending.lock().unwrap().remove(&id) {
173 (pending.callback)(result, Some(&pending.timestamp.elapsed()))
174 } else {
175 Err(Error::ResponseHandler(format!("{id:?}")))
176 }
177 } else if let Some(op) = op {
178 match result {
179 Ok(data) => self.handle_notification(&op, data).await,
180 _ => Ok(()),
181 }
182 } else {
183 Err(Error::NotificationMethod)
184 }
185 } else {
186 return Err(Error::WebSocketMessageType);
187 }
188 }
189}