workflow_rpc/client/protocol/
serde_json.rs

1use core::marker::PhantomData;
2
3use super::{Pending, PendingMap, ProtocolHandler};
4pub use crate::client::error::Error;
5pub use crate::client::result::Result;
6use crate::client::Interface;
7use crate::imports::*;
8use crate::messages::serde_json::*;
9
10pub type JsonResponseFn =
11    Arc<Box<(dyn Fn(Result<Value>, Option<&Duration>) -> Result<()> + Sync + Send)>>;
12
13/// Serde JSON RPC message handler and dispatcher
14pub struct JsonProtocol<Ops, Id>
15where
16    Ops: OpsT,
17    Id: IdT,
18{
19    ws: Arc<WebSocket>,
20    pending: PendingMap<Id, JsonResponseFn>,
21    interface: Option<Arc<Interface<Ops>>>,
22    // ops: PhantomData<Ops>,
23    id: PhantomData<Id>,
24}
25
26impl<Ops, Id> JsonProtocol<Ops, Id>
27where
28    Id: IdT,
29    Ops: OpsT,
30{
31    fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self {
32        JsonProtocol::<Ops, Id> {
33            ws,
34            pending: Arc::new(Mutex::new(AHashMap::new())),
35            interface,
36            // ops: PhantomData,
37            id: PhantomData,
38        }
39    }
40}
41
42type MessageInfo<Ops, Id> = (Option<Id>, Option<Ops>, Result<Value>);
43
44impl<Ops, Id> JsonProtocol<Ops, Id>
45where
46    Ops: OpsT,
47    Id: IdT,
48{
49    fn decode(&self, server_message: &str) -> Result<MessageInfo<Ops, Id>> {
50        let msg: JSONServerMessage<Ops, Id> = serde_json::from_str(server_message)?;
51
52        if let Some(error) = msg.error {
53            Ok((msg.id, None, Err(error.into())))
54        } else if msg.id.is_some() {
55            if let Some(result) = msg.params {
56                Ok((msg.id, None, Ok(result)))
57            } else {
58                Ok((msg.id, None, Err(Error::NoDataInSuccessResponse)))
59            }
60        } else if let Some(params) = msg.params {
61            Ok((None, msg.method, Ok(params)))
62        } else {
63            Ok((None, None, Err(Error::NoDataInNotificationMessage)))
64        }
65    }
66
67    pub async fn request<Req, Resp>(&self, op: Ops, req: Req) -> Result<Resp>
68    where
69        Req: MsgT,
70        Resp: MsgT,
71    {
72        let id = Id::generate();
73        let (sender, receiver) = oneshot();
74
75        {
76            let mut pending = self.pending.lock().unwrap();
77            pending.insert(
78                id.clone(),
79                Pending::new(Arc::new(Box::new(move |result, _duration| {
80                    sender.try_send(result)?;
81                    Ok(())
82                }))),
83            );
84        }
85
86        let payload = serde_json::to_value(req)?;
87        let client_message = JsonClientMessage::new(Some(id), op, payload);
88        let json = serde_json::to_string(&client_message)?;
89
90        self.ws.post(WebSocketMessage::Text(json)).await?;
91
92        let data = receiver.recv().await??;
93
94        let resp = <Resp as Deserialize>::deserialize(data)
95            .map_err(|e| Error::SerdeDeserialize(e.to_string()))?;
96        Ok(resp)
97    }
98
99    pub async fn notify<Msg>(&self, op: Ops, data: Msg) -> Result<()>
100    where
101        Msg: Serialize + Send + Sync + 'static,
102    {
103        let payload = serde_json::to_value(data)?;
104        let client_message = JsonClientMessage::<Ops, Id>::new(None, op, payload);
105        let json = serde_json::to_string(&client_message)?;
106        self.ws.post(WebSocketMessage::Text(json)).await?;
107        Ok(())
108    }
109
110    async fn handle_notification(&self, op: Ops, payload: Value) -> Result<()> {
111        if let Some(interface) = &self.interface {
112            interface
113                .call_notification_with_serde_json(&op, payload)
114                .await
115                .unwrap_or_else(|err| log_trace!("error handling server notification {}", err));
116        } else {
117            log_trace!("unable to handle server notification - interface is not initialized");
118        }
119
120        Ok(())
121    }
122}
123
124#[async_trait]
125impl<Ops, Id> ProtocolHandler<Ops> for JsonProtocol<Ops, Id>
126where
127    Ops: OpsT,
128    Id: IdT,
129{
130    fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self
131    where
132        Self: Sized,
133    {
134        JsonProtocol::new(ws, interface)
135    }
136
137    async fn handle_timeout(&self, timeout: Duration) {
138        self.pending.lock().unwrap().retain(|_, pending| {
139            if pending.timestamp.elapsed() > timeout {
140                (pending.callback)(Err(Error::Timeout), None).unwrap_or_else(|err| {
141                    log_trace!("Error in RPC callback during timeout: `{err}`")
142                });
143                false
144            } else {
145                true
146            }
147        });
148    }
149
150    async fn handle_message(&self, message: WebSocketMessage) -> Result<()> {
151        if let WebSocketMessage::Text(server_message) = message {
152            let (id, method, result) = self.decode(server_message.as_str())?;
153            if let Some(id) = id {
154                if let Some(pending) = self.pending.lock().unwrap().remove(&id) {
155                    (pending.callback)(result, Some(&pending.timestamp.elapsed()))
156                } else {
157                    Err(Error::ResponseHandler(format!("{id:?}")))
158                }
159            } else if let Some(method) = method {
160                match result {
161                    Ok(data) => self.handle_notification(method, data).await,
162                    _ => Ok(()),
163                }
164            } else {
165                Err(Error::NotificationMethod)
166            }
167        } else {
168            return Err(Error::WebSocketMessageType);
169        }
170    }
171
172    async fn handle_disconnect(&self) -> Result<()> {
173        self.pending.lock().unwrap().retain(|_, pending| {
174            (pending.callback)(Err(Error::Disconnect), None)
175                .unwrap_or_else(|err| log_trace!("Error in RPC callback during timeout: `{err}`"));
176            false
177        });
178
179        Ok(())
180    }
181}