1use crate::mcp::McpToolInfo;
20use anyhow::Result;
21use serde_json::Value;
22use std::cmp::Ordering;
23use std::sync::Arc;
24use tracing::{debug, info};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum DetailLevel {
29 NameOnly,
31 NameAndDescription,
33 Full,
35}
36
37impl DetailLevel {
38 pub fn as_str(&self) -> &'static str {
39 match self {
40 Self::NameOnly => "name-only",
41 Self::NameAndDescription => "name-and-description",
42 Self::Full => "full",
43 }
44 }
45}
46
47#[derive(Debug, Clone, serde::Serialize)]
49pub struct ToolDiscoveryResult {
50 pub name: String,
51 pub provider: String,
52 pub description: String,
53 pub relevance_score: f32,
54 pub input_schema: Option<Value>,
56}
57
58impl ToolDiscoveryResult {
59 pub fn to_json(&self, detail_level: DetailLevel) -> Value {
61 match detail_level {
62 DetailLevel::NameOnly => serde_json::json!({
63 "name": self.name,
64 "provider": self.provider,
65 }),
66 DetailLevel::NameAndDescription => serde_json::json!({
67 "name": self.name,
68 "provider": self.provider,
69 "description": self.description,
70 }),
71 DetailLevel::Full => serde_json::json!({
72 "name": self.name,
73 "provider": self.provider,
74 "description": self.description,
75 "input_schema": self.input_schema,
76 }),
77 }
78 }
79}
80
81pub struct ToolDiscovery {
83 mcp_client: Arc<dyn crate::mcp::McpToolExecutor>,
84}
85
86fn group_results_by_provider_preserving_order(
87 tools: impl IntoIterator<Item = ToolDiscoveryResult>,
88) -> Vec<(String, Vec<ToolDiscoveryResult>)> {
89 let mut grouped: Vec<(String, Vec<ToolDiscoveryResult>)> = Vec::new();
90
91 for tool in tools {
92 let provider = tool.provider.clone();
93 if let Some((_, provider_tools)) = grouped
94 .iter_mut()
95 .find(|(existing_provider, _)| *existing_provider == provider)
96 {
97 provider_tools.push(tool);
98 } else {
99 grouped.push((provider, vec![tool]));
100 }
101 }
102
103 grouped
104}
105
106impl ToolDiscovery {
107 pub fn new(mcp_client: Arc<dyn crate::mcp::McpToolExecutor>) -> Self {
109 Self { mcp_client }
110 }
111
112 pub async fn search_tools(
120 &self,
121 keyword: &str,
122 detail_level: DetailLevel,
123 ) -> Result<Vec<ToolDiscoveryResult>> {
124 let tools = self.mcp_client.list_mcp_tools().await?;
125
126 debug!(
127 keyword = keyword,
128 count = tools.len(),
129 "Searching tools for keyword"
130 );
131
132 let mut results = Vec::with_capacity(tools.len() / 4);
134
135 for tool in tools {
136 let relevance_score = self.calculate_relevance(&tool, keyword);
137
138 if relevance_score <= 0.0 {
140 continue;
141 }
142
143 let input_schema = match detail_level {
145 DetailLevel::Full => Some(tool.input_schema.clone()),
146 _ => None,
147 };
148
149 results.push(ToolDiscoveryResult {
150 name: tool.name.clone(),
151 provider: tool.provider.clone(),
152 description: tool.description.clone(),
153 relevance_score,
154 input_schema,
155 });
156 }
157
158 results.sort_by(|a, b| {
160 b.relevance_score
161 .partial_cmp(&a.relevance_score)
162 .unwrap_or(Ordering::Equal)
163 });
164
165 let total_results = results.len();
167 if total_results > 5 {
168 info!(
169 keyword = keyword,
170 matched = total_results,
171 displayed = 5,
172 overflow = total_results - 5,
173 detail_level = detail_level.as_str(),
174 "Tool search completed with overflow"
175 );
176 results.truncate(5);
177 } else {
178 info!(
179 keyword = keyword,
180 matched = total_results,
181 detail_level = detail_level.as_str(),
182 "Tool search completed"
183 );
184 }
185
186 Ok(results)
187 }
188
189 pub async fn get_tool_detail(&self, tool_name: &str) -> Result<Option<ToolDiscoveryResult>> {
191 let tools = self.mcp_client.list_mcp_tools().await?;
192
193 for tool in tools {
194 if tool.name.eq_ignore_ascii_case(tool_name) {
195 return Ok(Some(ToolDiscoveryResult {
196 name: tool.name.clone(),
197 provider: tool.provider.clone(),
198 description: tool.description.clone(),
199 relevance_score: 1.0,
200 input_schema: Some(tool.input_schema),
201 }));
202 }
203 }
204
205 Ok(None)
206 }
207
208 pub async fn list_tools_by_provider(&self) -> Result<Vec<(String, Vec<ToolDiscoveryResult>)>> {
210 let tools = self.mcp_client.list_mcp_tools().await?;
211
212 Ok(group_results_by_provider_preserving_order(
213 tools.into_iter().map(|tool| ToolDiscoveryResult {
214 name: tool.name,
215 provider: tool.provider,
216 description: tool.description,
217 relevance_score: 1.0,
218 input_schema: None,
219 }),
220 ))
221 }
222
223 fn calculate_relevance(&self, tool: &McpToolInfo, keyword: &str) -> f32 {
227 let keyword_lower = keyword.to_lowercase();
228
229 if tool.name.eq_ignore_ascii_case(keyword) {
231 return 1.0;
232 }
233
234 if tool.name.to_lowercase().contains(&keyword_lower) {
236 return 0.8;
237 }
238
239 if tool.description.to_lowercase().contains(&keyword_lower) {
241 return 0.6;
242 }
243
244 let name_fuzzy = self.fuzzy_score(&tool.name.to_lowercase(), &keyword_lower);
246 if name_fuzzy > 0.3 {
247 return 0.5 * name_fuzzy;
248 }
249
250 let desc_fuzzy = self.fuzzy_score(&tool.description.to_lowercase(), &keyword_lower);
251 if desc_fuzzy > 0.3 {
252 return 0.3 * desc_fuzzy;
253 }
254
255 0.0
256 }
257
258 fn fuzzy_score(&self, haystack: &str, needle: &str) -> f32 {
260 if needle.is_empty() {
261 return 1.0;
262 }
263
264 if haystack.is_empty() {
265 return 0.0;
266 }
267
268 let mut score = 0.0;
269 let mut haystack_idx = 0;
270
271 for needle_char in needle.chars() {
272 if let Some(pos) = haystack[haystack_idx..].find(needle_char) {
273 haystack_idx += pos + 1;
274 score += 1.0;
275 } else {
276 return 0.0;
277 }
278 }
279
280 score / needle.len() as f32
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use serde_json::json;
289
290 fn mock_tool(provider: &str, name: &str, description: &str) -> McpToolInfo {
291 McpToolInfo {
292 name: name.to_string(),
293 description: description.to_string(),
294 provider: provider.to_string(),
295 input_schema: json!({}),
296 }
297 }
298
299 #[test]
300 fn fuzzy_score_exact_match() {
301 let discovery = ToolDiscovery::new(Arc::new(MockMcpClient::default()));
302 assert_eq!(discovery.fuzzy_score("read_file", "read_file"), 1.0);
303 }
304
305 #[test]
306 fn fuzzy_score_partial_match() {
307 let discovery = ToolDiscovery::new(Arc::new(MockMcpClient::default()));
308 let score = discovery.fuzzy_score("read_file_contents", "read");
309 assert!(score > 0.5 && score <= 1.0);
310 }
311
312 #[test]
313 fn fuzzy_score_no_match() {
314 let discovery = ToolDiscovery::new(Arc::new(MockMcpClient::default()));
315 assert_eq!(discovery.fuzzy_score("read_file", "xyz"), 0.0);
316 }
317
318 #[tokio::test]
319 async fn list_tools_by_provider_preserves_first_seen_provider_and_tool_order() {
320 let discovery = ToolDiscovery::new(Arc::new(MockMcpClient {
321 tools: vec![
322 mock_tool("gmail", "send_email", "Send an email."),
323 mock_tool("calendar", "create_event", "Create a calendar event."),
324 mock_tool("gmail", "read_email", "Read an email."),
325 mock_tool("docs", "search", "Search docs."),
326 mock_tool("calendar", "list_events", "List calendar events."),
327 ],
328 }));
329
330 let grouped = discovery
331 .list_tools_by_provider()
332 .await
333 .expect("grouped tools");
334
335 let providers = grouped
336 .iter()
337 .map(|(provider, _)| provider.as_str())
338 .collect::<Vec<_>>();
339 assert_eq!(providers, vec!["gmail", "calendar", "docs"]);
340
341 let tool_names = grouped
342 .into_iter()
343 .map(|(_, tools)| tools.into_iter().map(|tool| tool.name).collect::<Vec<_>>())
344 .collect::<Vec<_>>();
345 assert_eq!(
346 tool_names,
347 vec![
348 vec!["send_email".to_string(), "read_email".to_string()],
349 vec!["create_event".to_string(), "list_events".to_string()],
350 vec!["search".to_string()],
351 ]
352 );
353 }
354
355 #[derive(Default)]
357 struct MockMcpClient {
358 tools: Vec<McpToolInfo>,
359 }
360
361 #[async_trait::async_trait]
362 impl crate::mcp::McpToolExecutor for MockMcpClient {
363 async fn execute_mcp_tool(&self, _tool_name: &str, _args: &Value) -> Result<Value> {
364 Ok(Value::Null)
365 }
366
367 async fn list_mcp_tools(&self) -> Result<Vec<McpToolInfo>> {
368 Ok(self.tools.clone())
369 }
370
371 async fn has_mcp_tool(&self, _tool_name: &str) -> Result<bool> {
372 Ok(false)
373 }
374
375 fn get_status(&self) -> crate::mcp::McpClientStatus {
376 crate::mcp::McpClientStatus {
377 enabled: true,
378 provider_count: 0,
379 active_connections: 0,
380 configured_providers: vec![],
381 }
382 }
383 }
384}