1use async_trait::async_trait;
7use bytes::Bytes;
8use std::{
9 collections::{HashMap, HashSet},
10 sync::Arc,
11};
12use synapse_primitives::{InstanceId, InterfaceId, MethodId};
13use synapse_proto::{RpcResponse, RpcStatus};
14use tokio::sync::RwLock;
15
16#[async_trait]
18pub trait RpcHandler: Send + Sync {
19 async fn handle(&self, request: synapse_proto::RpcRequest) -> RpcResponse;
20}
21
22#[derive(Debug, Clone)]
24pub struct InterfaceRegistration {
25 pub interface_id: InterfaceId,
26 pub interface_version: u32,
27 pub method_ids: HashSet<MethodId>,
28 pub method_names: Vec<String>,
29 pub instance_id: InstanceId,
30 pub service_name: String,
31 pub interface_name: String,
32}
33
34impl InterfaceRegistration {
35 pub fn new(
37 interface_name: &str,
38 service_name: &str,
39 method_names: &[&str],
40 instance_id: InstanceId,
41 ) -> Self {
42 Self {
43 interface_id: InterfaceId::from_name(interface_name),
44 interface_version: 1_000_000, method_ids: method_names
46 .iter()
47 .map(|n| MethodId::from_name(n))
48 .collect(),
49 method_names: method_names.iter().map(|n| n.to_string()).collect(),
50 instance_id,
51 service_name: service_name.to_string(),
52 interface_name: interface_name.to_string(),
53 }
54 }
55
56 pub fn with_version(mut self, major: u16, minor: u16, patch: u16) -> Self {
58 self.interface_version =
59 (major as u32) * 1_000_000 + (minor as u32) * 1_000 + (patch as u32);
60 self
61 }
62}
63
64pub struct InterfaceRegistry {
66 interfaces: Arc<RwLock<HashMap<InterfaceId, InterfaceEntry>>>,
67}
68
69struct InterfaceEntry {
70 registration: InterfaceRegistration,
71 handler: Arc<dyn RpcHandler>,
72}
73
74impl InterfaceRegistry {
75 pub fn new() -> Self {
76 Self {
77 interfaces: Arc::new(RwLock::new(HashMap::new())),
78 }
79 }
80
81 pub async fn register(
83 &self,
84 registration: InterfaceRegistration,
85 handler: Arc<dyn RpcHandler>,
86 ) -> Result<(), String> {
87 let mut interfaces = self.interfaces.write().await;
88
89 if interfaces.contains_key(®istration.interface_id) {
90 return Err(format!(
91 "Interface {} already registered",
92 registration.interface_name
93 ));
94 }
95
96 interfaces.insert(
97 registration.interface_id,
98 InterfaceEntry {
99 registration,
100 handler,
101 },
102 );
103
104 Ok(())
105 }
106
107 pub async fn deregister(&self, interface_id: InterfaceId) -> Result<(), String> {
109 let mut interfaces = self.interfaces.write().await;
110
111 if interfaces.remove(&interface_id).is_none() {
112 return Err(format!("Interface {:?} not found", interface_id));
113 }
114
115 Ok(())
116 }
117
118 pub async fn route(&self, request: synapse_proto::RpcRequest) -> RpcResponse {
120 let interfaces = self.interfaces.read().await;
121
122 let interface_id = InterfaceId::from_raw(request.interface_id);
123 let method_id = MethodId::from_raw(request.method_id);
124
125 match interfaces.get(&interface_id) {
126 Some(entry) => {
127 if !entry.registration.method_ids.contains(&method_id) {
129 return RpcResponse {
130 status: RpcStatus::MethodNotFound as i32,
131 payload: Bytes::new(),
132 error: Some(synapse_proto::RpcError {
133 code: 1004,
134 message: format!("Method {:?} not found", method_id),
135 details: vec![],
136 }),
137 headers: vec![],
138 responded_at_unix_ms: chrono::Utc::now().timestamp_millis(),
139 };
140 }
141
142 entry.handler.handle(request).await
144 }
145 None => {
146 RpcResponse {
148 status: RpcStatus::InterfaceNotFound as i32,
149 payload: Bytes::new(),
150 error: Some(synapse_proto::RpcError {
151 code: 1003,
152 message: format!("Interface {:?} not found", interface_id),
153 details: vec![],
154 }),
155 headers: vec![],
156 responded_at_unix_ms: chrono::Utc::now().timestamp_millis(),
157 }
158 }
159 }
160 }
161
162 pub async fn list_interfaces(&self) -> Vec<InterfaceRegistration> {
164 let interfaces = self.interfaces.read().await;
165 interfaces
166 .values()
167 .map(|entry| entry.registration.clone())
168 .collect()
169 }
170
171 pub async fn has_interface(&self, interface_id: InterfaceId) -> bool {
173 let interfaces = self.interfaces.read().await;
174 interfaces.contains_key(&interface_id)
175 }
176}
177
178impl Default for InterfaceRegistry {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184pub struct FunctionHandler<F>
186where
187 F: Fn(synapse_proto::RpcRequest) -> futures::future::BoxFuture<'static, RpcResponse>
188 + Send
189 + Sync,
190{
191 func: F,
192}
193
194impl<F> FunctionHandler<F>
195where
196 F: Fn(synapse_proto::RpcRequest) -> futures::future::BoxFuture<'static, RpcResponse>
197 + Send
198 + Sync,
199{
200 pub fn new(func: F) -> Self {
201 Self { func }
202 }
203}
204
205#[async_trait]
206impl<F> RpcHandler for FunctionHandler<F>
207where
208 F: Fn(synapse_proto::RpcRequest) -> futures::future::BoxFuture<'static, RpcResponse>
209 + Send
210 + Sync,
211{
212 async fn handle(&self, request: synapse_proto::RpcRequest) -> RpcResponse {
213 (self.func)(request).await
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[tokio::test]
222 async fn test_registry_basic() {
223 let registry = InterfaceRegistry::new();
224
225 let registration = InterfaceRegistration::new(
226 "test.Interface",
227 "test-service",
228 &["TestMethod"],
229 InstanceId::new_random(),
230 );
231
232 let interface_id = registration.interface_id;
233 let method_id = *registration.method_ids.iter().next().unwrap();
234
235 let handler = Arc::new(FunctionHandler::new(move |_req| {
237 Box::pin(async move {
238 RpcResponse {
239 status: RpcStatus::Ok as i32,
240 payload: Bytes::from("test response"),
241 error: None,
242 headers: vec![],
243 responded_at_unix_ms: chrono::Utc::now().timestamp_millis(),
244 }
245 })
246 }));
247
248 registry.register(registration, handler).await.unwrap();
250
251 let request = synapse_proto::RpcRequest {
253 interface_id: interface_id.into(),
254 method_id: method_id.into(),
255 headers: vec![],
256 payload: Bytes::from("test payload"),
257 sent_at_unix_ms: chrono::Utc::now().timestamp_millis(),
258 };
259
260 let response = registry.route(request).await;
261 assert_eq!(response.status, RpcStatus::Ok as i32);
262 }
263
264 #[tokio::test]
265 async fn test_interface_not_found() {
266 let registry = InterfaceRegistry::new();
267
268 let request = synapse_proto::RpcRequest {
269 interface_id: InterfaceId::from_name("nonexistent.Interface").into(),
270 method_id: MethodId::from_name("Method").into(),
271 headers: vec![],
272 payload: Bytes::new(),
273 sent_at_unix_ms: 0,
274 };
275
276 let response = registry.route(request).await;
277 assert_eq!(response.status, RpcStatus::InterfaceNotFound as i32);
278 }
279
280 #[tokio::test]
281 async fn test_method_not_found() {
282 let registry = InterfaceRegistry::new();
283
284 let registration = InterfaceRegistration::new(
285 "test.Interface",
286 "test-service",
287 &["MethodA"],
288 InstanceId::new_random(),
289 );
290
291 let interface_id = registration.interface_id;
292
293 let handler = Arc::new(FunctionHandler::new(move |_req| {
294 Box::pin(async move {
295 RpcResponse {
296 status: RpcStatus::Ok as i32,
297 payload: Bytes::new(),
298 error: None,
299 headers: vec![],
300 responded_at_unix_ms: 0,
301 }
302 })
303 }));
304
305 registry.register(registration, handler).await.unwrap();
306
307 let request = synapse_proto::RpcRequest {
309 interface_id: interface_id.into(),
310 method_id: MethodId::from_name("MethodB").into(), headers: vec![],
312 payload: Bytes::new(),
313 sent_at_unix_ms: 0,
314 };
315
316 let response = registry.route(request).await;
317 assert_eq!(response.status, RpcStatus::MethodNotFound as i32);
318 }
319}