use core::marker::PhantomData;
use super::{Pending, PendingMap, ProtocolHandler};
pub use crate::client::error::Error;
pub use crate::client::result::Result;
use crate::client::Interface;
use crate::imports::*;
use crate::messages::serde_json::*;
pub type JsonResponseFn =
Arc<Box<(dyn Fn(Result<Value>, Option<&Duration>) -> Result<()> + Sync + Send)>>;
pub struct JsonProtocol<Ops, Id>
where
Ops: OpsT,
Id: IdT,
{
ws: Arc<WebSocket>,
pending: PendingMap<Id, JsonResponseFn>,
interface: Option<Arc<Interface<Ops>>>,
id: PhantomData<Id>,
}
impl<Ops, Id> JsonProtocol<Ops, Id>
where
Id: IdT,
Ops: OpsT,
{
fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self {
JsonProtocol::<Ops, Id> {
ws,
pending: Arc::new(Mutex::new(AHashMap::new())),
interface,
id: PhantomData,
}
}
}
type MessageInfo<Ops, Id> = (Option<Id>, Option<Ops>, Result<Value>);
impl<Ops, Id> JsonProtocol<Ops, Id>
where
Ops: OpsT,
Id: IdT,
{
fn decode(&self, server_message: &str) -> Result<MessageInfo<Ops, Id>> {
let msg: JSONServerMessage<Ops, Id> = serde_json::from_str(server_message)?;
if let Some(error) = msg.error {
Ok((msg.id, None, Err(error.into())))
} else if msg.id.is_some() {
if let Some(result) = msg.params {
Ok((msg.id, None, Ok(result)))
} else {
Ok((msg.id, None, Err(Error::NoDataInSuccessResponse)))
}
} else if let Some(params) = msg.params {
Ok((None, msg.method, Ok(params)))
} else {
Ok((None, None, Err(Error::NoDataInNotificationMessage)))
}
}
pub async fn request<Req, Resp>(&self, op: Ops, req: Req) -> Result<Resp>
where
Req: MsgT,
Resp: MsgT,
{
let id = Id::generate();
let (sender, receiver) = oneshot();
{
let mut pending = self.pending.lock().unwrap();
pending.insert(
id.clone(),
Pending::new(Arc::new(Box::new(move |result, _duration| {
sender.try_send(result)?;
Ok(())
}))),
);
}
let payload = serde_json::to_value(req)?;
let client_message = JsonClientMessage::new(Some(id), op, payload);
let json = serde_json::to_string(&client_message)?;
self.ws.post(WebSocketMessage::Text(json)).await?;
let data = receiver.recv().await??;
let resp = <Resp as Deserialize>::deserialize(data)
.map_err(|e| Error::SerdeDeserialize(e.to_string()))?;
Ok(resp)
}
pub async fn notify<Msg>(&self, op: Ops, data: Msg) -> Result<()>
where
Msg: Serialize + Send + Sync + 'static,
{
let payload = serde_json::to_value(data)?;
let client_message = JsonClientMessage::<Ops, Id>::new(None, op, payload);
let json = serde_json::to_string(&client_message)?;
self.ws.post(WebSocketMessage::Text(json)).await?;
Ok(())
}
async fn handle_notification(&self, op: Ops, payload: Value) -> Result<()> {
if let Some(interface) = &self.interface {
interface
.call_notification_with_serde_json(&op, payload)
.await
.unwrap_or_else(|err| log_trace!("error handling server notification {}", err));
} else {
log_trace!("unable to handle server notification - interface is not initialized");
}
Ok(())
}
}
#[async_trait]
impl<Ops, Id> ProtocolHandler<Ops> for JsonProtocol<Ops, Id>
where
Ops: OpsT,
Id: IdT,
{
fn new(ws: Arc<WebSocket>, interface: Option<Arc<Interface<Ops>>>) -> Self
where
Self: Sized,
{
JsonProtocol::new(ws, interface)
}
async fn handle_timeout(&self, timeout: Duration) {
self.pending.lock().unwrap().retain(|_, pending| {
if pending.timestamp.elapsed() > timeout {
(pending.callback)(Err(Error::Timeout), None).unwrap_or_else(|err| {
log_trace!("Error in RPC callback during timeout: `{err}`")
});
false
} else {
true
}
});
}
async fn handle_message(&self, message: WebSocketMessage) -> Result<()> {
if let WebSocketMessage::Text(server_message) = message {
let (id, method, result) = self.decode(server_message.as_str())?;
if let Some(id) = id {
if let Some(pending) = self.pending.lock().unwrap().remove(&id) {
(pending.callback)(result, Some(&pending.timestamp.elapsed()))
} else {
Err(Error::ResponseHandler(format!("{id:?}")))
}
} else if let Some(method) = method {
match result {
Ok(data) => self.handle_notification(method, data).await,
_ => Ok(()),
}
} else {
Err(Error::NotificationMethod)
}
} else {
return Err(Error::WebSocketMessageType);
}
}
async fn handle_disconnect(&self) -> Result<()> {
self.pending.lock().unwrap().retain(|_, pending| {
(pending.callback)(Err(Error::Disconnect), None)
.unwrap_or_else(|err| log_trace!("Error in RPC callback during timeout: `{err}`"));
false
});
Ok(())
}
}