1use std::collections::{HashMap, HashSet};
10use std::sync::{Arc, RwLock};
11
12use crate::traits::tool::ErasedTool;
13use crate::traits::tool_registry::ToolRegistry;
14
15struct ToolEntry {
21 tool: Arc<dyn ErasedTool>,
22 enabled: bool,
23}
24
25pub struct DynamicRegistry {
39 tools: RwLock<Vec<ToolEntry>>,
40}
41
42impl DynamicRegistry {
43 #[must_use]
45 pub fn new() -> Self {
46 Self {
47 tools: RwLock::new(Vec::new()),
48 }
49 }
50
51 #[must_use]
53 pub fn with_tools(tools: Vec<Arc<dyn ErasedTool>>) -> Self {
54 let entries = tools
55 .into_iter()
56 .map(|tool| ToolEntry {
57 tool,
58 enabled: true,
59 })
60 .collect();
61 Self {
62 tools: RwLock::new(entries),
63 }
64 }
65}
66
67impl Default for DynamicRegistry {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl ToolRegistry for DynamicRegistry {
74 fn get_tools(&self) -> Vec<Arc<dyn ErasedTool>> {
75 let tools = self.tools.read().expect("DynamicRegistry lock poisoned");
76 tools
77 .iter()
78 .filter(|e| e.enabled)
79 .map(|e| Arc::clone(&e.tool))
80 .collect()
81 }
82
83 fn find_tool(&self, name: &str) -> Option<Arc<dyn ErasedTool>> {
84 let tools = self.tools.read().expect("DynamicRegistry lock poisoned");
85 tools
86 .iter()
87 .find(|e| e.enabled && e.tool.name() == name)
88 .map(|e| Arc::clone(&e.tool))
89 }
90
91 fn register(&self, tool: Arc<dyn ErasedTool>) -> bool {
92 let mut tools = self.tools.write().expect("DynamicRegistry lock poisoned");
93 let name = tool.name().to_string();
94 if tools.iter().any(|e| e.tool.name() == name) {
96 return false;
97 }
98 tools.push(ToolEntry {
99 tool,
100 enabled: true,
101 });
102 true
103 }
104
105 fn unregister(&self, name: &str) -> bool {
106 let mut tools = self.tools.write().expect("DynamicRegistry lock poisoned");
107 let len_before = tools.len();
108 tools.retain(|e| e.tool.name() != name);
109 tools.len() < len_before
110 }
111
112 fn set_enabled(&self, name: &str, enabled: bool) -> bool {
113 let mut tools = self.tools.write().expect("DynamicRegistry lock poisoned");
114 if let Some(entry) = tools.iter_mut().find(|e| e.tool.name() == name) {
115 if entry.enabled != enabled {
116 entry.enabled = enabled;
117 return true;
118 }
119 }
120 false
121 }
122
123 fn is_enabled(&self, name: &str) -> bool {
124 let tools = self.tools.read().expect("DynamicRegistry lock poisoned");
125 tools.iter().any(|e| e.tool.name() == name && e.enabled)
126 }
127
128 fn len(&self) -> usize {
129 let tools = self.tools.read().expect("DynamicRegistry lock poisoned");
130 tools.iter().filter(|e| e.enabled).count()
131 }
132
133 fn is_empty(&self) -> bool {
134 self.len() == 0
135 }
136}
137
138pub struct GroupedRegistry {
174 groups: RwLock<HashMap<String, Vec<Arc<dyn ErasedTool>>>>,
175 active_groups: RwLock<HashSet<String>>,
176}
177
178impl GroupedRegistry {
179 #[must_use]
181 pub fn new() -> Self {
182 Self {
183 groups: RwLock::new(HashMap::new()),
184 active_groups: RwLock::new(HashSet::new()),
185 }
186 }
187
188 #[must_use]
195 pub fn group(self, name: impl Into<String>, tools: Vec<Arc<dyn ErasedTool>>) -> Self {
196 {
197 let mut groups = self.groups.write().expect("GroupedRegistry lock poisoned");
198 groups.insert(name.into(), tools);
199 }
200 self
201 }
202
203 #[must_use]
208 pub fn activate(self, name: impl Into<String>) -> Self {
209 {
210 let mut active = self
211 .active_groups
212 .write()
213 .expect("GroupedRegistry lock poisoned");
214 active.insert(name.into());
215 }
216 self
217 }
218
219 pub fn activate_group(&self, name: &str) -> bool {
223 let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
224 if groups.contains_key(name) {
225 let mut active = self
226 .active_groups
227 .write()
228 .expect("GroupedRegistry lock poisoned");
229 active.insert(name.to_string());
230 true
231 } else {
232 false
233 }
234 }
235
236 pub fn deactivate_group(&self, name: &str) -> bool {
240 let mut active = self
241 .active_groups
242 .write()
243 .expect("GroupedRegistry lock poisoned");
244 active.remove(name)
245 }
246
247 #[must_use]
249 pub fn group_names(&self) -> Vec<String> {
250 let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
251 groups.keys().cloned().collect()
252 }
253
254 #[must_use]
256 pub fn active_group_names(&self) -> Vec<String> {
257 let active = self
258 .active_groups
259 .read()
260 .expect("GroupedRegistry lock poisoned");
261 active.iter().cloned().collect()
262 }
263
264 #[must_use]
266 pub fn is_group_active(&self, name: &str) -> bool {
267 let active = self
268 .active_groups
269 .read()
270 .expect("GroupedRegistry lock poisoned");
271 active.contains(name)
272 }
273}
274
275impl Default for GroupedRegistry {
276 fn default() -> Self {
277 Self::new()
278 }
279}
280
281impl ToolRegistry for GroupedRegistry {
282 fn get_tools(&self) -> Vec<Arc<dyn ErasedTool>> {
284 let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
285 let active = self
286 .active_groups
287 .read()
288 .expect("GroupedRegistry lock poisoned");
289
290 let mut tools = Vec::new();
291 for group_name in active.iter() {
292 if let Some(group_tools) = groups.get(group_name) {
293 for tool in group_tools {
294 tools.push(Arc::clone(tool));
295 }
296 }
297 }
298 tools
299 }
300
301 fn find_tool(&self, name: &str) -> Option<Arc<dyn ErasedTool>> {
305 let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
306 for tools in groups.values() {
307 if let Some(tool) = tools.iter().find(|t| t.name() == name) {
308 return Some(Arc::clone(tool));
309 }
310 }
311 None
312 }
313
314 fn len(&self) -> usize {
316 let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
317 let active = self
318 .active_groups
319 .read()
320 .expect("GroupedRegistry lock poisoned");
321
322 let mut count = 0;
323 for group_name in active.iter() {
324 if let Some(group_tools) = groups.get(group_name) {
325 count += group_tools.len();
326 }
327 }
328 count
329 }
330
331 fn is_empty(&self) -> bool {
332 self.len() == 0
333 }
334}
335
336#[derive(Debug, Clone, Copy)]
344pub struct TierLimits {
345 pub small: usize,
347 pub medium: usize,
349 pub large: usize,
351}
352
353impl Default for TierLimits {
354 fn default() -> Self {
355 Self {
356 small: 5,
357 medium: 15,
358 large: usize::MAX,
359 }
360 }
361}
362
363pub struct AdaptiveRegistry {
386 tools: Vec<Arc<dyn ErasedTool>>,
387 limits: TierLimits,
388 tier: crate::types::model_info::ModelTier,
389}
390
391impl AdaptiveRegistry {
392 #[must_use]
396 pub fn new(tools: Vec<Arc<dyn ErasedTool>>, tier: crate::types::model_info::ModelTier) -> Self {
397 Self {
398 tools,
399 limits: TierLimits::default(),
400 tier,
401 }
402 }
403
404 #[must_use]
416 pub fn with_limits(mut self, small: usize, medium: usize, large: usize) -> Self {
417 self.limits = TierLimits {
418 small,
419 medium,
420 large,
421 };
422 self
423 }
424
425 #[must_use]
427 pub fn limits(&self) -> TierLimits {
428 self.limits
429 }
430
431 #[must_use]
433 pub fn tier(&self) -> crate::types::model_info::ModelTier {
434 self.tier
435 }
436
437 fn effective_limit(&self) -> usize {
439 use crate::types::model_info::ModelTier;
440 match self.tier {
441 ModelTier::Small => self.limits.small,
442 ModelTier::Medium => self.limits.medium,
443 ModelTier::Large => self.limits.large,
444 }
445 }
446}
447
448impl ToolRegistry for AdaptiveRegistry {
449 fn get_tools(&self) -> Vec<Arc<dyn ErasedTool>> {
453 let limit = self.effective_limit();
454 self.tools
455 .iter()
456 .take(limit)
457 .map(|t| Arc::clone(t))
458 .collect()
459 }
460
461 fn find_tool(&self, name: &str) -> Option<Arc<dyn ErasedTool>> {
465 self.tools
466 .iter()
467 .find(|t| t.name() == name)
468 .map(|t| Arc::clone(t))
469 }
470
471 fn len(&self) -> usize {
473 let limit = self.effective_limit();
474 self.tools.len().min(limit)
475 }
476
477 fn is_empty(&self) -> bool {
478 self.len() == 0
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use async_trait::async_trait;
486
487 struct FakeTool {
488 tool_name: String,
489 }
490
491 impl FakeTool {
492 fn new(name: &str) -> Self {
493 Self {
494 tool_name: name.to_string(),
495 }
496 }
497 }
498
499 #[async_trait]
500 impl ErasedTool for FakeTool {
501 fn name(&self) -> &str {
502 &self.tool_name
503 }
504 fn description(&self) -> &str {
505 "fake"
506 }
507 fn schema(&self) -> crate::traits::tool::ToolSchema {
508 crate::traits::tool::ToolSchema {
509 name: self.tool_name.clone(),
510 description: "fake".to_string(),
511 parameters: serde_json::json!({}),
512 }
513 }
514 async fn execute_json(
515 &self,
516 _input: serde_json::Value,
517 ) -> crate::Result<serde_json::Value> {
518 Ok(serde_json::json!("ok"))
519 }
520 }
521
522 #[test]
525 fn test_dynamic_registry_empty() {
526 let reg = DynamicRegistry::new();
527 assert!(reg.is_empty());
528 assert_eq!(reg.len(), 0);
529 }
530
531 #[test]
532 fn test_dynamic_registry_register_and_find() {
533 let reg = DynamicRegistry::new();
534 assert!(reg.register(Arc::new(FakeTool::new("search"))));
535 assert_eq!(reg.len(), 1);
536 assert!(reg.find_tool("search").is_some());
537 assert!(reg.find_tool("calc").is_none());
538 }
539
540 #[test]
541 fn test_dynamic_registry_no_duplicates() {
542 let reg = DynamicRegistry::new();
543 assert!(reg.register(Arc::new(FakeTool::new("search"))));
544 assert!(!reg.register(Arc::new(FakeTool::new("search"))));
545 assert_eq!(reg.len(), 1);
546 }
547
548 #[test]
549 fn test_dynamic_registry_unregister() {
550 let reg = DynamicRegistry::new();
551 reg.register(Arc::new(FakeTool::new("search")));
552 assert!(reg.unregister("search"));
553 assert!(reg.is_empty());
554 assert!(!reg.unregister("search")); }
556
557 #[test]
558 fn test_dynamic_registry_set_enabled() {
559 let reg = DynamicRegistry::new();
560 reg.register(Arc::new(FakeTool::new("search")));
561 assert!(reg.is_enabled("search"));
562
563 assert!(reg.set_enabled("search", false));
565 assert!(!reg.is_enabled("search"));
566 assert_eq!(reg.len(), 0); assert!(reg.find_tool("search").is_none()); assert!(reg.set_enabled("search", true));
571 assert!(reg.is_enabled("search"));
572 assert_eq!(reg.len(), 1);
573 }
574
575 #[test]
576 fn test_dynamic_registry_set_enabled_no_change() {
577 let reg = DynamicRegistry::new();
578 reg.register(Arc::new(FakeTool::new("search")));
579 assert!(!reg.set_enabled("search", true));
581 }
582
583 #[test]
584 fn test_dynamic_registry_with_tools() {
585 let tools: Vec<Arc<dyn ErasedTool>> =
586 vec![Arc::new(FakeTool::new("a")), Arc::new(FakeTool::new("b"))];
587 let reg = DynamicRegistry::with_tools(tools);
588 assert_eq!(reg.len(), 2);
589 assert!(reg.is_enabled("a"));
590 assert!(reg.is_enabled("b"));
591 }
592
593 #[test]
594 fn test_dynamic_registry_get_tools_only_enabled() {
595 let reg = DynamicRegistry::new();
596 reg.register(Arc::new(FakeTool::new("a")));
597 reg.register(Arc::new(FakeTool::new("b")));
598 reg.set_enabled("a", false);
599
600 let tools = reg.get_tools();
601 assert_eq!(tools.len(), 1);
602 assert_eq!(tools[0].name(), "b");
603 }
604
605 #[test]
608 fn test_grouped_registry_empty() {
609 let reg = GroupedRegistry::new();
610 assert!(reg.is_empty());
611 assert_eq!(reg.len(), 0);
612 assert!(reg.get_tools().is_empty());
613 assert!(reg.find_tool("anything").is_none());
614 }
615
616 #[test]
617 fn test_grouped_registry_single_group() {
618 let tools: Vec<Arc<dyn ErasedTool>> = vec![
619 Arc::new(FakeTool::new("web_search")),
620 Arc::new(FakeTool::new("deep_search")),
621 ];
622 let reg = GroupedRegistry::new()
623 .group("search", tools)
624 .activate("search");
625
626 assert_eq!(reg.len(), 2);
627 assert!(!reg.is_empty());
628 let active = reg.get_tools();
629 assert_eq!(active.len(), 2);
630 }
631
632 #[test]
633 fn test_grouped_registry_multiple_groups_activate_switch() {
634 let search_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("web_search"))];
636 let code_tools: Vec<Arc<dyn ErasedTool>> = vec![
637 Arc::new(FakeTool::new("read_file")),
638 Arc::new(FakeTool::new("write_file")),
639 ];
640
641 let reg = GroupedRegistry::new()
642 .group("search", search_tools)
643 .group("code", code_tools)
644 .activate("search");
645
646 assert_eq!(reg.len(), 1);
648 assert_eq!(reg.get_tools()[0].name(), "web_search");
649
650 assert!(reg.deactivate_group("search"));
652 assert!(reg.is_empty());
653
654 assert!(reg.activate_group("code"));
656 assert_eq!(reg.len(), 2);
657 let names: Vec<String> = reg
658 .get_tools()
659 .iter()
660 .map(|t| t.name().to_string())
661 .collect();
662 assert!(names.contains(&"read_file".to_string()));
663 assert!(names.contains(&"write_file".to_string()));
664 }
665
666 #[test]
667 fn test_grouped_registry_multiple_active_groups() {
668 let search_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("web_search"))];
670 let code_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("read_file"))];
671
672 let reg = GroupedRegistry::new()
673 .group("search", search_tools)
674 .group("code", code_tools)
675 .activate("search")
676 .activate("code");
677
678 assert_eq!(reg.len(), 2);
679 let names: Vec<String> = reg
680 .get_tools()
681 .iter()
682 .map(|t| t.name().to_string())
683 .collect();
684 assert!(names.contains(&"web_search".to_string()));
685 assert!(names.contains(&"read_file".to_string()));
686 }
687
688 #[test]
689 fn test_grouped_registry_find_tool_searches_all_groups() {
690 let search_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("web_search"))];
692 let code_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("read_file"))];
693
694 let reg = GroupedRegistry::new()
695 .group("search", search_tools)
696 .group("code", code_tools)
697 .activate("search"); assert_eq!(reg.get_tools().len(), 1);
701
702 assert!(reg.find_tool("web_search").is_some()); assert!(reg.find_tool("read_file").is_some()); assert!(reg.find_tool("nonexistent").is_none());
706 }
707
708 #[test]
709 fn test_grouped_registry_activate_nonexistent_group() {
710 let reg = GroupedRegistry::new().group("search", vec![Arc::new(FakeTool::new("a"))]);
711
712 assert!(!reg.activate_group("nonexistent"));
714 assert!(reg.activate_group("search"));
715 }
716
717 #[test]
718 fn test_grouped_registry_deactivate_nonexistent() {
719 let reg = GroupedRegistry::new();
720 assert!(!reg.deactivate_group("nonexistent"));
722 }
723
724 #[test]
725 fn test_grouped_registry_group_names() {
726 let reg = GroupedRegistry::new()
727 .group("search", vec![])
728 .group("code", vec![])
729 .activate("search");
730
731 let mut names = reg.group_names();
732 names.sort();
733 assert_eq!(names, vec!["code", "search"]);
734
735 let active = reg.active_group_names();
736 assert_eq!(active.len(), 1);
737 assert!(active.contains(&"search".to_string()));
738 }
739
740 #[test]
741 fn test_grouped_registry_is_group_active() {
742 let reg = GroupedRegistry::new()
743 .group("search", vec![])
744 .group("code", vec![])
745 .activate("search");
746
747 assert!(reg.is_group_active("search"));
748 assert!(!reg.is_group_active("code"));
749 }
750
751 #[test]
752 fn test_grouped_registry_concurrent_read() {
753 use std::thread;
755
756 let reg = Arc::new(
757 GroupedRegistry::new()
758 .group("a", vec![Arc::new(FakeTool::new("tool_a"))])
759 .group("b", vec![Arc::new(FakeTool::new("tool_b"))])
760 .activate("a"),
761 );
762
763 let mut handles = vec![];
764 for _ in 0..10 {
765 let reg_clone = Arc::clone(®);
766 handles.push(thread::spawn(move || {
767 for _ in 0..100 {
768 let tools = reg_clone.get_tools();
769 assert_eq!(tools.len(), 1);
770 assert!(reg_clone.find_tool("tool_a").is_some());
771 assert!(reg_clone.find_tool("tool_b").is_some());
772 }
773 }));
774 }
775
776 for h in handles {
777 h.join().expect("thread panicked");
778 }
779 }
780
781 #[test]
782 fn test_grouped_registry_object_safe() {
783 let reg = GroupedRegistry::new();
784 let _: Arc<dyn ToolRegistry> = Arc::new(reg);
785 }
786
787 #[test]
788 fn test_grouped_registry_replace_group() {
789 let reg = GroupedRegistry::new()
791 .group("search", vec![Arc::new(FakeTool::new("old_tool"))])
792 .group("search", vec![Arc::new(FakeTool::new("new_tool"))])
793 .activate("search");
794
795 assert_eq!(reg.len(), 1);
796 assert!(reg.find_tool("new_tool").is_some());
797 assert!(reg.find_tool("old_tool").is_none());
798 }
799
800 fn make_tools(n: usize) -> Vec<Arc<dyn ErasedTool>> {
803 (0..n)
804 .map(|i| Arc::new(FakeTool::new(&format!("tool_{i}"))) as Arc<dyn ErasedTool>)
805 .collect()
806 }
807
808 #[test]
809 fn test_adaptive_registry_small_tier_limits() {
810 use crate::types::model_info::ModelTier;
812 let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Small);
813 assert_eq!(reg.len(), 5);
814 assert_eq!(reg.get_tools().len(), 5);
815 assert_eq!(reg.get_tools()[0].name(), "tool_0");
817 assert_eq!(reg.get_tools()[4].name(), "tool_4");
818 }
819
820 #[test]
821 fn test_adaptive_registry_medium_tier_limits() {
822 use crate::types::model_info::ModelTier;
823 let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Medium);
824 assert_eq!(reg.len(), 15);
825 assert_eq!(reg.get_tools().len(), 15);
826 }
827
828 #[test]
829 fn test_adaptive_registry_large_tier_all() {
830 use crate::types::model_info::ModelTier;
832 let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Large);
833 assert_eq!(reg.len(), 30);
834 assert_eq!(reg.get_tools().len(), 30);
835 }
836
837 #[test]
838 fn test_adaptive_registry_custom_limits() {
839 use crate::types::model_info::ModelTier;
840 let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Small).with_limits(3, 10, 50);
841 assert_eq!(reg.len(), 3);
842 assert_eq!(reg.get_tools().len(), 3);
843 }
844
845 #[test]
846 fn test_adaptive_registry_find_tool_beyond_limit() {
847 use crate::types::model_info::ModelTier;
849 let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Small);
850 assert!(reg.find_tool("tool_29").is_some());
852 assert!(reg.find_tool("tool_0").is_some());
853 assert!(reg.find_tool("nonexistent").is_none());
854 }
855
856 #[test]
857 fn test_adaptive_registry_empty() {
858 use crate::types::model_info::ModelTier;
859 let reg = AdaptiveRegistry::new(vec![], ModelTier::Large);
860 assert!(reg.is_empty());
861 assert_eq!(reg.len(), 0);
862 }
863
864 #[test]
865 fn test_adaptive_registry_object_safe() {
866 use crate::types::model_info::ModelTier;
867 let reg = AdaptiveRegistry::new(vec![], ModelTier::Medium);
868 let _: Arc<dyn ToolRegistry> = Arc::new(reg);
869 }
870}