synapse_rpc/
method_router.rs1use crate::{message::error_response, registry::RpcHandler};
4use async_trait::async_trait;
5use std::{collections::HashMap, sync::Arc};
6use synapse_primitives::MethodId;
7use synapse_proto::{RpcRequest, RpcResponse, RpcStatus};
8
9pub struct MethodRouter {
11 handlers: HashMap<MethodId, Arc<dyn RpcHandler>>,
12 default_handler: Option<Arc<dyn RpcHandler>>,
13}
14
15impl MethodRouter {
16 pub fn new() -> Self {
18 Self {
19 handlers: HashMap::new(),
20 default_handler: None,
21 }
22 }
23
24 pub fn method(mut self, method_id: MethodId, handler: Arc<dyn RpcHandler>) -> Self {
26 self.handlers.insert(method_id, handler);
27 self
28 }
29
30 pub fn default(mut self, handler: Arc<dyn RpcHandler>) -> Self {
32 self.default_handler = Some(handler);
33 self
34 }
35
36 pub fn build(self) -> Arc<Self> {
38 Arc::new(self)
39 }
40}
41
42impl Default for MethodRouter {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48#[async_trait]
49impl RpcHandler for MethodRouter {
50 async fn handle(&self, request: RpcRequest) -> RpcResponse {
51 let method_id = MethodId::from_raw(request.method_id);
52
53 if let Some(handler) = self.handlers.get(&method_id) {
55 return handler.handle(request).await;
56 }
57
58 if let Some(handler) = &self.default_handler {
60 return handler.handle(request).await;
61 }
62
63 error_response(
65 RpcStatus::MethodNotFound,
66 1004,
67 format!("Method {} not found", request.method_id),
68 )
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use crate::registry::FunctionHandler;
76 use bytes::Bytes;
77
78 fn make_request(method_id: u32) -> RpcRequest {
79 RpcRequest {
80 interface_id: 1,
81 method_id,
82 headers: vec![],
83 payload: Bytes::new(),
84 sent_at_unix_ms: 0,
85 }
86 }
87
88 fn echo_handler(tag: &'static str) -> Arc<dyn RpcHandler> {
89 Arc::new(FunctionHandler::new(move |_req| {
90 Box::pin(async move {
91 RpcResponse {
92 status: RpcStatus::Ok as i32,
93 payload: Bytes::from(tag),
94 error: None,
95 headers: vec![],
96 responded_at_unix_ms: 0,
97 }
98 })
99 }))
100 }
101
102 #[tokio::test]
103 async fn test_route_to_registered_method() {
104 let router = MethodRouter::new()
105 .method(MethodId::from_raw(1), echo_handler("method_1"))
106 .build();
107
108 let resp = router.handle(make_request(1)).await;
109 assert_eq!(resp.status, RpcStatus::Ok as i32);
110 assert_eq!(resp.payload, Bytes::from("method_1"));
111 }
112
113 #[tokio::test]
114 async fn test_route_multiple_methods() {
115 let router = MethodRouter::new()
116 .method(MethodId::from_raw(1), echo_handler("m1"))
117 .method(MethodId::from_raw(2), echo_handler("m2"))
118 .build();
119
120 let r1 = router.handle(make_request(1)).await;
121 let r2 = router.handle(make_request(2)).await;
122 assert_eq!(r1.payload, Bytes::from("m1"));
123 assert_eq!(r2.payload, Bytes::from("m2"));
124 }
125
126 #[tokio::test]
127 async fn test_default_handler_on_miss() {
128 let router = MethodRouter::new()
129 .method(MethodId::from_raw(1), echo_handler("m1"))
130 .default(echo_handler("fallback"))
131 .build();
132
133 let resp = router.handle(make_request(999)).await;
134 assert_eq!(resp.status, RpcStatus::Ok as i32);
135 assert_eq!(resp.payload, Bytes::from("fallback"));
136 }
137
138 #[tokio::test]
139 async fn test_method_not_found_without_default() {
140 let router = MethodRouter::new()
141 .method(MethodId::from_raw(1), echo_handler("m1"))
142 .build();
143
144 let resp = router.handle(make_request(999)).await;
145 assert_eq!(resp.status, RpcStatus::MethodNotFound as i32);
146 }
147
148 #[tokio::test]
149 async fn test_exact_match_preferred_over_default() {
150 let router = MethodRouter::new()
151 .method(MethodId::from_raw(1), echo_handler("exact"))
152 .default(echo_handler("default"))
153 .build();
154
155 let resp = router.handle(make_request(1)).await;
156 assert_eq!(resp.payload, Bytes::from("exact"));
157 }
158
159 #[test]
160 fn test_default_trait() {
161 let router = <MethodRouter as Default>::default();
162 assert!(router.handlers.is_empty());
163 assert!(router.default_handler.is_none());
164 }
165}