xrpc/
client.rs

1use parking_lot::Mutex;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::oneshot;
9
10use crate::codec::{BincodeCodec, Codec};
11use crate::error::{Result, RpcError};
12use crate::message::Message;
13use crate::message::types::{MessageId, MessageType};
14use crate::streaming::{StreamManager, StreamReceiver, next_stream_id};
15use crate::transport::message_transport::MessageTransport;
16
17pub struct RpcClient<T, C: Codec = BincodeCodec>
18where
19    T: MessageTransport<C>,
20{
21    transport: Arc<T>,
22    pending: Arc<Mutex<HashMap<MessageId, oneshot::Sender<Result<Message<C>>>>>>,
23    stream_manager: Arc<StreamManager<C>>,
24    codec: C,
25    running: Arc<AtomicBool>,
26    default_timeout: Duration,
27}
28
29impl<T: MessageTransport<BincodeCodec> + 'static> RpcClient<T, BincodeCodec> {
30    pub fn new(transport: T) -> Self {
31        Self::with_timeout(transport, Duration::from_secs(30))
32    }
33
34    pub fn with_timeout(transport: T, default_timeout: Duration) -> Self {
35        Self {
36            transport: Arc::new(transport),
37            pending: Arc::new(Mutex::new(HashMap::new())),
38            stream_manager: Arc::new(StreamManager::new()),
39            codec: BincodeCodec,
40            running: Arc::new(AtomicBool::new(false)),
41            default_timeout,
42        }
43    }
44}
45
46impl<T, C> RpcClient<T, C>
47where
48    T: MessageTransport<C> + 'static,
49    C: Codec + Clone + Default + 'static,
50{
51    pub fn with_codec(transport: T, codec: C) -> Self {
52        Self {
53            transport: Arc::new(transport),
54            pending: Arc::new(Mutex::new(HashMap::new())),
55            stream_manager: Arc::new(StreamManager::with_codec(codec.clone())),
56            codec,
57            running: Arc::new(AtomicBool::new(false)),
58            default_timeout: Duration::from_secs(30),
59        }
60    }
61
62    pub fn with_codec_and_timeout(transport: T, codec: C, default_timeout: Duration) -> Self {
63        Self {
64            transport: Arc::new(transport),
65            pending: Arc::new(Mutex::new(HashMap::new())),
66            stream_manager: Arc::new(StreamManager::with_codec(codec.clone())),
67            codec,
68            running: Arc::new(AtomicBool::new(false)),
69            default_timeout,
70        }
71    }
72
73    pub fn start(&self) -> RpcClientHandle {
74        self.running.store(true, Ordering::Release);
75
76        let transport = self.transport.clone();
77        let pending = self.pending.clone();
78        let stream_manager = self.stream_manager.clone();
79        let running = self.running.clone();
80
81        let handle = tokio::spawn(async move {
82            while running.load(Ordering::Acquire) {
83                match transport.recv().await {
84                    Ok(message) => match message.msg_type {
85                        MessageType::Reply => {
86                            if let Some(tx) = pending.lock().remove(&message.id) {
87                                let _ = tx.send(Ok(message));
88                            }
89                        }
90                        MessageType::Error => {
91                            // Route stream errors to StreamManager, others to pending
92                            if let Some(stream_id) = message.metadata.stream_id {
93                                let error_msg: String = BincodeCodec
94                                    .decode(&message.payload)
95                                    .unwrap_or_else(|_| "Unknown error".to_string());
96                                stream_manager.send_error(stream_id, error_msg);
97                            } else if let Some(tx) = pending.lock().remove(&message.id) {
98                                let _ = tx.send(Ok(message));
99                            }
100                        }
101                        MessageType::StreamChunk | MessageType::StreamEnd => {
102                            stream_manager.handle_message(&message);
103                        }
104                        _ => {}
105                    },
106                    Err(_) => {
107                        break;
108                    }
109                }
110            }
111        });
112
113        RpcClientHandle { handle }
114    }
115
116    pub fn transport(&self) -> Arc<T> {
117        self.transport.clone()
118    }
119
120    pub fn stream_manager(&self) -> Arc<StreamManager<C>> {
121        self.stream_manager.clone()
122    }
123
124    pub async fn call<Req, Resp>(&self, method: &str, request: &Req) -> Result<Resp>
125    where
126        Req: Serialize,
127        Resp: for<'de> Deserialize<'de>,
128    {
129        self.call_with_timeout(method, request, self.default_timeout)
130            .await
131    }
132
133    pub async fn call_with_timeout<Req, Resp>(
134        &self,
135        method: &str,
136        request: &Req,
137        timeout: Duration,
138    ) -> Result<Resp>
139    where
140        Req: Serialize,
141        Resp: for<'de> Deserialize<'de>,
142    {
143        let message: Message<C> = Message::call(method, request)?;
144        let msg_id = message.id;
145
146        let (tx, rx) = oneshot::channel();
147        self.pending.lock().insert(msg_id, tx);
148
149        if let Err(e) = self.transport.send(&message).await {
150            self.pending.lock().remove(&msg_id);
151            return Err(RpcError::Transport(e));
152        }
153
154        let response = tokio::time::timeout(timeout, rx)
155            .await
156            .map_err(|_| {
157                self.pending.lock().remove(&msg_id);
158                RpcError::Timeout(format!("Request {} timed out after {:?}", msg_id, timeout))
159            })?
160            .map_err(|_| RpcError::ConnectionClosed)??;
161
162        match response.msg_type {
163            MessageType::Reply => self.codec.decode(&response.payload),
164            MessageType::Error => {
165                let error_msg: String = self
166                    .codec
167                    .decode(&response.payload)
168                    .unwrap_or_else(|_| "Unknown error".to_string());
169                Err(RpcError::ServerError(error_msg))
170            }
171            _ => Err(RpcError::InvalidMessage(format!(
172                "Unexpected message type: {:?}",
173                response.msg_type
174            ))),
175        }
176    }
177
178    pub async fn call_server_stream<Req, Resp>(
179        &self,
180        method: &str,
181        request: &Req,
182    ) -> Result<StreamReceiver<Resp, C>>
183    where
184        Req: Serialize,
185        Resp: for<'de> Deserialize<'de>,
186    {
187        let stream_id = next_stream_id();
188        let receiver = self.stream_manager.create_receiver::<Resp>(stream_id);
189
190        let mut message: Message<C> = Message::call(method, request)?;
191        message.metadata = message.metadata.with_stream(stream_id, 0);
192
193        self.transport
194            .send(&message)
195            .await
196            .map_err(RpcError::Transport)?;
197
198        Ok(receiver)
199    }
200
201    pub async fn notify<Req: Serialize>(&self, method: &str, request: &Req) -> Result<()> {
202        let message: Message<C> = Message::notification(method, request)?;
203        self.transport
204            .send(&message)
205            .await
206            .map_err(RpcError::Transport)
207    }
208
209    pub async fn call_raw(&self, method: &str, payload: Vec<u8>) -> Result<Vec<u8>> {
210        self.call_raw_with_timeout(method, payload, self.default_timeout)
211            .await
212    }
213
214    pub async fn call_raw_with_timeout(
215        &self,
216        method: &str,
217        payload: Vec<u8>,
218        timeout: Duration,
219    ) -> Result<Vec<u8>> {
220        let message: Message<C> = Message::new(
221            MessageId::new(),
222            MessageType::Call,
223            method,
224            payload.into(),
225            Default::default(),
226        );
227        let msg_id = message.id;
228
229        let (tx, rx) = oneshot::channel();
230        self.pending.lock().insert(msg_id, tx);
231
232        if let Err(e) = self.transport.send(&message).await {
233            self.pending.lock().remove(&msg_id);
234            return Err(RpcError::Transport(e));
235        }
236
237        let response = tokio::time::timeout(timeout, rx)
238            .await
239            .map_err(|_| {
240                self.pending.lock().remove(&msg_id);
241                RpcError::Timeout(format!("Request {} timed out after {:?}", msg_id, timeout))
242            })?
243            .map_err(|_| RpcError::ConnectionClosed)??;
244
245        match response.msg_type {
246            MessageType::Reply => Ok(response.payload.to_vec()),
247            MessageType::Error => {
248                let error_msg: String = self
249                    .codec
250                    .decode(&response.payload)
251                    .unwrap_or_else(|_| "Unknown error".to_string());
252                Err(RpcError::ServerError(error_msg))
253            }
254            _ => Err(RpcError::InvalidMessage(format!(
255                "Unexpected message type: {:?}",
256                response.msg_type
257            ))),
258        }
259    }
260
261    pub fn is_connected(&self) -> bool {
262        self.transport.is_connected()
263    }
264
265    pub fn active_streams(&self) -> usize {
266        self.stream_manager.active_stream_count()
267    }
268
269    pub async fn close(&self) -> Result<()> {
270        self.running.store(false, Ordering::Release);
271        self.transport.close().await.map_err(RpcError::Transport)
272    }
273}
274
275impl<T, C> Debug for RpcClient<T, C>
276where
277    T: MessageTransport<C>,
278    C: Codec + Clone,
279{
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.debug_struct("RpcClient")
282            .field("running", &self.running.load(Ordering::Relaxed))
283            .field("pending_requests", &self.pending.lock().len())
284            .field("active_streams", &self.stream_manager.active_stream_count())
285            .finish()
286    }
287}
288
289pub struct RpcClientHandle {
290    handle: tokio::task::JoinHandle<()>,
291}
292
293impl RpcClientHandle {
294    pub async fn shutdown(self) {
295        self.handle.abort();
296        let _ = self.handle.await;
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::transport::channel::{ChannelConfig, ChannelTransport};
304    use crate::transport::message_transport::MessageTransportAdapter;
305
306    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
307    struct AddRequest {
308        a: i32,
309        b: i32,
310    }
311
312    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
313    struct AddResponse {
314        result: i32,
315    }
316
317    #[tokio::test]
318    async fn test_client_call_reply() {
319        let config = ChannelConfig::default();
320        let (t1, t2) = ChannelTransport::create_pair("test", config).unwrap();
321
322        let client_transport = MessageTransportAdapter::new(t1);
323        let server_transport = MessageTransportAdapter::new(t2);
324
325        let client = RpcClient::new(client_transport);
326        let _handle = client.start();
327
328        let server_handle = tokio::spawn(async move {
329            let msg = server_transport.recv().await.unwrap();
330            assert_eq!(msg.method, "add");
331
332            let req: AddRequest = msg.deserialize_payload().unwrap();
333            let resp = AddResponse {
334                result: req.a + req.b,
335            };
336
337            let reply: Message = Message::reply(msg.id, resp).unwrap();
338            server_transport.send(&reply).await.unwrap();
339        });
340
341        let response: AddResponse = client
342            .call("add", &AddRequest { a: 10, b: 32 })
343            .await
344            .unwrap();
345        assert_eq!(response.result, 42);
346
347        server_handle.await.unwrap();
348    }
349
350    #[tokio::test]
351    async fn test_client_server_stream() {
352        let config = ChannelConfig::default();
353        let (t1, t2) = ChannelTransport::create_pair("test", config).unwrap();
354
355        let client_transport = MessageTransportAdapter::new(t1);
356        let server_transport = Arc::new(MessageTransportAdapter::new(t2));
357
358        let client = RpcClient::new(client_transport);
359        let _handle = client.start();
360
361        let server_transport_clone = server_transport.clone();
362        let server_handle = tokio::spawn(async move {
363            let msg = server_transport_clone.recv().await.unwrap();
364            let stream_id = msg.metadata.stream_id.unwrap();
365
366            for i in 1..=3 {
367                let chunk: Message = Message::stream_chunk(stream_id, i - 1, i as i32).unwrap();
368                server_transport_clone.send(&chunk).await.unwrap();
369            }
370
371            let end: Message = Message::stream_end(stream_id);
372            server_transport_clone.send(&end).await.unwrap();
373        });
374
375        let mut stream: StreamReceiver<i32> =
376            client.call_server_stream("get_numbers", &()).await.unwrap();
377
378        let mut items = Vec::new();
379        while let Some(result) = stream.recv().await {
380            items.push(result.unwrap());
381        }
382
383        assert_eq!(items, vec![1, 2, 3]);
384        server_handle.await.unwrap();
385    }
386
387    #[tokio::test]
388    async fn test_client_stream_error() {
389        let config = ChannelConfig::default();
390        let (t1, t2) = ChannelTransport::create_pair("test", config).unwrap();
391
392        let client_transport = MessageTransportAdapter::new(t1);
393        let server_transport = MessageTransportAdapter::new(t2);
394
395        let client = RpcClient::new(client_transport);
396        let _handle = client.start();
397
398        // Server sends error with stream_id
399        let server_handle = tokio::spawn(async move {
400            let msg = server_transport.recv().await.unwrap();
401            let stream_id = msg.metadata.stream_id.unwrap();
402
403            let error: Message = Message::stream_error(msg.id, stream_id, "method not found");
404            server_transport.send(&error).await.unwrap();
405        });
406
407        let mut stream: StreamReceiver<i32> = client
408            .call_server_stream("unknown_method", &())
409            .await
410            .unwrap();
411
412        // Should receive error, not hang
413        let result = stream.recv().await;
414        assert!(result.is_some());
415        assert!(result.unwrap().is_err());
416
417        server_handle.await.unwrap();
418    }
419}