workflow_rpc/client/protocol/
serde_json.rs1use core::marker::PhantomData;
2
3use super::{Pending, PendingMap, ProtocolHandler};
4pub use crate::client::error::Error;
5pub use crate::client::result::Result;
6use crate::client::Interface;
7use crate::imports::*;
8use crate::messages::serde_json::*;
9
10pub type JsonResponseFn =
11 Arc<Box<(dyn Fn(Result<Value>, Option<&Duration>) -> Result<()> + Sync + Send)>>;
12
13pub struct JsonProtocol<Ops, Id>
15where
16 Ops: OpsT,
17 Id: IdT,
18{
19 ws: Arc<WebSocket>,
20 pending: PendingMap<Id, JsonResponseFn>,
21 interface: Option<Arc<Interface<Ops>>>,
22 id: PhantomData<Id>,
24}
25
26impl<Ops, Id> JsonProtocol<Ops, Id>
27where
28 Id: IdT,
29 Ops: OpsT,
30{
31 fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self {
32 JsonProtocol::<Ops, Id> {
33 ws,
34 pending: Arc::new(Mutex::new(AHashMap::new())),
35 interface,
36 id: PhantomData,
38 }
39 }
40}
41
42type MessageInfo<Ops, Id> = (Option<Id>, Option<Ops>, Result<Value>);
43
44impl<Ops, Id> JsonProtocol<Ops, Id>
45where
46 Ops: OpsT,
47 Id: IdT,
48{
49 fn decode(&self, server_message: &str) -> Result<MessageInfo<Ops, Id>> {
50 let msg: JSONServerMessage<Ops, Id> = serde_json::from_str(server_message)?;
51
52 if let Some(error) = msg.error {
53 Ok((msg.id, None, Err(error.into())))
54 } else if msg.id.is_some() {
55 if let Some(result) = msg.params {
56 Ok((msg.id, None, Ok(result)))
57 } else {
58 Ok((msg.id, None, Err(Error::NoDataInSuccessResponse)))
59 }
60 } else if let Some(params) = msg.params {
61 Ok((None, msg.method, Ok(params)))
62 } else {
63 Ok((None, None, Err(Error::NoDataInNotificationMessage)))
64 }
65 }
66
67 pub async fn request<Req, Resp>(&self, op: Ops, req: Req) -> Result<Resp>
68 where
69 Req: MsgT,
70 Resp: MsgT,
71 {
72 let id = Id::generate();
73 let (sender, receiver) = oneshot();
74
75 {
76 let mut pending = self.pending.lock().unwrap();
77 pending.insert(
78 id.clone(),
79 Pending::new(Arc::new(Box::new(move |result, _duration| {
80 sender.try_send(result)?;
81 Ok(())
82 }))),
83 );
84 }
85
86 let payload = serde_json::to_value(req)?;
87 let client_message = JsonClientMessage::new(Some(id), op, payload);
88 let json = serde_json::to_string(&client_message)?;
89
90 self.ws.post(WebSocketMessage::Text(json)).await?;
91
92 let data = receiver.recv().await??;
93
94 let resp = <Resp as Deserialize>::deserialize(data)
95 .map_err(|e| Error::SerdeDeserialize(e.to_string()))?;
96 Ok(resp)
97 }
98
99 pub async fn notify<Msg>(&self, op: Ops, data: Msg) -> Result<()>
100 where
101 Msg: Serialize + Send + Sync + 'static,
102 {
103 let payload = serde_json::to_value(data)?;
104 let client_message = JsonClientMessage::<Ops, Id>::new(None, op, payload);
105 let json = serde_json::to_string(&client_message)?;
106 self.ws.post(WebSocketMessage::Text(json)).await?;
107 Ok(())
108 }
109
110 async fn handle_notification(&self, op: Ops, payload: Value) -> Result<()> {
111 if let Some(interface) = &self.interface {
112 interface
113 .call_notification_with_serde_json(&op, payload)
114 .await
115 .unwrap_or_else(|err| log_trace!("error handling server notification {}", err));
116 } else {
117 log_trace!("unable to handle server notification - interface is not initialized");
118 }
119
120 Ok(())
121 }
122}
123
124#[async_trait]
125impl<Ops, Id> ProtocolHandler<Ops> for JsonProtocol<Ops, Id>
126where
127 Ops: OpsT,
128 Id: IdT,
129{
130 fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self
131 where
132 Self: Sized,
133 {
134 JsonProtocol::new(ws, interface)
135 }
136
137 async fn handle_timeout(&self, timeout: Duration) {
138 self.pending.lock().unwrap().retain(|_, pending| {
139 if pending.timestamp.elapsed() > timeout {
140 (pending.callback)(Err(Error::Timeout), None).unwrap_or_else(|err| {
141 log_trace!("Error in RPC callback during timeout: `{err}`")
142 });
143 false
144 } else {
145 true
146 }
147 });
148 }
149
150 async fn handle_message(&self, message: WebSocketMessage) -> Result<()> {
151 if let WebSocketMessage::Text(server_message) = message {
152 let (id, method, result) = self.decode(server_message.as_str())?;
153 if let Some(id) = id {
154 if let Some(pending) = self.pending.lock().unwrap().remove(&id) {
155 (pending.callback)(result, Some(&pending.timestamp.elapsed()))
156 } else {
157 Err(Error::ResponseHandler(format!("{id:?}")))
158 }
159 } else if let Some(method) = method {
160 match result {
161 Ok(data) => self.handle_notification(method, data).await,
162 _ => Ok(()),
163 }
164 } else {
165 Err(Error::NotificationMethod)
166 }
167 } else {
168 return Err(Error::WebSocketMessageType);
169 }
170 }
171
172 async fn handle_disconnect(&self) -> Result<()> {
173 self.pending.lock().unwrap().retain(|_, pending| {
174 (pending.callback)(Err(Error::Disconnect), None)
175 .unwrap_or_else(|err| log_trace!("Error in RPC callback during timeout: `{err}`"));
176 false
177 });
178
179 Ok(())
180 }
181}