Skip to main content

synapse_rpc/
method_router.rs

1//! Method router for dispatching to multiple handlers per interface
2
3use crate::{message::error_response, registry::RpcHandler};
4use async_trait::async_trait;
5use std::{collections::HashMap, sync::Arc};
6use synapse_primitives::MethodId;
7use synapse_proto::{RpcRequest, RpcResponse, RpcStatus};
8
9/// Routes requests to different handlers based on method_id
10pub struct MethodRouter {
11    handlers: HashMap<MethodId, Arc<dyn RpcHandler>>,
12    default_handler: Option<Arc<dyn RpcHandler>>,
13}
14
15impl MethodRouter {
16    /// Create a new method router
17    pub fn new() -> Self {
18        Self {
19            handlers: HashMap::new(),
20            default_handler: None,
21        }
22    }
23
24    /// Register a handler for a specific method
25    pub fn method(mut self, method_id: MethodId, handler: Arc<dyn RpcHandler>) -> Self {
26        self.handlers.insert(method_id, handler);
27        self
28    }
29
30    /// Set a default handler for unmatched methods
31    pub fn default(mut self, handler: Arc<dyn RpcHandler>) -> Self {
32        self.default_handler = Some(handler);
33        self
34    }
35
36    /// Build the router into an RpcHandler
37    pub fn build(self) -> Arc<Self> {
38        Arc::new(self)
39    }
40}
41
42impl Default for MethodRouter {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48#[async_trait]
49impl RpcHandler for MethodRouter {
50    async fn handle(&self, request: RpcRequest) -> RpcResponse {
51        let method_id = MethodId::from_raw(request.method_id);
52
53        // Try to find handler for this method
54        if let Some(handler) = self.handlers.get(&method_id) {
55            return handler.handle(request).await;
56        }
57
58        // Try default handler
59        if let Some(handler) = &self.default_handler {
60            return handler.handle(request).await;
61        }
62
63        // No handler found
64        error_response(
65            RpcStatus::MethodNotFound,
66            1004,
67            format!("Method {} not found", request.method_id),
68        )
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::registry::FunctionHandler;
76    use bytes::Bytes;
77
78    fn make_request(method_id: u32) -> RpcRequest {
79        RpcRequest {
80            interface_id: 1,
81            method_id,
82            headers: vec![],
83            payload: Bytes::new(),
84            sent_at_unix_ms: 0,
85        }
86    }
87
88    fn echo_handler(tag: &'static str) -> Arc<dyn RpcHandler> {
89        Arc::new(FunctionHandler::new(move |_req| {
90            Box::pin(async move {
91                RpcResponse {
92                    status: RpcStatus::Ok as i32,
93                    payload: Bytes::from(tag),
94                    error: None,
95                    headers: vec![],
96                    responded_at_unix_ms: 0,
97                }
98            })
99        }))
100    }
101
102    #[tokio::test]
103    async fn test_route_to_registered_method() {
104        let router = MethodRouter::new()
105            .method(MethodId::from_raw(1), echo_handler("method_1"))
106            .build();
107
108        let resp = router.handle(make_request(1)).await;
109        assert_eq!(resp.status, RpcStatus::Ok as i32);
110        assert_eq!(resp.payload, Bytes::from("method_1"));
111    }
112
113    #[tokio::test]
114    async fn test_route_multiple_methods() {
115        let router = MethodRouter::new()
116            .method(MethodId::from_raw(1), echo_handler("m1"))
117            .method(MethodId::from_raw(2), echo_handler("m2"))
118            .build();
119
120        let r1 = router.handle(make_request(1)).await;
121        let r2 = router.handle(make_request(2)).await;
122        assert_eq!(r1.payload, Bytes::from("m1"));
123        assert_eq!(r2.payload, Bytes::from("m2"));
124    }
125
126    #[tokio::test]
127    async fn test_default_handler_on_miss() {
128        let router = MethodRouter::new()
129            .method(MethodId::from_raw(1), echo_handler("m1"))
130            .default(echo_handler("fallback"))
131            .build();
132
133        let resp = router.handle(make_request(999)).await;
134        assert_eq!(resp.status, RpcStatus::Ok as i32);
135        assert_eq!(resp.payload, Bytes::from("fallback"));
136    }
137
138    #[tokio::test]
139    async fn test_method_not_found_without_default() {
140        let router = MethodRouter::new()
141            .method(MethodId::from_raw(1), echo_handler("m1"))
142            .build();
143
144        let resp = router.handle(make_request(999)).await;
145        assert_eq!(resp.status, RpcStatus::MethodNotFound as i32);
146    }
147
148    #[tokio::test]
149    async fn test_exact_match_preferred_over_default() {
150        let router = MethodRouter::new()
151            .method(MethodId::from_raw(1), echo_handler("exact"))
152            .default(echo_handler("default"))
153            .build();
154
155        let resp = router.handle(make_request(1)).await;
156        assert_eq!(resp.payload, Bytes::from("exact"));
157    }
158
159    #[test]
160    fn test_default_trait() {
161        let router = <MethodRouter as Default>::default();
162        assert!(router.handlers.is_empty());
163        assert!(router.default_handler.is_none());
164    }
165}