steer_core/tools/
resolver.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use steer_tools::ToolSchema;
6use tokio::sync::RwLock;
7
8use super::backend::ToolBackend;
9use super::mcp::McpBackend;
10
11#[async_trait]
12pub trait BackendResolver: Send + Sync {
13 async fn resolve(&self, tool_name: &str) -> Option<Arc<dyn ToolBackend>>;
14
15 async fn get_tool_schemas(&self) -> Vec<ToolSchema>;
16
17 fn requires_approval(&self, tool_name: &str) -> Option<bool>;
18}
19
20#[async_trait]
21impl BackendResolver for super::BackendRegistry {
22 async fn resolve(&self, tool_name: &str) -> Option<Arc<dyn ToolBackend>> {
23 self.get_backend_for_tool(tool_name).cloned()
24 }
25
26 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
27 let mut schemas = Vec::new();
28 for (_, backend) in self.backends() {
29 schemas.extend(backend.get_tool_schemas().await);
30 }
31 schemas
32 }
33
34 fn requires_approval(&self, _tool_name: &str) -> Option<bool> {
35 None
36 }
37}
38
39pub struct SessionMcpBackends {
40 backends: RwLock<HashMap<String, Arc<McpBackend>>>,
41 tool_to_backend: RwLock<HashMap<String, String>>,
42 generations: RwLock<HashMap<String, u64>>,
43}
44
45impl Default for SessionMcpBackends {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl SessionMcpBackends {
52 pub fn new() -> Self {
53 Self {
54 backends: RwLock::new(HashMap::new()),
55 tool_to_backend: RwLock::new(HashMap::new()),
56 generations: RwLock::new(HashMap::new()),
57 }
58 }
59
60 pub async fn next_generation(&self, server_name: &str) -> u64 {
61 let mut generations = self.generations.write().await;
62 let next = generations
63 .get(server_name)
64 .copied()
65 .unwrap_or(0)
66 .wrapping_add(1);
67 generations.insert(server_name.to_string(), next);
68 next
69 }
70
71 pub async fn is_current_generation(&self, server_name: &str, generation: u64) -> bool {
72 let generations = self.generations.read().await;
73 generations.get(server_name).copied().unwrap_or(0) == generation
74 }
75
76 pub async fn register(&self, server_name: String, backend: Arc<McpBackend>) {
77 let tool_names = backend.supported_tools().await;
78
79 let mut tool_mapping = self.tool_to_backend.write().await;
80 tool_mapping.retain(|_, name| name != &server_name);
81 for tool_name in tool_names {
82 tool_mapping.insert(tool_name, server_name.clone());
83 }
84 drop(tool_mapping);
85
86 let mut backends = self.backends.write().await;
87 backends.insert(server_name, backend);
88 }
89
90 pub async fn unregister(&self, server_name: &str) -> Option<Arc<McpBackend>> {
91 let mut backends = self.backends.write().await;
92 let removed = backends.remove(server_name);
93
94 if removed.is_some() {
95 let mut tool_mapping = self.tool_to_backend.write().await;
96 tool_mapping.retain(|_, name| name != server_name);
97 }
98
99 removed
100 }
101
102 pub async fn get(&self, server_name: &str) -> Option<Arc<McpBackend>> {
103 let backends = self.backends.read().await;
104 backends.get(server_name).cloned()
105 }
106
107 pub async fn clear(&self) {
108 let mut backends = self.backends.write().await;
109 backends.clear();
110 drop(backends);
111
112 let mut tool_mapping = self.tool_to_backend.write().await;
113 tool_mapping.clear();
114
115 let mut generations = self.generations.write().await;
116 generations.clear();
117 }
118}
119
120#[async_trait]
121impl BackendResolver for SessionMcpBackends {
122 async fn resolve(&self, tool_name: &str) -> Option<Arc<dyn ToolBackend>> {
123 let tool_mapping = self.tool_to_backend.read().await;
124 let server_name = tool_mapping.get(tool_name)?.clone();
125 drop(tool_mapping);
126
127 let backends = self.backends.read().await;
128 backends
129 .get(&server_name)
130 .map(|b| b.clone() as Arc<dyn ToolBackend>)
131 }
132
133 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
134 let mut schemas = Vec::new();
135 let backends = self.backends.read().await;
136 for backend in backends.values() {
137 schemas.extend(backend.get_tool_schemas().await);
138 }
139 schemas
140 }
141
142 fn requires_approval(&self, tool_name: &str) -> Option<bool> {
143 let tool_mapping = self.tool_to_backend.try_read().ok()?;
144 if tool_mapping.contains_key(tool_name) {
145 Some(true)
146 } else {
147 None
148 }
149 }
150}
151
152pub struct OverlayResolver {
153 session: Arc<SessionMcpBackends>,
154 static_resolver: Arc<dyn BackendResolver>,
155}
156
157impl OverlayResolver {
158 pub fn new(
159 session: Arc<SessionMcpBackends>,
160 static_resolver: Arc<dyn BackendResolver>,
161 ) -> Self {
162 Self {
163 session,
164 static_resolver,
165 }
166 }
167
168 pub fn session_backends(&self) -> &Arc<SessionMcpBackends> {
169 &self.session
170 }
171}
172
173#[async_trait]
174impl BackendResolver for OverlayResolver {
175 async fn resolve(&self, tool_name: &str) -> Option<Arc<dyn ToolBackend>> {
176 if let Some(backend) = self.session.resolve(tool_name).await {
177 return Some(backend);
178 }
179 self.static_resolver.resolve(tool_name).await
180 }
181
182 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
183 let mut schemas = self.session.get_tool_schemas().await;
184 schemas.extend(self.static_resolver.get_tool_schemas().await);
185 schemas
186 }
187
188 fn requires_approval(&self, tool_name: &str) -> Option<bool> {
189 self.session
190 .requires_approval(tool_name)
191 .or_else(|| self.static_resolver.requires_approval(tool_name))
192 }
193}