thin_jsonrpc_client/
lib.rs

1//! # thin-jsonrpc-client
2//!
3//! This crate provides a lightweight JSON-RPC compatible client.
4#![deny(missing_docs)]
5
6/// A broadcast-style stream of decoded JSON-RPC responses.
7mod response_stream;
8mod response;
9
10/// Helpers to build parameters for a JSON-RPC request.
11pub mod params;
12/// The backend trait, to connect a client to some server.
13pub mod backend;
14/// JSON-RPC response types.
15pub mod raw_response;
16
17use futures_core::Stream;
18use futures_util::StreamExt;
19use response::ErrorObject;
20use response_stream::{ResponseStreamMaster, ResponseStreamHandle, ResponseStream};
21use raw_response::{RawResponse};
22use backend::{BackendSender, BackendReceiver};
23use params::IntoRpcParams;
24use std::sync::atomic::{ AtomicU64, Ordering };
25use std::sync::Arc;
26use std::task::Poll;
27
28pub use response::{ Response, ResponseError };
29
30/// An error handed back from [`Client::request()`].
31#[derive(Debug, derive_more::From, derive_more::Display)]
32#[non_exhaustive]
33pub enum RequestError {
34    /// An error from the backend implementation that was emitted when
35    /// attempting to send the request.
36    #[from]
37    Backend(backend::BackendError),
38    /// The connection was closed before a response was delivered.
39    ConnectionClosed,
40}
41
42impl std::error::Error for RequestError {
43    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44        match self {
45            RequestError::Backend(e) => Some(e),
46            RequestError::ConnectionClosed => None
47        }
48    }
49}
50
51/// A JSON-RPC client. Build this by calling [`Client::from_backend()`]
52/// and providing a suitable sender and receiver.
53#[derive(Clone)]
54pub struct Client {
55    next_id: Arc<AtomicU64>,
56    sender: Arc<dyn BackendSender>,
57    stream: ResponseStreamHandle,
58}
59
60impl std::fmt::Debug for Client {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("Client")
63            .field("next_id", &self.next_id)
64            .field("sender", &"Box<dyn BackendSender>")
65            .field("stream", &self.stream)
66            .finish()
67    }
68}
69
70impl Client {
71    /// Construct a client/driver from a [`BackendSender`] and [`BackendReceiver`].
72    /// The [`ClientDriver`] handed back is a stream which needs polling in order to
73    /// drive the message receiving.
74    pub fn from_backend<S, R>(send: S, recv: R) -> (Client, ClientDriver)
75    where
76        S: BackendSender,
77        R: BackendReceiver
78    {
79        let master = ResponseStreamMaster::new(Box::new(recv));
80
81        let client = Client {
82            next_id: Arc::new(AtomicU64::new(1)),
83            sender: Arc::new(send),
84            stream: master.handle()
85        };
86        let client_driver = ClientDriver(master);
87
88        (client, client_driver)
89    }
90
91    /// Make a request to the RPC server. This will return either a response or an error.
92    pub async fn send_request<Params>(&self, method: &str, params: Params) -> Result<Response, RequestError>
93    where Params: IntoRpcParams
94    {
95        let id = self.next_id.fetch_add(1, Ordering::Relaxed).to_string();
96
97        // Build the request:
98        let request = match params.into_rpc_params() {
99            Some(params) =>
100                format!(r#"{{"jsonrpc":"2.0","id":"{id}","method":"{method}","params":{params}}}"#),
101            None =>
102                format!(r#"{{"jsonrpc":"2.0","id":"{id}","method":"{method}"}}"#),
103        };
104
105        // Subscribe to responses with the matching ID.
106        let mut response_stream = self.stream.response_stream().filter(move |res| {
107            let Some(msg_id) = res.id() else {
108                return std::future::ready(false)
109            };
110
111            std::future::ready(id == msg_id)
112        });
113
114        // Now we're set up to wait for the reply, send the request.
115        self.sender
116            .send(request.as_bytes())
117            .await
118            .map_err(|e| RequestError::Backend(e))?;
119
120        // Get the response.
121        let response = response_stream.next().await;
122
123        match response {
124            Some(res) => {
125                Ok(Response(res))
126            },
127            None => {
128                Err(RequestError::ConnectionClosed)
129            }
130        }
131    }
132
133    /// Send a notification to the RPC server. This will not wait for a response.
134    pub async fn send_notification<Params>(&mut self, method: &str, params: Params) -> Result<(), backend::BackendError>
135    where Params: IntoRpcParams
136    {
137        let notification = match params.into_rpc_params() {
138            Some(params) =>
139                format!(r#"{{"jsonrpc":"2.0","method":"{method}","params":{params}}}"#),
140            None =>
141                format!(r#"{{"jsonrpc":"2.0","method":"{method}"}}"#),
142        };
143
144        // Send the message. No need to wait for any response.
145        self.sender.send(notification.as_bytes()).await?;
146        Ok(())
147    }
148
149    /// Obtain a stream of server notifications from the backend that aren't linked to
150    /// any specific request.
151    pub fn notifications(&self) -> ServerNotifications {
152        ServerNotifications(self.stream.response_stream())
153    }
154}
155
156/// This must be polled in order to accept messages from the server.
157/// Nothing will happen unless it is. This design allows us to apply backpressure
158/// (by polling this less often), and allows us to drive multiple notification streams
159/// without requiring any specific async runtime. It will return:
160///
161/// - `Some(Ok(response))` if it successfully received a response.
162/// - `Some(Err(e))` if it failed to parse some bytes into a response, or the response
163///    was invalid.
164/// - `None` if the backend has stopped (either after a fatal error, which will be
165///   delivered just prior to this, or because the [`Client`] was dropped.
166pub struct ClientDriver(ResponseStreamMaster);
167
168/// An error was encountered while receiving messages from the backend.
169pub type ClientDriverError = response_stream::ResponseStreamError;
170
171impl Stream for ClientDriver {
172    type Item = Result<Response, ClientDriverError>;
173    fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
174        self.0.poll_next_unpin(cx).map(|o| o.map(|r| r.map(Response)))
175    }
176}
177
178/// A struct representing messages from the server. This
179/// implements [`futures_util::Stream`] and provides a couple of
180/// additional helper methods to further filter the stream.
181pub struct ServerNotifications(ResponseStream);
182
183impl ServerNotifications {
184    /// This is analogous to [`Response::ok_into()`], but will apply to each
185    /// notification in the stream, filtering out any that don't deserialize into
186    /// the given type.
187    pub fn ok_into<R>(self) -> impl Stream<Item=R> + Send + Sync + 'static
188    where
189        R: for<'de> serde::de::Deserialize<'de> + Send + Sync + 'static,
190    {
191        self.filter_map(move |res| {
192            match res.ok_into() {
193                Err(_) => std::future::ready(None),
194                Ok(r) => std::future::ready(Some(r))
195            }
196        })
197    }
198
199    /// Like [`ServerNotifications::ok_into()`], but also accepts a filter function to
200    /// ignore any values that we're not interested in.
201    pub fn ok_into_if<R, F>(self, filter_fn: F) -> impl Stream<Item=R> + Send + Sync + 'static
202    where
203        R: for<'de> serde::de::Deserialize<'de> + Send + Sync + 'static,
204        F: Fn(&R) -> bool + Send + Sync + 'static
205    {
206        self.ok_into().filter(move |n| std::future::ready(filter_fn(n)))
207    }
208
209    /// This is analogous to [`Response::error_into()`], but will apply to each
210    /// notification in the stream, filtering out any that don't deserialize into
211    /// the given type.
212    pub fn error_into<R>(self) -> impl Stream<Item=ErrorObject<R>> + Send + Sync + 'static
213    where
214        R: for<'de> serde::de::Deserialize<'de> + Send + Sync + 'static,
215    {
216        self.filter_map(move |res| {
217            match res.error_into() {
218                Err(_) => std::future::ready(None),
219                Ok(r) => std::future::ready(Some(r))
220            }
221        })
222    }
223
224    /// Like [`ServerNotifications::error_into()`], but also accepts a filter function to
225    /// ignore any values that we're not interested in.
226    pub fn error_into_if<R, F>(self, filter_fn: F) -> impl Stream<Item=ErrorObject<R>> + Send + Sync + 'static
227    where
228        R: for<'de> serde::de::Deserialize<'de> + Send + Sync + 'static,
229        F: Fn(&ErrorObject<R>) -> bool + Send + Sync + 'static
230    {
231        self.error_into().filter(move |n| std::future::ready(filter_fn(n)))
232    }
233}
234
235
236impl Stream for ServerNotifications {
237    type Item = Response;
238    fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
239        loop {
240            let Poll::Ready(res) = self.0.poll_next_unpin(cx) else {
241                return Poll::Pending
242            };
243
244            let Some(res) = res else {
245                return Poll::Ready(None)
246            };
247
248            if res.id().is_some() {
249                // If the response has an ID, we filter
250                // it out. Loop and poll again because we
251                // can't return pending if the inner was ready.
252                continue
253            }
254
255            return Poll::Ready(Some(Response(res)))
256        }
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use super::*;
263    use backend::mock::{ self, MockBackend };
264    use futures_util::StreamExt;
265    use crate::response::ErrorObject;
266
267    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
268
269    fn mock_client() -> (Client, ClientDriver, MockBackend) {
270        let (mock_backend, mock_send, mock_recv) = mock::build();
271        let (client, driver) = Client::from_backend(mock_send, mock_recv);
272        (client, driver, mock_backend)
273    }
274
275    fn drive_client(mut driver: ClientDriver) {
276        tokio::spawn(async move {
277            while let Some(res) = driver.next().await {
278                if let Err(err) = res {
279                    eprintln!("ClientDriver Error: {err}");
280                }
281            }
282        });
283    }
284
285    #[tokio::test]
286    async fn test_basic_requests() -> Result<(), Error> {
287        let (client, driver, backend) = mock_client();
288        drive_client(driver);
289
290        backend
291            .handler("echo_ok_raw", |cx, req| {
292                let id = req.id.unwrap_or("-1".to_string());
293                let params = req.params;
294                let res = format!(r#"{{ "jsonrpc": "2.0", "id": "{id}", "result": {params} }}"#);
295                cx.send_bytes(res.into_bytes());
296            })
297            .handler("echo_err_raw", |cx, req| {
298                let id = req.id.unwrap_or("-1".to_string());
299                let params = req.params;
300                let res = format!(r#"{{ "jsonrpc": "2.0", "id": "{id}", "error": {{ "code":123, "message":"Eep!", "data": {params} }} }}"#);
301                cx.send_bytes(res.into_bytes());
302            })
303            .handler("add", |cx, req| {
304                let (a, b): (i64, i64) = serde_json::from_str(req.params.get()).unwrap();
305                cx.send_ok_response(req.id, a+b);
306            });
307
308        // Check we can decode ok response properly.
309        let res: Vec<u8> = client.send_request("echo_ok_raw", (1,2,3)).await?.ok_into()?;
310        assert_eq!(res, vec![1,2,3]);
311
312        // Check we can decode error response properly.
313        let err: ErrorObject<Vec<u8>> = client.send_request("echo_err_raw", (1,2,3)).await?.error_into()?;
314        assert_eq!(err.code, 123);
315        assert_eq!(err.message, "Eep!");
316        assert_eq!(err.data, vec![1,2,3]);
317
318        // Ensure ID's are incremented and such.
319        for i in 0i64..500 {
320            let res: i64 = client.send_request("add", (i, 1)).await?.ok_into()?;
321            assert_eq!(res, i + 1);
322        }
323
324        Ok(())
325    }
326
327    #[tokio::test]
328    async fn test_notifications() -> Result<(), Error> {
329        let (mut client, driver, backend) = mock_client();
330        drive_client(driver);
331
332        backend
333            .handler("start_sending", |cx, _req| {
334                for i in 0u64..20 {
335                    cx.send_ok_notification(i);
336                }
337                // If we don't shutdown after, the streams will wait
338                // forever for new messages:
339                cx.shutdown();
340            });
341
342        // Subscribe to messages:
343        let twos = client.notifications().ok_into_if(|n: &u64| n % 2 == 0);
344        let threes = client.notifications().ok_into_if(|res: &u64| res % 3 == 0);
345        let bools = client.notifications().ok_into::<bool>();
346
347        // Start sending notifications.
348        client.send_notification("start_sending", ()).await?;
349
350        // Collect them in our streams to check:
351        let twos: Vec<u64> = twos.collect().await;
352        let threes: Vec<u64> = threes.collect().await;
353        let bools: Vec<_> = bools.collect().await;
354
355        let expected_twos = vec![0,2,4,6,8,10,12,14,16,18];
356        let expected_threes = vec![0,3,6,9,12,15,18];
357
358        assert_eq!(twos, expected_twos);
359        assert_eq!(threes, expected_threes);
360        assert!(bools.is_empty());
361
362        Ok(())
363    }
364
365    #[tokio::test]
366    async fn test_lots_of_subscriptions() {
367        let (client, driver, backend) = mock_client();
368        drive_client(driver);
369
370        // lots of listeners:
371        let handles = (0..1000).map(|_| {
372            let mut notifs = client.notifications().ok_into::<u64>();
373            tokio::spawn(async move {
374                let mut n = 0;
375                while let Some(_res) = notifs.next().await {
376                    n += 1;
377                }
378                n
379            })
380        });
381
382        tokio::spawn(async move {
383            for n in 0u64..1000 {
384                backend.send_ok_notification(n);
385            }
386            // Else streams will wait forever:
387            backend.shutdown();
388        });
389
390        let counts: Vec<_> = futures_util::future::join_all(handles).await;
391
392        // Check that every stream saw every message:
393        for count in counts {
394            assert_eq!(count.unwrap(), 1000);
395        }
396    }
397}