xrpc/
client.rs

1use parking_lot::Mutex;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::oneshot;
9
10use crate::channel::message::MessageChannel;
11use crate::codec::{BincodeCodec, Codec};
12use crate::error::{Result, RpcError};
13use crate::message::Message;
14use crate::message::types::{MessageId, MessageType};
15use crate::streaming::{StreamManager, StreamReceiver, next_stream_id};
16
17/// RPC client for making remote procedure calls.
18pub struct RpcClient<T, C: Codec = BincodeCodec>
19where
20    T: MessageChannel<C>,
21{
22    transport: Arc<T>,
23    pending: Arc<Mutex<HashMap<MessageId, oneshot::Sender<Result<Message<C>>>>>>,
24    stream_manager: Arc<StreamManager<C>>,
25    codec: C,
26    running: Arc<AtomicBool>,
27    default_timeout: Duration,
28}
29
30impl<T: MessageChannel<BincodeCodec> + 'static> RpcClient<T, BincodeCodec> {
31    pub fn new(transport: T) -> Self {
32        Self::with_timeout(transport, Duration::from_secs(30))
33    }
34
35    pub fn with_timeout(transport: T, default_timeout: Duration) -> Self {
36        Self {
37            transport: Arc::new(transport),
38            pending: Arc::new(Mutex::new(HashMap::new())),
39            stream_manager: Arc::new(StreamManager::new()),
40            codec: BincodeCodec,
41            running: Arc::new(AtomicBool::new(false)),
42            default_timeout,
43        }
44    }
45}
46
47impl<T, C> RpcClient<T, C>
48where
49    T: MessageChannel<C> + 'static,
50    C: Codec + Clone + Default + 'static,
51{
52    pub fn with_codec(transport: T, codec: C) -> Self {
53        Self {
54            transport: Arc::new(transport),
55            pending: Arc::new(Mutex::new(HashMap::new())),
56            stream_manager: Arc::new(StreamManager::with_codec(codec.clone())),
57            codec,
58            running: Arc::new(AtomicBool::new(false)),
59            default_timeout: Duration::from_secs(30),
60        }
61    }
62
63    pub fn with_codec_and_timeout(transport: T, codec: C, default_timeout: Duration) -> Self {
64        Self {
65            transport: Arc::new(transport),
66            pending: Arc::new(Mutex::new(HashMap::new())),
67            stream_manager: Arc::new(StreamManager::with_codec(codec.clone())),
68            codec,
69            running: Arc::new(AtomicBool::new(false)),
70            default_timeout,
71        }
72    }
73
74    /// Start the client's background receive loop.
75    pub fn start(&self) -> RpcClientHandle {
76        self.running.store(true, Ordering::Release);
77
78        let transport = self.transport.clone();
79        let pending = self.pending.clone();
80        let stream_manager = self.stream_manager.clone();
81        let running = self.running.clone();
82
83        let handle = tokio::spawn(async move {
84            while running.load(Ordering::Acquire) {
85                match transport.recv().await {
86                    Ok(message) => match message.msg_type {
87                        MessageType::Reply => {
88                            if let Some(tx) = pending.lock().remove(&message.id) {
89                                let _ = tx.send(Ok(message));
90                            }
91                        }
92                        MessageType::Error => {
93                            // Route stream errors to StreamManager, others to pending
94                            if let Some(stream_id) = message.metadata.stream_id {
95                                let error_msg: String = BincodeCodec
96                                    .decode(&message.payload)
97                                    .unwrap_or_else(|_| "Unknown error".to_string());
98                                stream_manager.send_error(stream_id, error_msg);
99                            } else if let Some(tx) = pending.lock().remove(&message.id) {
100                                let _ = tx.send(Ok(message));
101                            }
102                        }
103                        MessageType::StreamChunk | MessageType::StreamEnd => {
104                            stream_manager.handle_message(&message);
105                        }
106                        _ => {}
107                    },
108                    Err(_) => {
109                        break;
110                    }
111                }
112            }
113        });
114
115        RpcClientHandle { handle }
116    }
117
118    pub fn transport(&self) -> Arc<T> {
119        self.transport.clone()
120    }
121
122    pub fn stream_manager(&self) -> Arc<StreamManager<C>> {
123        self.stream_manager.clone()
124    }
125
126    /// Make a typed RPC call.
127    pub async fn call<Req, Resp>(&self, method: &str, request: &Req) -> Result<Resp>
128    where
129        Req: Serialize,
130        Resp: for<'de> Deserialize<'de>,
131    {
132        self.call_with_timeout(method, request, self.default_timeout)
133            .await
134    }
135
136    /// Make a typed RPC call with custom timeout.
137    pub async fn call_with_timeout<Req, Resp>(
138        &self,
139        method: &str,
140        request: &Req,
141        timeout: Duration,
142    ) -> Result<Resp>
143    where
144        Req: Serialize,
145        Resp: for<'de> Deserialize<'de>,
146    {
147        let message: Message<C> = Message::call(method, request)?;
148        let msg_id = message.id;
149
150        let (tx, rx) = oneshot::channel();
151        self.pending.lock().insert(msg_id, tx);
152
153        if let Err(e) = self.transport.send(&message).await {
154            self.pending.lock().remove(&msg_id);
155            return Err(RpcError::Transport(e));
156        }
157
158        let response = tokio::time::timeout(timeout, rx)
159            .await
160            .map_err(|_| {
161                self.pending.lock().remove(&msg_id);
162                RpcError::Timeout(format!("Request {} timed out after {:?}", msg_id, timeout))
163            })?
164            .map_err(|_| RpcError::ConnectionClosed)??;
165
166        match response.msg_type {
167            MessageType::Reply => self.codec.decode(&response.payload),
168            MessageType::Error => {
169                let error_msg: String = self
170                    .codec
171                    .decode(&response.payload)
172                    .unwrap_or_else(|_| "Unknown error".to_string());
173                Err(RpcError::ServerError(error_msg))
174            }
175            _ => Err(RpcError::InvalidMessage(format!(
176                "Unexpected message type: {:?}",
177                response.msg_type
178            ))),
179        }
180    }
181
182    /// Initiate a server streaming RPC call.
183    pub async fn call_server_stream<Req, Resp>(
184        &self,
185        method: &str,
186        request: &Req,
187    ) -> Result<StreamReceiver<Resp, C>>
188    where
189        Req: Serialize,
190        Resp: for<'de> Deserialize<'de>,
191    {
192        let stream_id = next_stream_id();
193        let receiver = self.stream_manager.create_receiver::<Resp>(stream_id);
194
195        let mut message: Message<C> = Message::call(method, request)?;
196        message.metadata = message.metadata.with_stream(stream_id, 0);
197
198        self.transport
199            .send(&message)
200            .await
201            .map_err(RpcError::Transport)?;
202
203        Ok(receiver)
204    }
205
206    /// Send a one-way notification (no response expected).
207    pub async fn notify<Req: Serialize>(&self, method: &str, request: &Req) -> Result<()> {
208        let message: Message<C> = Message::notification(method, request)?;
209        self.transport
210            .send(&message)
211            .await
212            .map_err(RpcError::Transport)
213    }
214
215    /// Make a raw bytes RPC call.
216    pub async fn call_raw(&self, method: &str, payload: Vec<u8>) -> Result<Vec<u8>> {
217        self.call_raw_with_timeout(method, payload, self.default_timeout)
218            .await
219    }
220
221    /// Make a raw bytes RPC call with custom timeout.
222    pub async fn call_raw_with_timeout(
223        &self,
224        method: &str,
225        payload: Vec<u8>,
226        timeout: Duration,
227    ) -> Result<Vec<u8>> {
228        let message: Message<C> = Message::new(
229            MessageId::new(),
230            MessageType::Call,
231            method,
232            payload.into(),
233            Default::default(),
234        );
235        let msg_id = message.id;
236
237        let (tx, rx) = oneshot::channel();
238        self.pending.lock().insert(msg_id, tx);
239
240        if let Err(e) = self.transport.send(&message).await {
241            self.pending.lock().remove(&msg_id);
242            return Err(RpcError::Transport(e));
243        }
244
245        let response = tokio::time::timeout(timeout, rx)
246            .await
247            .map_err(|_| {
248                self.pending.lock().remove(&msg_id);
249                RpcError::Timeout(format!("Request {} timed out after {:?}", msg_id, timeout))
250            })?
251            .map_err(|_| RpcError::ConnectionClosed)??;
252
253        match response.msg_type {
254            MessageType::Reply => Ok(response.payload.to_vec()),
255            MessageType::Error => {
256                let error_msg: String = self
257                    .codec
258                    .decode(&response.payload)
259                    .unwrap_or_else(|_| "Unknown error".to_string());
260                Err(RpcError::ServerError(error_msg))
261            }
262            _ => Err(RpcError::InvalidMessage(format!(
263                "Unexpected message type: {:?}",
264                response.msg_type
265            ))),
266        }
267    }
268
269    pub fn is_connected(&self) -> bool {
270        self.transport.is_connected()
271    }
272
273    pub fn active_streams(&self) -> usize {
274        self.stream_manager.active_stream_count()
275    }
276
277    pub async fn close(&self) -> Result<()> {
278        self.running.store(false, Ordering::Release);
279        self.transport.close().await.map_err(RpcError::Transport)
280    }
281}
282
283impl<T, C> Debug for RpcClient<T, C>
284where
285    T: MessageChannel<C>,
286    C: Codec + Clone,
287{
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        f.debug_struct("RpcClient")
290            .field("running", &self.running.load(Ordering::Relaxed))
291            .field("pending_requests", &self.pending.lock().len())
292            .field("active_streams", &self.stream_manager.active_stream_count())
293            .finish()
294    }
295}
296
297/// Handle for the client's background task.
298pub struct RpcClientHandle {
299    handle: tokio::task::JoinHandle<()>,
300}
301
302impl RpcClientHandle {
303    pub async fn shutdown(self) {
304        self.handle.abort();
305        let _ = self.handle.await;
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::channel::message::MessageChannelAdapter;
313    use crate::transport::channel::{ChannelConfig, ChannelFrameTransport};
314
315    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
316    struct AddRequest {
317        a: i32,
318        b: i32,
319    }
320
321    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
322    struct AddResponse {
323        result: i32,
324    }
325
326    #[tokio::test]
327    async fn test_client_call_reply() {
328        let config = ChannelConfig::default();
329        let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
330
331        let client_channel = MessageChannelAdapter::new(t1);
332        let server_channel = MessageChannelAdapter::new(t2);
333
334        let client = RpcClient::new(client_channel);
335        let _handle = client.start();
336
337        let server_handle = tokio::spawn(async move {
338            let msg = server_channel.recv().await.unwrap();
339            assert_eq!(msg.method, "add");
340
341            let req: AddRequest = msg.deserialize_payload().unwrap();
342            let resp = AddResponse {
343                result: req.a + req.b,
344            };
345
346            let reply: Message = Message::reply(msg.id, resp).unwrap();
347            server_channel.send(&reply).await.unwrap();
348        });
349
350        let response: AddResponse = client
351            .call("add", &AddRequest { a: 10, b: 32 })
352            .await
353            .unwrap();
354        assert_eq!(response.result, 42);
355
356        server_handle.await.unwrap();
357    }
358
359    #[tokio::test]
360    async fn test_client_server_stream() {
361        let config = ChannelConfig::default();
362        let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
363
364        let client_channel = MessageChannelAdapter::new(t1);
365        let server_channel = Arc::new(MessageChannelAdapter::new(t2));
366
367        let client = RpcClient::new(client_channel);
368        let _handle = client.start();
369
370        let server_channel_clone = server_channel.clone();
371        let server_handle = tokio::spawn(async move {
372            let msg = server_channel_clone.recv().await.unwrap();
373            let stream_id = msg.metadata.stream_id.unwrap();
374
375            for i in 1..=3 {
376                let chunk: Message = Message::stream_chunk(stream_id, i - 1, i as i32).unwrap();
377                server_channel_clone.send(&chunk).await.unwrap();
378            }
379
380            let end: Message = Message::stream_end(stream_id);
381            server_channel_clone.send(&end).await.unwrap();
382        });
383
384        let mut stream: StreamReceiver<i32> =
385            client.call_server_stream("get_numbers", &()).await.unwrap();
386
387        let mut items = Vec::new();
388        while let Some(result) = stream.recv().await {
389            items.push(result.unwrap());
390        }
391
392        assert_eq!(items, vec![1, 2, 3]);
393        server_handle.await.unwrap();
394    }
395
396    #[tokio::test]
397    async fn test_client_stream_error() {
398        let config = ChannelConfig::default();
399        let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
400
401        let client_channel = MessageChannelAdapter::new(t1);
402        let server_channel = MessageChannelAdapter::new(t2);
403
404        let client = RpcClient::new(client_channel);
405        let _handle = client.start();
406
407        let server_handle = tokio::spawn(async move {
408            let msg = server_channel.recv().await.unwrap();
409            let stream_id = msg.metadata.stream_id.unwrap();
410
411            let error: Message = Message::stream_error(msg.id, stream_id, "method not found");
412            server_channel.send(&error).await.unwrap();
413        });
414
415        let mut stream: StreamReceiver<i32> = client
416            .call_server_stream("unknown_method", &())
417            .await
418            .unwrap();
419
420        let result = stream.recv().await;
421        assert!(result.is_some());
422        assert!(result.unwrap().is_err());
423
424        server_handle.await.unwrap();
425    }
426}