turbomcp_protocol/
registry.rs1use std::any::{Any, TypeId};
22use std::collections::HashMap;
23use std::sync::Arc;
24
25use parking_lot::RwLock;
26use thiserror::Error;
27
28#[derive(Error, Debug, Clone)]
33pub enum RegistryError {
34 #[error("Component {0} not found")]
36 NotFound(String),
37
38 #[error("Component {0} already exists")]
40 AlreadyExists(String),
41
42 #[error("Type mismatch for component {0}")]
44 TypeMismatch(String),
45}
46
47impl From<RegistryError> for Box<crate::error::Error> {
49 fn from(err: RegistryError) -> Self {
50 use crate::error::Error;
51 match err {
52 RegistryError::NotFound(name) => {
53 Error::internal(format!("Component '{}' not found in registry", name))
54 .with_component("registry")
55 }
56 RegistryError::AlreadyExists(name) => {
57 Error::validation(format!("Component '{}' already exists in registry", name))
58 .with_component("registry")
59 }
60 RegistryError::TypeMismatch(name) => Error::internal(format!(
61 "Type mismatch when accessing component '{}' in registry",
62 name
63 ))
64 .with_component("registry"),
65 }
66 }
67}
68
69#[derive(Debug)]
71pub struct Registry {
72 components: RwLock<HashMap<String, Arc<dyn Any + Send + Sync>>>,
74
75 type_map: RwLock<HashMap<String, TypeId>>,
77}
78
79#[derive(Debug)]
81pub struct RegistryBuilder {
82 registry: Registry,
83}
84
85impl Registry {
86 #[must_use]
88 pub fn new() -> Self {
89 Self {
90 components: RwLock::new(HashMap::new()),
91 type_map: RwLock::new(HashMap::new()),
92 }
93 }
94
95 #[must_use]
97 pub fn builder() -> RegistryBuilder {
98 RegistryBuilder {
99 registry: Self::new(),
100 }
101 }
102
103 pub fn register<T>(&self, name: impl Into<String>, component: T) -> Result<(), RegistryError>
105 where
106 T: 'static + Send + Sync,
107 {
108 let name = name.into();
109 let type_id = TypeId::of::<T>();
110
111 {
112 let mut components = self.components.write();
113 if components.contains_key(&name) {
114 return Err(RegistryError::AlreadyExists(name));
115 }
116 components.insert(name.clone(), Arc::new(component));
117 }
118
119 {
120 let mut type_map = self.type_map.write();
121 type_map.insert(name, type_id);
122 }
123
124 Ok(())
125 }
126
127 pub fn get<T>(&self, name: &str) -> Result<Arc<T>, RegistryError>
129 where
130 T: 'static + Send + Sync,
131 {
132 let component = {
133 let components = self.components.read();
134 components
135 .get(name)
136 .ok_or_else(|| RegistryError::NotFound(name.to_string()))?
137 .clone()
138 }; component
141 .downcast::<T>()
142 .map_err(|_| RegistryError::TypeMismatch(name.to_string()))
143 }
144
145 pub fn contains(&self, name: &str) -> bool {
147 self.components.read().contains_key(name)
148 }
149
150 pub fn component_names(&self) -> Vec<String> {
152 self.components.read().keys().cloned().collect()
153 }
154
155 pub fn remove(&self, name: &str) -> Option<Arc<dyn Any + Send + Sync>> {
157 {
158 let mut type_map = self.type_map.write();
159 type_map.remove(name);
160 } let mut components = self.components.write();
163 components.remove(name)
164 }
165
166 pub fn clear(&self) {
168 self.components.write().clear();
169 self.type_map.write().clear();
170 }
171
172 pub fn len(&self) -> usize {
174 self.components.read().len()
175 }
176
177 pub fn is_empty(&self) -> bool {
179 self.components.read().is_empty()
180 }
181}
182
183impl RegistryBuilder {
184 pub fn register<T>(self, name: impl Into<String>, component: T) -> Result<Self, RegistryError>
186 where
187 T: 'static + Send + Sync,
188 {
189 self.registry.register(name, component)?;
190 Ok(self)
191 }
192
193 pub fn build(self) -> Registry {
195 self.registry
196 }
197}
198
199impl Default for Registry {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205pub trait Component: 'static + Send + Sync {
207 fn name(&self) -> &'static str;
209
210 fn register_in(self, registry: &Registry) -> Result<(), RegistryError>
212 where
213 Self: Sized,
214 {
215 registry.register(self.name(), self)
216 }
217}
218
219#[macro_export]
221macro_rules! register_component {
222 ($registry:expr, $name:expr, $component:expr) => {
223 $registry.register($name, $component)
224 };
225 ($registry:expr, $($name:expr => $component:expr),+ $(,)?) => {
226 {
227 $(
228 $registry.register($name, $component)?;
229 )+
230 Ok::<(), $crate::registry::RegistryError>(())
231 }
232 };
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use std::sync::atomic::{AtomicU32, Ordering};
239
240 #[derive(Debug)]
241 struct TestService {
242 id: u32,
243 counter: AtomicU32,
244 }
245
246 impl TestService {
247 fn new(id: u32) -> Self {
248 Self {
249 id,
250 counter: AtomicU32::new(0),
251 }
252 }
253
254 fn increment(&self) -> u32 {
255 self.counter.fetch_add(1, Ordering::SeqCst) + 1
256 }
257
258 fn get_id(&self) -> u32 {
259 self.id
260 }
261 }
262
263 impl Component for TestService {
264 fn name(&self) -> &'static str {
265 "test_service"
266 }
267 }
268
269 #[test]
270 fn test_registry_basic_operations() {
271 let registry = Registry::new();
272 let service = TestService::new(42);
273
274 assert!(registry.register("test", service).is_ok());
276
277 assert!(registry.contains("test"));
279 assert!(!registry.contains("nonexistent"));
280
281 let retrieved: Arc<TestService> = registry.get("test").unwrap();
283 assert_eq!(retrieved.get_id(), 42);
284
285 assert_eq!(retrieved.increment(), 1);
287 assert_eq!(retrieved.increment(), 2);
288
289 assert_eq!(registry.len(), 1);
291 assert!(!registry.is_empty());
292 }
293
294 #[test]
295 fn test_registry_errors() {
296 let registry = Registry::new();
297
298 let result: Result<Arc<TestService>, _> = registry.get("nonexistent");
300 assert!(matches!(result, Err(RegistryError::NotFound(_))));
301
302 let service1 = TestService::new(1);
304 let service2 = TestService::new(2);
305
306 assert!(registry.register("duplicate", service1).is_ok());
307 let result = registry.register("duplicate", service2);
308 assert!(matches!(result, Err(RegistryError::AlreadyExists(_))));
309 }
310
311 #[test]
312 fn test_registry_builder() {
313 let registry = Registry::builder()
314 .register("service1", TestService::new(1))
315 .unwrap()
316 .register("service2", TestService::new(2))
317 .unwrap()
318 .build();
319
320 assert_eq!(registry.len(), 2);
321
322 let service1: Arc<TestService> = registry.get("service1").unwrap();
323 let service2: Arc<TestService> = registry.get("service2").unwrap();
324
325 assert_eq!(service1.get_id(), 1);
326 assert_eq!(service2.get_id(), 2);
327 }
328
329 #[test]
330 fn test_component_trait() {
331 let registry = Registry::new();
332 let service = TestService::new(123);
333
334 assert!(service.register_in(®istry).is_ok());
336
337 let retrieved: Arc<TestService> = registry.get("test_service").unwrap();
338 assert_eq!(retrieved.get_id(), 123);
339 }
340
341 #[test]
342 fn test_registry_removal() {
343 let registry = Registry::new();
344 let service = TestService::new(42);
345
346 registry.register("test", service).unwrap();
347 assert!(registry.contains("test"));
348
349 let removed = registry.remove("test");
350 assert!(removed.is_some());
351 assert!(!registry.contains("test"));
352
353 let removed = registry.remove("nonexistent");
355 assert!(removed.is_none());
356 }
357
358 #[test]
359 fn test_registry_clear() {
360 let registry = Registry::new();
361
362 registry.register("service1", TestService::new(1)).unwrap();
363 registry.register("service2", TestService::new(2)).unwrap();
364
365 assert_eq!(registry.len(), 2);
366
367 registry.clear();
368
369 assert_eq!(registry.len(), 0);
370 assert!(registry.is_empty());
371 }
372
373 #[test]
374 fn test_component_names() {
375 let registry = Registry::new();
376
377 registry.register("alpha", TestService::new(1)).unwrap();
378 registry.register("beta", TestService::new(2)).unwrap();
379
380 let mut names = registry.component_names();
381 names.sort();
382
383 assert_eq!(names, vec!["alpha", "beta"]);
384 }
385}