1#![deny(missing_docs)]
5
6mod response_stream;
8mod response;
9
10pub mod params;
12pub mod backend;
14pub mod raw_response;
16
17use futures_core::Stream;
18use futures_util::StreamExt;
19use response::ErrorObject;
20use response_stream::{ResponseStreamMaster, ResponseStreamHandle, ResponseStream};
21use raw_response::{RawResponse};
22use backend::{BackendSender, BackendReceiver};
23use params::IntoRpcParams;
24use std::sync::atomic::{ AtomicU64, Ordering };
25use std::sync::Arc;
26use std::task::Poll;
27
28pub use response::{ Response, ResponseError };
29
30#[derive(Debug, derive_more::From, derive_more::Display)]
32#[non_exhaustive]
33pub enum RequestError {
34 #[from]
37 Backend(backend::BackendError),
38 ConnectionClosed,
40}
41
42impl std::error::Error for RequestError {
43 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44 match self {
45 RequestError::Backend(e) => Some(e),
46 RequestError::ConnectionClosed => None
47 }
48 }
49}
50
51#[derive(Clone)]
54pub struct Client {
55 next_id: Arc<AtomicU64>,
56 sender: Arc<dyn BackendSender>,
57 stream: ResponseStreamHandle,
58}
59
60impl std::fmt::Debug for Client {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("Client")
63 .field("next_id", &self.next_id)
64 .field("sender", &"Box<dyn BackendSender>")
65 .field("stream", &self.stream)
66 .finish()
67 }
68}
69
70impl Client {
71 pub fn from_backend<S, R>(send: S, recv: R) -> (Client, ClientDriver)
75 where
76 S: BackendSender,
77 R: BackendReceiver
78 {
79 let master = ResponseStreamMaster::new(Box::new(recv));
80
81 let client = Client {
82 next_id: Arc::new(AtomicU64::new(1)),
83 sender: Arc::new(send),
84 stream: master.handle()
85 };
86 let client_driver = ClientDriver(master);
87
88 (client, client_driver)
89 }
90
91 pub async fn send_request<Params>(&self, method: &str, params: Params) -> Result<Response, RequestError>
93 where Params: IntoRpcParams
94 {
95 let id = self.next_id.fetch_add(1, Ordering::Relaxed).to_string();
96
97 let request = match params.into_rpc_params() {
99 Some(params) =>
100 format!(r#"{{"jsonrpc":"2.0","id":"{id}","method":"{method}","params":{params}}}"#),
101 None =>
102 format!(r#"{{"jsonrpc":"2.0","id":"{id}","method":"{method}"}}"#),
103 };
104
105 let mut response_stream = self.stream.response_stream().filter(move |res| {
107 let Some(msg_id) = res.id() else {
108 return std::future::ready(false)
109 };
110
111 std::future::ready(id == msg_id)
112 });
113
114 self.sender
116 .send(request.as_bytes())
117 .await
118 .map_err(|e| RequestError::Backend(e))?;
119
120 let response = response_stream.next().await;
122
123 match response {
124 Some(res) => {
125 Ok(Response(res))
126 },
127 None => {
128 Err(RequestError::ConnectionClosed)
129 }
130 }
131 }
132
133 pub async fn send_notification<Params>(&mut self, method: &str, params: Params) -> Result<(), backend::BackendError>
135 where Params: IntoRpcParams
136 {
137 let notification = match params.into_rpc_params() {
138 Some(params) =>
139 format!(r#"{{"jsonrpc":"2.0","method":"{method}","params":{params}}}"#),
140 None =>
141 format!(r#"{{"jsonrpc":"2.0","method":"{method}"}}"#),
142 };
143
144 self.sender.send(notification.as_bytes()).await?;
146 Ok(())
147 }
148
149 pub fn notifications(&self) -> ServerNotifications {
152 ServerNotifications(self.stream.response_stream())
153 }
154}
155
156pub struct ClientDriver(ResponseStreamMaster);
167
168pub type ClientDriverError = response_stream::ResponseStreamError;
170
171impl Stream for ClientDriver {
172 type Item = Result<Response, ClientDriverError>;
173 fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
174 self.0.poll_next_unpin(cx).map(|o| o.map(|r| r.map(Response)))
175 }
176}
177
178pub struct ServerNotifications(ResponseStream);
182
183impl ServerNotifications {
184 pub fn ok_into<R>(self) -> impl Stream<Item=R> + Send + Sync + 'static
188 where
189 R: for<'de> serde::de::Deserialize<'de> + Send + Sync + 'static,
190 {
191 self.filter_map(move |res| {
192 match res.ok_into() {
193 Err(_) => std::future::ready(None),
194 Ok(r) => std::future::ready(Some(r))
195 }
196 })
197 }
198
199 pub fn ok_into_if<R, F>(self, filter_fn: F) -> impl Stream<Item=R> + Send + Sync + 'static
202 where
203 R: for<'de> serde::de::Deserialize<'de> + Send + Sync + 'static,
204 F: Fn(&R) -> bool + Send + Sync + 'static
205 {
206 self.ok_into().filter(move |n| std::future::ready(filter_fn(n)))
207 }
208
209 pub fn error_into<R>(self) -> impl Stream<Item=ErrorObject<R>> + Send + Sync + 'static
213 where
214 R: for<'de> serde::de::Deserialize<'de> + Send + Sync + 'static,
215 {
216 self.filter_map(move |res| {
217 match res.error_into() {
218 Err(_) => std::future::ready(None),
219 Ok(r) => std::future::ready(Some(r))
220 }
221 })
222 }
223
224 pub fn error_into_if<R, F>(self, filter_fn: F) -> impl Stream<Item=ErrorObject<R>> + Send + Sync + 'static
227 where
228 R: for<'de> serde::de::Deserialize<'de> + Send + Sync + 'static,
229 F: Fn(&ErrorObject<R>) -> bool + Send + Sync + 'static
230 {
231 self.error_into().filter(move |n| std::future::ready(filter_fn(n)))
232 }
233}
234
235
236impl Stream for ServerNotifications {
237 type Item = Response;
238 fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
239 loop {
240 let Poll::Ready(res) = self.0.poll_next_unpin(cx) else {
241 return Poll::Pending
242 };
243
244 let Some(res) = res else {
245 return Poll::Ready(None)
246 };
247
248 if res.id().is_some() {
249 continue
253 }
254
255 return Poll::Ready(Some(Response(res)))
256 }
257 }
258}
259
260#[cfg(test)]
261mod test {
262 use super::*;
263 use backend::mock::{ self, MockBackend };
264 use futures_util::StreamExt;
265 use crate::response::ErrorObject;
266
267 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
268
269 fn mock_client() -> (Client, ClientDriver, MockBackend) {
270 let (mock_backend, mock_send, mock_recv) = mock::build();
271 let (client, driver) = Client::from_backend(mock_send, mock_recv);
272 (client, driver, mock_backend)
273 }
274
275 fn drive_client(mut driver: ClientDriver) {
276 tokio::spawn(async move {
277 while let Some(res) = driver.next().await {
278 if let Err(err) = res {
279 eprintln!("ClientDriver Error: {err}");
280 }
281 }
282 });
283 }
284
285 #[tokio::test]
286 async fn test_basic_requests() -> Result<(), Error> {
287 let (client, driver, backend) = mock_client();
288 drive_client(driver);
289
290 backend
291 .handler("echo_ok_raw", |cx, req| {
292 let id = req.id.unwrap_or("-1".to_string());
293 let params = req.params;
294 let res = format!(r#"{{ "jsonrpc": "2.0", "id": "{id}", "result": {params} }}"#);
295 cx.send_bytes(res.into_bytes());
296 })
297 .handler("echo_err_raw", |cx, req| {
298 let id = req.id.unwrap_or("-1".to_string());
299 let params = req.params;
300 let res = format!(r#"{{ "jsonrpc": "2.0", "id": "{id}", "error": {{ "code":123, "message":"Eep!", "data": {params} }} }}"#);
301 cx.send_bytes(res.into_bytes());
302 })
303 .handler("add", |cx, req| {
304 let (a, b): (i64, i64) = serde_json::from_str(req.params.get()).unwrap();
305 cx.send_ok_response(req.id, a+b);
306 });
307
308 let res: Vec<u8> = client.send_request("echo_ok_raw", (1,2,3)).await?.ok_into()?;
310 assert_eq!(res, vec![1,2,3]);
311
312 let err: ErrorObject<Vec<u8>> = client.send_request("echo_err_raw", (1,2,3)).await?.error_into()?;
314 assert_eq!(err.code, 123);
315 assert_eq!(err.message, "Eep!");
316 assert_eq!(err.data, vec![1,2,3]);
317
318 for i in 0i64..500 {
320 let res: i64 = client.send_request("add", (i, 1)).await?.ok_into()?;
321 assert_eq!(res, i + 1);
322 }
323
324 Ok(())
325 }
326
327 #[tokio::test]
328 async fn test_notifications() -> Result<(), Error> {
329 let (mut client, driver, backend) = mock_client();
330 drive_client(driver);
331
332 backend
333 .handler("start_sending", |cx, _req| {
334 for i in 0u64..20 {
335 cx.send_ok_notification(i);
336 }
337 cx.shutdown();
340 });
341
342 let twos = client.notifications().ok_into_if(|n: &u64| n % 2 == 0);
344 let threes = client.notifications().ok_into_if(|res: &u64| res % 3 == 0);
345 let bools = client.notifications().ok_into::<bool>();
346
347 client.send_notification("start_sending", ()).await?;
349
350 let twos: Vec<u64> = twos.collect().await;
352 let threes: Vec<u64> = threes.collect().await;
353 let bools: Vec<_> = bools.collect().await;
354
355 let expected_twos = vec![0,2,4,6,8,10,12,14,16,18];
356 let expected_threes = vec![0,3,6,9,12,15,18];
357
358 assert_eq!(twos, expected_twos);
359 assert_eq!(threes, expected_threes);
360 assert!(bools.is_empty());
361
362 Ok(())
363 }
364
365 #[tokio::test]
366 async fn test_lots_of_subscriptions() {
367 let (client, driver, backend) = mock_client();
368 drive_client(driver);
369
370 let handles = (0..1000).map(|_| {
372 let mut notifs = client.notifications().ok_into::<u64>();
373 tokio::spawn(async move {
374 let mut n = 0;
375 while let Some(_res) = notifs.next().await {
376 n += 1;
377 }
378 n
379 })
380 });
381
382 tokio::spawn(async move {
383 for n in 0u64..1000 {
384 backend.send_ok_notification(n);
385 }
386 backend.shutdown();
388 });
389
390 let counts: Vec<_> = futures_util::future::join_all(handles).await;
391
392 for count in counts {
394 assert_eq!(count.unwrap(), 1000);
395 }
396 }
397}