1use async_trait::async_trait;
2use bytes::Bytes;
3use futures::Stream;
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10
11use crate::channel::message::MessageChannel;
12use crate::codec::{BincodeCodec, Codec};
13use crate::error::{Result, RpcError};
14use crate::message::Message;
15use crate::message::metadata::MessageMetadata;
16use crate::message::types::{MessageId, MessageType};
17use crate::streaming::{StreamId, next_stream_id};
18
19#[async_trait]
20pub trait Handler<C: Codec>: Send + Sync {
21 async fn handle(&self, request: Message<C>, codec: &C) -> Result<Message<C>>;
22 fn method_name(&self) -> &str;
23}
24
25#[async_trait]
26pub trait StreamHandler<C: Codec>: Send + Sync {
27 async fn handle(
28 &self,
29 request: Message<C>,
30 sender: ServerStreamSender<C>,
31 codec: &C,
32 ) -> Result<()>;
33 fn method_name(&self) -> &str;
34}
35
36pub struct FnHandler<F, C> {
37 method: String,
38 func: Arc<F>,
39 _codec: std::marker::PhantomData<C>,
40}
41
42impl<F, Fut, C> FnHandler<F, C>
43where
44 F: Fn(Message<C>) -> Fut + Send + Sync + 'static,
45 Fut: Future<Output = Result<Message<C>>> + Send + 'static,
46 C: Codec,
47{
48 pub fn new(method: impl Into<String>, func: F) -> Self {
49 Self {
50 method: method.into(),
51 func: Arc::new(func),
52 _codec: std::marker::PhantomData,
53 }
54 }
55}
56
57#[async_trait]
58impl<F, Fut, C: Codec + Default> Handler<C> for FnHandler<F, C>
59where
60 F: Fn(Message<C>) -> Fut + Send + Sync + 'static,
61 Fut: Future<Output = Result<Message<C>>> + Send + 'static,
62{
63 async fn handle(&self, request: Message<C>, _codec: &C) -> Result<Message<C>> {
64 (self.func)(request).await
65 }
66
67 fn method_name(&self) -> &str {
68 &self.method
69 }
70}
71
72pub struct TypedHandler<Req, Resp, F, C> {
73 method: String,
74 func: Arc<F>,
75 _phantom: std::marker::PhantomData<(Req, Resp, C)>,
76}
77
78impl<Req, Resp, F, Fut, C> TypedHandler<Req, Resp, F, C>
79where
80 Req: for<'de> Deserialize<'de> + Send + 'static,
81 Resp: Serialize + Send + 'static,
82 F: Fn(Req) -> Fut + Send + Sync + 'static,
83 Fut: Future<Output = Result<Resp>> + Send + 'static,
84 C: Codec,
85{
86 pub fn new(method: impl Into<String>, func: F) -> Self {
87 Self {
88 method: method.into(),
89 func: Arc::new(func),
90 _phantom: std::marker::PhantomData,
91 }
92 }
93}
94
95#[async_trait]
96impl<Req, Resp, F, Fut, C> Handler<C> for TypedHandler<Req, Resp, F, C>
97where
98 Req: for<'de> Deserialize<'de> + Send + Sync + 'static,
99 Resp: Serialize + Send + Sync + 'static,
100 F: Fn(Req) -> Fut + Send + Sync + 'static,
101 Fut: Future<Output = Result<Resp>> + Send + 'static,
102 C: Codec + Default,
103{
104 async fn handle(&self, request: Message<C>, codec: &C) -> Result<Message<C>> {
105 let req: Req = codec.decode(&request.payload)?;
106 let resp = (self.func)(req).await?;
107 let payload = codec.encode(&resp)?;
108 Ok(Message::new(
109 request.id,
110 MessageType::Reply,
111 "",
112 Bytes::from(payload),
113 MessageMetadata::new(),
114 ))
115 }
116
117 fn method_name(&self) -> &str {
118 &self.method
119 }
120}
121
122pub struct ServerStreamSender<C: Codec> {
123 stream_id: StreamId,
124 tx: mpsc::UnboundedSender<Bytes>,
125 sequence: std::sync::atomic::AtomicU64,
126 codec: C,
127}
128
129impl<C: Codec> ServerStreamSender<C> {
130 fn new(stream_id: StreamId, tx: mpsc::UnboundedSender<Bytes>, codec: C) -> Self {
131 Self {
132 stream_id,
133 tx,
134 sequence: std::sync::atomic::AtomicU64::new(0),
135 codec,
136 }
137 }
138
139 pub fn stream_id(&self) -> StreamId {
140 self.stream_id
141 }
142
143 pub fn send<T: Serialize>(&self, data: T) -> Result<()> {
144 let seq = self
145 .sequence
146 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147 let payload = self.codec.encode(&data)?;
148 let chunk: Message = Message::new(
149 MessageId::new(),
150 MessageType::StreamChunk,
151 "",
152 Bytes::from(payload),
153 MessageMetadata::new().with_stream(self.stream_id, seq),
154 );
155 let encoded = chunk.encode().map_err(RpcError::Transport)?;
156
157 self.tx
158 .send(encoded.freeze())
159 .map_err(|_| RpcError::StreamError("Stream closed".to_string()))
160 }
161
162 pub fn end(&self) -> Result<()> {
163 let end_msg: Message = Message::stream_end(self.stream_id);
164 let encoded = end_msg.encode().map_err(RpcError::Transport)?;
165
166 self.tx
167 .send(encoded.freeze())
168 .map_err(|_| RpcError::StreamError("Stream closed".to_string()))
169 }
170}
171
172pub struct TypedStreamHandler<Req, Item, F, C> {
173 method: String,
174 func: Arc<F>,
175 _phantom: std::marker::PhantomData<(Req, Item, C)>,
176}
177
178impl<Req, Item, F, S, C> TypedStreamHandler<Req, Item, F, C>
179where
180 Req: for<'de> Deserialize<'de> + Send + 'static,
181 Item: Serialize + Send + 'static,
182 S: Stream<Item = Result<Item>> + Send + 'static,
183 F: Fn(Req) -> S + Send + Sync + 'static,
184 C: Codec,
185{
186 pub fn new(method: impl Into<String>, func: F) -> Self {
187 Self {
188 method: method.into(),
189 func: Arc::new(func),
190 _phantom: std::marker::PhantomData,
191 }
192 }
193}
194
195#[async_trait]
196impl<Req, Item, F, S, C> StreamHandler<C> for TypedStreamHandler<Req, Item, F, C>
197where
198 Req: for<'de> Deserialize<'de> + Send + Sync + 'static,
199 Item: Serialize + Send + Sync + 'static,
200 S: Stream<Item = Result<Item>> + Send + 'static,
201 F: Fn(Req) -> S + Send + Sync + 'static,
202 C: Codec + Default,
203{
204 async fn handle(
205 &self,
206 request: Message<C>,
207 sender: ServerStreamSender<C>,
208 codec: &C,
209 ) -> Result<()> {
210 use futures::StreamExt;
211
212 let req: Req = codec.decode(&request.payload)?;
213 let mut stream = Box::pin((self.func)(req));
214
215 while let Some(result) = stream.next().await {
216 match result {
217 Ok(item) => sender.send(item)?,
218 Err(e) => return Err(e),
219 }
220 }
221
222 sender.end()?;
223 Ok(())
224 }
225
226 fn method_name(&self) -> &str {
227 &self.method
228 }
229}
230
231pub struct FnStreamHandler<F, C> {
232 method: String,
233 func: Arc<F>,
234 _codec: std::marker::PhantomData<C>,
235}
236
237impl<F, Fut, C> FnStreamHandler<F, C>
238where
239 F: Fn(Message<C>, ServerStreamSender<C>) -> Fut + Send + Sync + 'static,
240 Fut: Future<Output = Result<()>> + Send + 'static,
241 C: Codec,
242{
243 pub fn new(method: impl Into<String>, func: F) -> Self {
244 Self {
245 method: method.into(),
246 func: Arc::new(func),
247 _codec: std::marker::PhantomData,
248 }
249 }
250}
251
252#[async_trait]
253impl<F, Fut, C> StreamHandler<C> for FnStreamHandler<F, C>
254where
255 F: Fn(Message<C>, ServerStreamSender<C>) -> Fut + Send + Sync + 'static,
256 Fut: Future<Output = Result<()>> + Send + 'static,
257 C: Codec + Default,
258{
259 async fn handle(
260 &self,
261 request: Message<C>,
262 sender: ServerStreamSender<C>,
263 _codec: &C,
264 ) -> Result<()> {
265 (self.func)(request, sender).await
266 }
267
268 fn method_name(&self) -> &str {
269 &self.method
270 }
271}
272
273pub struct RpcServer<C: Codec = BincodeCodec> {
274 handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler<C>>>>>,
275 stream_handlers: Arc<RwLock<HashMap<String, Arc<dyn StreamHandler<C>>>>>,
276 codec: C,
277}
278
279impl RpcServer<BincodeCodec> {
280 pub fn new() -> Self {
281 Self {
282 handlers: Arc::new(RwLock::new(HashMap::new())),
283 stream_handlers: Arc::new(RwLock::new(HashMap::new())),
284 codec: BincodeCodec,
285 }
286 }
287}
288
289impl<C: Codec + Clone + Default + 'static> RpcServer<C> {
290 pub fn with_codec(codec: C) -> Self {
291 Self {
292 handlers: Arc::new(RwLock::new(HashMap::new())),
293 stream_handlers: Arc::new(RwLock::new(HashMap::new())),
294 codec,
295 }
296 }
297
298 pub fn register(&self, handler: Arc<dyn Handler<C>>) {
299 let method = handler.method_name().to_string();
300 self.handlers.write().insert(method, handler);
301 }
302
303 pub fn register_fn<F, Fut>(&self, method: impl Into<String>, func: F)
304 where
305 F: Fn(Message<C>) -> Fut + Send + Sync + 'static,
306 Fut: Future<Output = Result<Message<C>>> + Send + 'static,
307 {
308 let handler: Arc<FnHandler<F, C>> = Arc::new(FnHandler::new(method, func));
309 self.register(handler);
310 }
311
312 pub fn register_typed<Req, Resp, F, Fut>(&self, method: impl Into<String>, func: F)
313 where
314 Req: for<'de> Deserialize<'de> + Send + Sync + 'static,
315 Resp: Serialize + Send + Sync + 'static,
316 F: Fn(Req) -> Fut + Send + Sync + 'static,
317 Fut: Future<Output = Result<Resp>> + Send + 'static,
318 {
319 let handler: Arc<TypedHandler<Req, Resp, F, C>> = Arc::new(TypedHandler::new(method, func));
320 self.register(handler);
321 }
322
323 pub fn register_stream<Req, Item, F, S>(&self, method: impl Into<String>, func: F)
324 where
325 Req: for<'de> Deserialize<'de> + Send + Sync + 'static,
326 Item: Serialize + Send + Sync + 'static,
327 S: Stream<Item = Result<Item>> + Send + 'static,
328 F: Fn(Req) -> S + Send + Sync + 'static,
329 {
330 let method = method.into();
331 let handler: Arc<TypedStreamHandler<Req, Item, F, C>> =
332 Arc::new(TypedStreamHandler::new(method.clone(), func));
333 self.stream_handlers.write().insert(method, handler);
334 }
335
336 pub fn register_stream_fn<F, Fut>(&self, method: impl Into<String>, func: F)
337 where
338 F: Fn(Message<C>, ServerStreamSender<C>) -> Fut + Send + Sync + 'static,
339 Fut: Future<Output = Result<()>> + Send + 'static,
340 {
341 let method = method.into();
342 let handler: Arc<FnStreamHandler<F, C>> =
343 Arc::new(FnStreamHandler::new(method.clone(), func));
344 self.stream_handlers.write().insert(method, handler);
345 }
346
347 pub async fn handle_message<T: MessageChannel<C>>(
348 &self,
349 message: Message<C>,
350 transport: &T,
351 ) -> Option<Message<C>> {
352 match message.msg_type {
353 MessageType::Call => {
354 if message.metadata.stream_id.is_some() {
355 self.handle_stream_call(message, transport).await;
356 return None;
357 }
358
359 let handler = self.handlers.read().get(&message.method).cloned();
360 match handler {
361 Some(h) => match h.handle(message.clone(), &self.codec).await {
362 Ok(response) => Some(response),
363 Err(e) => Some(Message::error(message.id, e.to_string())),
364 },
365 None => Some(Message::error(
366 message.id,
367 format!("Method not found: {}", message.method),
368 )),
369 }
370 }
371 MessageType::Notification => {
372 let handler = self.handlers.read().get(&message.method).cloned();
373 if let Some(h) = handler {
374 let _ = h.handle(message, &self.codec).await;
375 }
376 None
377 }
378 _ => None,
379 }
380 }
381
382 async fn handle_stream_call<T: MessageChannel<C>>(&self, message: Message<C>, transport: &T) {
383 let stream_id = message.metadata.stream_id.unwrap_or_else(next_stream_id);
384 let handler = self.stream_handlers.read().get(&message.method).cloned();
385
386 let Some(h) = handler else {
387 let error = Message::stream_error(
388 message.id,
389 stream_id,
390 format!("Stream method not found: {}", message.method),
391 );
392 let _ = transport.send(&error).await;
393 return;
394 };
395
396 let (tx, mut rx) = mpsc::unbounded_channel::<Bytes>();
397 let sender = ServerStreamSender::new(stream_id, tx, self.codec.clone());
398
399 let transport_send = async {
400 while let Some(data) = rx.recv().await {
401 if let Ok(msg) = Message::<C>::decode(&data[..]) {
402 let _ = transport.send(&msg).await;
403 }
404 }
405 };
406
407 let codec = self.codec.clone();
408 let handler_task = async {
409 if let Err(e) = h.handle(message.clone(), sender, &codec).await {
410 let error = Message::stream_error(message.id, stream_id, e.to_string());
411 let _ = transport.send(&error).await;
412 }
413 };
414
415 tokio::join!(handler_task, transport_send);
416 }
417
418 pub async fn serve<T: MessageChannel<C>>(&self, transport: Arc<T>) -> Result<()> {
419 loop {
420 let message = transport.recv().await.map_err(RpcError::Transport)?;
421
422 if let Some(response) = self.handle_message(message, transport.as_ref()).await {
423 transport
424 .send(&response)
425 .await
426 .map_err(RpcError::Transport)?;
427 }
428 }
429 }
430
431 pub fn spawn_handler<T: MessageChannel<C> + 'static>(&self, transport: T) -> ServerHandle {
432 let handlers = self.handlers.clone();
433 let stream_handlers = self.stream_handlers.clone();
434 let codec = self.codec.clone();
435 let transport = Arc::new(transport);
436
437 let handle = tokio::spawn(async move {
438 let server = RpcServer {
439 handlers,
440 stream_handlers,
441 codec,
442 };
443 let _ = server.serve(transport).await;
444 });
445
446 ServerHandle { handle }
447 }
448
449 pub fn handler_count(&self) -> usize {
450 self.handlers.read().len() + self.stream_handlers.read().len()
451 }
452}
453
454impl Default for RpcServer<BincodeCodec> {
455 fn default() -> Self {
456 Self::new()
457 }
458}
459
460pub struct ServerHandle {
461 handle: tokio::task::JoinHandle<()>,
462}
463
464impl ServerHandle {
465 pub async fn shutdown(self) {
466 self.handle.abort();
467 let _ = self.handle.await;
468 }
469
470 pub fn is_finished(&self) -> bool {
471 self.handle.is_finished()
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use crate::channel::message::MessageChannelAdapter;
479 use crate::streaming::StreamReceiver;
480 use crate::transport::channel::{ChannelConfig, ChannelFrameTransport};
481
482 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
483 struct AddRequest {
484 a: i32,
485 b: i32,
486 }
487
488 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
489 struct AddResponse {
490 result: i32,
491 }
492
493 #[tokio::test]
494 async fn test_server_typed_handler() {
495 let config = ChannelConfig::default();
496 let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
497
498 let client_channel = MessageChannelAdapter::new(t1);
499 let server_channel = MessageChannelAdapter::new(t2);
500
501 let server = RpcServer::new();
502 server.register_typed("add", |req: AddRequest| async move {
503 Ok(AddResponse {
504 result: req.a + req.b,
505 })
506 });
507
508 let _handle = server.spawn_handler(server_channel);
509
510 let request: Message = Message::call("add", AddRequest { a: 10, b: 32 }).unwrap();
511 client_channel.send(&request).await.unwrap();
512
513 let response = client_channel.recv().await.unwrap();
514 assert_eq!(response.msg_type, MessageType::Reply);
515
516 let resp: AddResponse = response.deserialize_payload().unwrap();
517 assert_eq!(resp.result, 42);
518 }
519
520 #[tokio::test]
521 async fn test_server_stream_handler() {
522 let config = ChannelConfig::default();
523 let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
524
525 let client_channel = Arc::new(MessageChannelAdapter::new(t1));
526 let server_channel = MessageChannelAdapter::new(t2);
527
528 let server = RpcServer::new();
529 server.register_stream("range", |count: i32| {
530 futures::stream::iter((1..=count).map(|i| Ok(i)))
531 });
532
533 let _handle = server.spawn_handler(server_channel);
534
535 let stream_id = next_stream_id();
536 let mut request: Message = Message::call("range", 5i32).unwrap();
537 request.metadata = request.metadata.with_stream(stream_id, 0);
538
539 let manager = crate::streaming::StreamManager::new();
540 let mut receiver: StreamReceiver<i32> = manager.create_receiver(stream_id);
541
542 client_channel.send(&request).await.unwrap();
543
544 let client_channel_clone = client_channel.clone();
545 let recv_task = tokio::spawn(async move {
546 loop {
547 match client_channel_clone.recv().await {
548 Ok(msg) => {
549 if msg.msg_type == MessageType::StreamEnd {
550 manager.handle_message(&msg);
551 break;
552 }
553 manager.handle_message(&msg);
554 }
555 Err(_) => break,
556 }
557 }
558 });
559
560 let mut items = Vec::new();
561 while let Some(result) = receiver.recv().await {
562 items.push(result.unwrap());
563 }
564
565 recv_task.await.unwrap();
566 assert_eq!(items, vec![1, 2, 3, 4, 5]);
567 }
568
569 #[tokio::test]
570 async fn test_server_notification() {
571 use std::sync::atomic::{AtomicBool, Ordering};
572
573 let config = ChannelConfig::default();
574 let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
575
576 let client_channel = MessageChannelAdapter::new(t1);
577 let server_channel = MessageChannelAdapter::new(t2);
578
579 let called = Arc::new(AtomicBool::new(false));
580 let called_clone = called.clone();
581
582 let server = RpcServer::new();
583 server.register_fn("log", move |_msg: Message| {
584 let called = called_clone.clone();
585 async move {
586 called.store(true, Ordering::Release);
587 Ok(Message::reply(MessageId::new(), ())?)
588 }
589 });
590
591 let _handle = server.spawn_handler(server_channel);
592
593 let notification: Message = Message::notification("log", "test").unwrap();
594 client_channel.send(¬ification).await.unwrap();
595
596 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
597 assert!(called.load(Ordering::Acquire));
598 }
599}