Skip to main content

synapse_rpc/
registry.rs

1//! Interface and instance registry
2//!
3//! The registry tracks which interfaces are available and routes
4//! incoming RPC requests to the appropriate handlers.
5
6use async_trait::async_trait;
7use bytes::Bytes;
8use std::{
9    collections::{HashMap, HashSet},
10    sync::Arc,
11};
12use synapse_primitives::{InstanceId, InterfaceId, MethodId};
13use synapse_proto::{RpcResponse, RpcStatus};
14use tokio::sync::RwLock;
15
16/// Handler trait for RPC methods
17#[async_trait]
18pub trait RpcHandler: Send + Sync {
19    async fn handle(&self, request: synapse_proto::RpcRequest) -> RpcResponse;
20}
21
22/// Registration info for an interface
23#[derive(Debug, Clone)]
24pub struct InterfaceRegistration {
25    pub interface_id: InterfaceId,
26    pub interface_version: u32,
27    pub method_ids: HashSet<MethodId>,
28    pub method_names: Vec<String>,
29    pub instance_id: InstanceId,
30    pub service_name: String,
31    pub interface_name: String,
32}
33
34impl InterfaceRegistration {
35    /// Create a new interface registration
36    pub fn new(
37        interface_name: &str,
38        service_name: &str,
39        method_names: &[&str],
40        instance_id: InstanceId,
41    ) -> Self {
42        Self {
43            interface_id: InterfaceId::from_name(interface_name),
44            interface_version: 1_000_000, // 1.0.0
45            method_ids: method_names
46                .iter()
47                .map(|n| MethodId::from_name(n))
48                .collect(),
49            method_names: method_names.iter().map(|n| n.to_string()).collect(),
50            instance_id,
51            service_name: service_name.to_string(),
52            interface_name: interface_name.to_string(),
53        }
54    }
55
56    /// Set the interface version
57    pub fn with_version(mut self, major: u16, minor: u16, patch: u16) -> Self {
58        self.interface_version =
59            (major as u32) * 1_000_000 + (minor as u32) * 1_000 + (patch as u32);
60        self
61    }
62}
63
64/// Registry for interfaces and their handlers
65pub struct InterfaceRegistry {
66    interfaces: Arc<RwLock<HashMap<InterfaceId, InterfaceEntry>>>,
67}
68
69struct InterfaceEntry {
70    registration: InterfaceRegistration,
71    handler: Arc<dyn RpcHandler>,
72}
73
74impl InterfaceRegistry {
75    pub fn new() -> Self {
76        Self {
77            interfaces: Arc::new(RwLock::new(HashMap::new())),
78        }
79    }
80
81    /// Register an interface with its handler
82    pub async fn register(
83        &self,
84        registration: InterfaceRegistration,
85        handler: Arc<dyn RpcHandler>,
86    ) -> Result<(), String> {
87        let mut interfaces = self.interfaces.write().await;
88
89        if interfaces.contains_key(&registration.interface_id) {
90            return Err(format!(
91                "Interface {} already registered",
92                registration.interface_name
93            ));
94        }
95
96        interfaces.insert(
97            registration.interface_id,
98            InterfaceEntry {
99                registration,
100                handler,
101            },
102        );
103
104        Ok(())
105    }
106
107    /// Deregister an interface
108    pub async fn deregister(&self, interface_id: InterfaceId) -> Result<(), String> {
109        let mut interfaces = self.interfaces.write().await;
110
111        if interfaces.remove(&interface_id).is_none() {
112            return Err(format!("Interface {:?} not found", interface_id));
113        }
114
115        Ok(())
116    }
117
118    /// Route a request to the appropriate handler
119    pub async fn route(&self, request: synapse_proto::RpcRequest) -> RpcResponse {
120        let interfaces = self.interfaces.read().await;
121
122        let interface_id = InterfaceId::from_raw(request.interface_id);
123        let method_id = MethodId::from_raw(request.method_id);
124
125        match interfaces.get(&interface_id) {
126            Some(entry) => {
127                // Check if method is supported
128                if !entry.registration.method_ids.contains(&method_id) {
129                    return RpcResponse {
130                        status: RpcStatus::MethodNotFound as i32,
131                        payload: Bytes::new(),
132                        error: Some(synapse_proto::RpcError {
133                            code: 1004,
134                            message: format!("Method {:?} not found", method_id),
135                            details: vec![],
136                        }),
137                        headers: vec![],
138                        responded_at_unix_ms: chrono::Utc::now().timestamp_millis(),
139                    };
140                }
141
142                // Call the handler
143                entry.handler.handle(request).await
144            }
145            None => {
146                // Interface not found error
147                RpcResponse {
148                    status: RpcStatus::InterfaceNotFound as i32,
149                    payload: Bytes::new(),
150                    error: Some(synapse_proto::RpcError {
151                        code: 1003,
152                        message: format!("Interface {:?} not found", interface_id),
153                        details: vec![],
154                    }),
155                    headers: vec![],
156                    responded_at_unix_ms: chrono::Utc::now().timestamp_millis(),
157                }
158            }
159        }
160    }
161
162    /// List all registered interfaces
163    pub async fn list_interfaces(&self) -> Vec<InterfaceRegistration> {
164        let interfaces = self.interfaces.read().await;
165        interfaces
166            .values()
167            .map(|entry| entry.registration.clone())
168            .collect()
169    }
170
171    /// Check if an interface is registered
172    pub async fn has_interface(&self, interface_id: InterfaceId) -> bool {
173        let interfaces = self.interfaces.read().await;
174        interfaces.contains_key(&interface_id)
175    }
176}
177
178impl Default for InterfaceRegistry {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184/// Simple function-based RPC handler
185pub struct FunctionHandler<F>
186where
187    F: Fn(synapse_proto::RpcRequest) -> futures::future::BoxFuture<'static, RpcResponse>
188        + Send
189        + Sync,
190{
191    func: F,
192}
193
194impl<F> FunctionHandler<F>
195where
196    F: Fn(synapse_proto::RpcRequest) -> futures::future::BoxFuture<'static, RpcResponse>
197        + Send
198        + Sync,
199{
200    pub fn new(func: F) -> Self {
201        Self { func }
202    }
203}
204
205#[async_trait]
206impl<F> RpcHandler for FunctionHandler<F>
207where
208    F: Fn(synapse_proto::RpcRequest) -> futures::future::BoxFuture<'static, RpcResponse>
209        + Send
210        + Sync,
211{
212    async fn handle(&self, request: synapse_proto::RpcRequest) -> RpcResponse {
213        (self.func)(request).await
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[tokio::test]
222    async fn test_registry_basic() {
223        let registry = InterfaceRegistry::new();
224
225        let registration = InterfaceRegistration::new(
226            "test.Interface",
227            "test-service",
228            &["TestMethod"],
229            InstanceId::new_random(),
230        );
231
232        let interface_id = registration.interface_id;
233        let method_id = *registration.method_ids.iter().next().unwrap();
234
235        // Create a simple handler
236        let handler = Arc::new(FunctionHandler::new(move |_req| {
237            Box::pin(async move {
238                RpcResponse {
239                    status: RpcStatus::Ok as i32,
240                    payload: Bytes::from("test response"),
241                    error: None,
242                    headers: vec![],
243                    responded_at_unix_ms: chrono::Utc::now().timestamp_millis(),
244                }
245            })
246        }));
247
248        // Register interface
249        registry.register(registration, handler).await.unwrap();
250
251        // Test routing
252        let request = synapse_proto::RpcRequest {
253            interface_id: interface_id.into(),
254            method_id: method_id.into(),
255            headers: vec![],
256            payload: Bytes::from("test payload"),
257            sent_at_unix_ms: chrono::Utc::now().timestamp_millis(),
258        };
259
260        let response = registry.route(request).await;
261        assert_eq!(response.status, RpcStatus::Ok as i32);
262    }
263
264    #[tokio::test]
265    async fn test_interface_not_found() {
266        let registry = InterfaceRegistry::new();
267
268        let request = synapse_proto::RpcRequest {
269            interface_id: InterfaceId::from_name("nonexistent.Interface").into(),
270            method_id: MethodId::from_name("Method").into(),
271            headers: vec![],
272            payload: Bytes::new(),
273            sent_at_unix_ms: 0,
274        };
275
276        let response = registry.route(request).await;
277        assert_eq!(response.status, RpcStatus::InterfaceNotFound as i32);
278    }
279
280    #[tokio::test]
281    async fn test_method_not_found() {
282        let registry = InterfaceRegistry::new();
283
284        let registration = InterfaceRegistration::new(
285            "test.Interface",
286            "test-service",
287            &["MethodA"],
288            InstanceId::new_random(),
289        );
290
291        let interface_id = registration.interface_id;
292
293        let handler = Arc::new(FunctionHandler::new(move |_req| {
294            Box::pin(async move {
295                RpcResponse {
296                    status: RpcStatus::Ok as i32,
297                    payload: Bytes::new(),
298                    error: None,
299                    headers: vec![],
300                    responded_at_unix_ms: 0,
301                }
302            })
303        }));
304
305        registry.register(registration, handler).await.unwrap();
306
307        // Try calling a method that doesn't exist
308        let request = synapse_proto::RpcRequest {
309            interface_id: interface_id.into(),
310            method_id: MethodId::from_name("MethodB").into(), // Not registered
311            headers: vec![],
312            payload: Bytes::new(),
313            sent_at_unix_ms: 0,
314        };
315
316        let response = registry.route(request).await;
317        assert_eq!(response.status, RpcStatus::MethodNotFound as i32);
318    }
319}