Skip to main content

steer_core/tools/
resolver.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use steer_tools::ToolSchema;
6use tokio::sync::RwLock;
7
8use super::backend::ToolBackend;
9use super::mcp::McpBackend;
10
11#[async_trait]
12pub trait BackendResolver: Send + Sync {
13    async fn resolve(&self, tool_name: &str) -> Option<Arc<dyn ToolBackend>>;
14
15    async fn get_tool_schemas(&self) -> Vec<ToolSchema>;
16
17    fn requires_approval(&self, tool_name: &str) -> Option<bool>;
18}
19
20#[async_trait]
21impl BackendResolver for super::BackendRegistry {
22    async fn resolve(&self, tool_name: &str) -> Option<Arc<dyn ToolBackend>> {
23        self.get_backend_for_tool(tool_name).cloned()
24    }
25
26    async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
27        let mut schemas = Vec::new();
28        for (_, backend) in self.backends() {
29            schemas.extend(backend.get_tool_schemas().await);
30        }
31        schemas
32    }
33
34    fn requires_approval(&self, _tool_name: &str) -> Option<bool> {
35        None
36    }
37}
38
39pub struct SessionMcpBackends {
40    backends: RwLock<HashMap<String, Arc<McpBackend>>>,
41    tool_to_backend: RwLock<HashMap<String, String>>,
42    generations: RwLock<HashMap<String, u64>>,
43}
44
45impl Default for SessionMcpBackends {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl SessionMcpBackends {
52    pub fn new() -> Self {
53        Self {
54            backends: RwLock::new(HashMap::new()),
55            tool_to_backend: RwLock::new(HashMap::new()),
56            generations: RwLock::new(HashMap::new()),
57        }
58    }
59
60    pub async fn next_generation(&self, server_name: &str) -> u64 {
61        let mut generations = self.generations.write().await;
62        let next = generations
63            .get(server_name)
64            .copied()
65            .unwrap_or(0)
66            .wrapping_add(1);
67        generations.insert(server_name.to_string(), next);
68        next
69    }
70
71    pub async fn is_current_generation(&self, server_name: &str, generation: u64) -> bool {
72        let generations = self.generations.read().await;
73        generations.get(server_name).copied().unwrap_or(0) == generation
74    }
75
76    pub async fn register(&self, server_name: String, backend: Arc<McpBackend>) {
77        let tool_names = backend.supported_tools().await;
78
79        let mut tool_mapping = self.tool_to_backend.write().await;
80        tool_mapping.retain(|_, name| name != &server_name);
81        for tool_name in tool_names {
82            tool_mapping.insert(tool_name, server_name.clone());
83        }
84        drop(tool_mapping);
85
86        let mut backends = self.backends.write().await;
87        backends.insert(server_name, backend);
88    }
89
90    pub async fn unregister(&self, server_name: &str) -> Option<Arc<McpBackend>> {
91        let mut backends = self.backends.write().await;
92        let removed = backends.remove(server_name);
93
94        if removed.is_some() {
95            let mut tool_mapping = self.tool_to_backend.write().await;
96            tool_mapping.retain(|_, name| name != server_name);
97        }
98
99        removed
100    }
101
102    pub async fn get(&self, server_name: &str) -> Option<Arc<McpBackend>> {
103        let backends = self.backends.read().await;
104        backends.get(server_name).cloned()
105    }
106
107    pub async fn clear(&self) {
108        let mut backends = self.backends.write().await;
109        backends.clear();
110        drop(backends);
111
112        let mut tool_mapping = self.tool_to_backend.write().await;
113        tool_mapping.clear();
114
115        let mut generations = self.generations.write().await;
116        generations.clear();
117    }
118}
119
120#[async_trait]
121impl BackendResolver for SessionMcpBackends {
122    async fn resolve(&self, tool_name: &str) -> Option<Arc<dyn ToolBackend>> {
123        let tool_mapping = self.tool_to_backend.read().await;
124        let server_name = tool_mapping.get(tool_name)?.clone();
125        drop(tool_mapping);
126
127        let backends = self.backends.read().await;
128        backends
129            .get(&server_name)
130            .map(|b| b.clone() as Arc<dyn ToolBackend>)
131    }
132
133    async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
134        let mut schemas = Vec::new();
135        let backends = self.backends.read().await;
136        for backend in backends.values() {
137            schemas.extend(backend.get_tool_schemas().await);
138        }
139        schemas
140    }
141
142    fn requires_approval(&self, tool_name: &str) -> Option<bool> {
143        let tool_mapping = self.tool_to_backend.try_read().ok()?;
144        if tool_mapping.contains_key(tool_name) {
145            Some(true)
146        } else {
147            None
148        }
149    }
150}
151
152pub struct OverlayResolver {
153    session: Arc<SessionMcpBackends>,
154    static_resolver: Arc<dyn BackendResolver>,
155}
156
157impl OverlayResolver {
158    pub fn new(
159        session: Arc<SessionMcpBackends>,
160        static_resolver: Arc<dyn BackendResolver>,
161    ) -> Self {
162        Self {
163            session,
164            static_resolver,
165        }
166    }
167
168    pub fn session_backends(&self) -> &Arc<SessionMcpBackends> {
169        &self.session
170    }
171}
172
173#[async_trait]
174impl BackendResolver for OverlayResolver {
175    async fn resolve(&self, tool_name: &str) -> Option<Arc<dyn ToolBackend>> {
176        if let Some(backend) = self.session.resolve(tool_name).await {
177            return Some(backend);
178        }
179        self.static_resolver.resolve(tool_name).await
180    }
181
182    async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
183        let mut schemas = self.session.get_tool_schemas().await;
184        schemas.extend(self.static_resolver.get_tool_schemas().await);
185        schemas
186    }
187
188    fn requires_approval(&self, tool_name: &str) -> Option<bool> {
189        self.session
190            .requires_approval(tool_name)
191            .or_else(|| self.static_resolver.requires_approval(tool_name))
192    }
193}