1use dashmap::DashMap;
33use std::sync::Arc;
34
35use crate::handlers::{
36 CompletionProvider, ElicitationHandler, HandlerCapabilities, PingHandler,
37 ResourceTemplateHandler,
38};
39use crate::registry::{Registry, RegistryError};
40
41macro_rules! register_handler {
43 ($map:expr, $caps:expr, $name:expr, $handler:expr, $cap_field:ident) => {{
44 let name = $name.into();
45 if $map.contains_key(&name) {
46 return Err(RegistryError::AlreadyExists(name));
47 }
48
49 $map.insert(name.clone(), $handler);
50
51 $caps.entry(name.clone()).or_default().$cap_field = true;
53
54 Ok(())
55 }};
56}
57
58pub struct EnhancedRegistry {
60 base: Registry,
62
63 elicitation_handlers: Arc<DashMap<String, Arc<dyn ElicitationHandler>>>,
65
66 completion_providers: Arc<DashMap<String, Arc<dyn CompletionProvider>>>,
68
69 template_handlers: Arc<DashMap<String, Arc<dyn ResourceTemplateHandler>>>,
71
72 ping_handlers: Arc<DashMap<String, Arc<dyn PingHandler>>>,
74
75 capabilities: Arc<DashMap<String, HandlerCapabilities>>,
77}
78
79impl EnhancedRegistry {
80 pub fn new() -> Self {
82 Self {
83 base: Registry::new(),
84 elicitation_handlers: Arc::new(DashMap::new()),
85 completion_providers: Arc::new(DashMap::new()),
86 template_handlers: Arc::new(DashMap::new()),
87 ping_handlers: Arc::new(DashMap::new()),
88 capabilities: Arc::new(DashMap::new()),
89 }
90 }
91
92 pub fn register_elicitation_handler(
94 &self,
95 name: impl Into<String>,
96 handler: Arc<dyn ElicitationHandler>,
97 ) -> Result<(), RegistryError> {
98 register_handler!(
99 self.elicitation_handlers,
100 self.capabilities,
101 name,
102 handler,
103 elicitation
104 )
105 }
106
107 pub fn get_elicitation_handler(&self, name: &str) -> Option<Arc<dyn ElicitationHandler>> {
109 self.elicitation_handlers.get(name).map(|h| h.clone())
110 }
111
112 pub fn list_elicitation_handlers(&self) -> Vec<String> {
114 self.elicitation_handlers
115 .iter()
116 .map(|entry| entry.key().clone())
117 .collect()
118 }
119
120 pub fn register_completion_provider(
122 &self,
123 name: impl Into<String>,
124 provider: Arc<dyn CompletionProvider>,
125 ) -> Result<(), RegistryError> {
126 register_handler!(
127 self.completion_providers,
128 self.capabilities,
129 name,
130 provider,
131 completion
132 )
133 }
134
135 pub fn get_completion_provider(&self, name: &str) -> Option<Arc<dyn CompletionProvider>> {
137 self.completion_providers.get(name).map(|p| p.clone())
138 }
139
140 pub fn get_matching_completion_providers(
142 &self,
143 context: &crate::context::CompletionContext,
144 ) -> Vec<Arc<dyn CompletionProvider>> {
145 let mut providers: Vec<_> = self
146 .completion_providers
147 .iter()
148 .filter_map(|entry| {
149 let provider = entry.value();
150 if provider.can_provide(context) {
151 Some(provider.clone())
152 } else {
153 None
154 }
155 })
156 .collect();
157
158 providers.sort_by_key(|p| -p.priority());
160 providers
161 }
162
163 pub fn register_template_handler(
165 &self,
166 name: impl Into<String>,
167 handler: Arc<dyn ResourceTemplateHandler>,
168 ) -> Result<(), RegistryError> {
169 register_handler!(
170 self.template_handlers,
171 self.capabilities,
172 name,
173 handler,
174 templates
175 )
176 }
177
178 pub fn get_template_handler(&self, name: &str) -> Option<Arc<dyn ResourceTemplateHandler>> {
180 self.template_handlers.get(name).map(|h| h.clone())
181 }
182
183 pub fn register_ping_handler(
185 &self,
186 name: impl Into<String>,
187 handler: Arc<dyn PingHandler>,
188 ) -> Result<(), RegistryError> {
189 register_handler!(self.ping_handlers, self.capabilities, name, handler, ping)
190 }
191
192 pub fn get_ping_handler(&self, name: &str) -> Option<Arc<dyn PingHandler>> {
194 self.ping_handlers.get(name).map(|h| h.clone())
195 }
196
197 pub fn get_capabilities(&self, name: &str) -> Option<HandlerCapabilities> {
199 self.capabilities.get(name).map(|c| c.clone())
200 }
201
202 pub fn find_by_capabilities(
204 &self,
205 filter: impl Fn(&HandlerCapabilities) -> bool,
206 ) -> Vec<String> {
207 self.capabilities
208 .iter()
209 .filter(|entry| filter(entry.value()))
210 .map(|entry| entry.key().clone())
211 .collect()
212 }
213
214 pub fn clear_handlers(&self) {
216 self.elicitation_handlers.clear();
217 self.completion_providers.clear();
218 self.template_handlers.clear();
219 self.ping_handlers.clear();
220 self.capabilities.clear();
221 }
222
223 pub fn handler_stats(&self) -> HandlerStats {
225 HandlerStats {
226 elicitation_handlers: self.elicitation_handlers.len(),
227 completion_providers: self.completion_providers.len(),
228 template_handlers: self.template_handlers.len(),
229 ping_handlers: self.ping_handlers.len(),
230 total_components: self.capabilities.len(),
231 }
232 }
233
234 pub fn base(&self) -> &Registry {
236 &self.base
237 }
238}
239
240impl Default for EnhancedRegistry {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246impl std::fmt::Debug for EnhancedRegistry {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 f.debug_struct("EnhancedRegistry")
249 .field("base", &self.base)
250 .field(
251 "elicitation_handlers_count",
252 &self.elicitation_handlers.len(),
253 )
254 .field(
255 "completion_providers_count",
256 &self.completion_providers.len(),
257 )
258 .field("template_handlers_count", &self.template_handlers.len())
259 .field("ping_handlers_count", &self.ping_handlers.len())
260 .field("capabilities_count", &self.capabilities.len())
261 .finish()
262 }
263}
264
265#[derive(Debug, Clone)]
267pub struct HandlerStats {
268 pub elicitation_handlers: usize,
270 pub completion_providers: usize,
272 pub template_handlers: usize,
274 pub ping_handlers: usize,
276 pub total_components: usize,
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use crate::context::{CompletionContext, ElicitationContext};
284 use crate::handlers::{CompletionItem, ElicitationResponse};
285 use async_trait::async_trait;
286
287 struct TestElicitationHandler;
288
289 #[async_trait]
290 impl ElicitationHandler for TestElicitationHandler {
291 async fn handle_elicitation(
292 &self,
293 _context: &ElicitationContext,
294 ) -> crate::error::Result<ElicitationResponse> {
295 Ok(ElicitationResponse {
296 accepted: true,
297 content: None,
298 decline_reason: None,
299 })
300 }
301
302 fn can_handle(&self, _context: &ElicitationContext) -> bool {
303 true
304 }
305 }
306
307 struct TestCompletionProvider;
308
309 #[async_trait]
310 impl CompletionProvider for TestCompletionProvider {
311 async fn provide_completions(
312 &self,
313 _context: &CompletionContext,
314 ) -> crate::error::Result<Vec<CompletionItem>> {
315 Ok(vec![])
316 }
317
318 fn can_provide(&self, _context: &CompletionContext) -> bool {
319 true
320 }
321
322 fn priority(&self) -> i32 {
323 10
324 }
325 }
326
327 #[test]
328 fn test_enhanced_registry() {
329 let registry = EnhancedRegistry::new();
330
331 let handler = Arc::new(TestElicitationHandler);
333 registry
334 .register_elicitation_handler("test_handler", handler)
335 .unwrap();
336
337 assert!(registry.get_elicitation_handler("test_handler").is_some());
339 assert_eq!(registry.list_elicitation_handlers(), vec!["test_handler"]);
340
341 let caps = registry.get_capabilities("test_handler").unwrap();
343 assert!(caps.elicitation);
344 assert!(!caps.completion);
345 }
346
347 #[test]
348 fn test_completion_provider_priority() {
349 let registry = EnhancedRegistry::new();
350
351 let provider = Arc::new(TestCompletionProvider);
353 registry
354 .register_completion_provider("test_provider", provider)
355 .unwrap();
356
357 use crate::context::CompletionReference;
359 let context = CompletionContext::new(CompletionReference::Tool {
360 name: "test".to_string(),
361 argument: "arg".to_string(),
362 });
363
364 let providers = registry.get_matching_completion_providers(&context);
366 assert_eq!(providers.len(), 1);
367 assert_eq!(providers[0].priority(), 10);
368 }
369
370 #[test]
371 fn test_handler_stats() {
372 let registry = EnhancedRegistry::new();
373
374 registry
376 .register_elicitation_handler("elicit1", Arc::new(TestElicitationHandler))
377 .unwrap();
378 registry
379 .register_completion_provider("comp1", Arc::new(TestCompletionProvider))
380 .unwrap();
381
382 let stats = registry.handler_stats();
383 assert_eq!(stats.elicitation_handlers, 1);
384 assert_eq!(stats.completion_providers, 1);
385 assert_eq!(stats.template_handlers, 0);
386 assert_eq!(stats.ping_handlers, 0);
387 assert_eq!(stats.total_components, 2);
388 }
389}