1use serde::{Deserialize, Serialize};
8use thulp_core::ToolDefinition;
9
10pub type Result<T> = std::result::Result<T, QueryError>;
12
13#[derive(Debug, thiserror::Error)]
15pub enum QueryError {
16 #[error("Parse error: {0}")]
17 Parse(String),
18
19 #[error("Invalid query: {0}")]
20 Invalid(String),
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub enum QueryCriteria {
26 Name(String),
28
29 Description(String),
31
32 HasParameter(String),
34
35 MinParameters(usize),
37
38 MaxParameters(usize),
40
41 And(Vec<QueryCriteria>),
43
44 Or(Vec<QueryCriteria>),
46
47 Not(Box<QueryCriteria>),
49}
50
51impl QueryCriteria {
52 pub fn matches(&self, tool: &ToolDefinition) -> bool {
54 match self {
55 QueryCriteria::Name(pattern) => {
56 if pattern.contains('*') {
57 let regex = pattern.replace('*', ".*");
58 regex::Regex::new(®ex)
59 .map(|re| re.is_match(&tool.name))
60 .unwrap_or(false)
61 } else {
62 tool.name.contains(pattern)
63 }
64 }
65 QueryCriteria::Description(keyword) => tool
66 .description
67 .to_lowercase()
68 .contains(&keyword.to_lowercase()),
69 QueryCriteria::HasParameter(param_name) => {
70 tool.parameters.iter().any(|p| p.name == *param_name)
71 }
72 QueryCriteria::MinParameters(min) => tool.parameters.len() >= *min,
73 QueryCriteria::MaxParameters(max) => tool.parameters.len() <= *max,
74 QueryCriteria::And(criteria) => criteria.iter().all(|c| c.matches(tool)),
75 QueryCriteria::Or(criteria) => criteria.iter().any(|c| c.matches(tool)),
76 QueryCriteria::Not(criteria) => !criteria.matches(tool),
77 }
78 }
79}
80
81#[derive(Debug, Default)]
83pub struct QueryBuilder {
84 criteria: Vec<QueryCriteria>,
85}
86
87impl QueryBuilder {
88 pub fn new() -> Self {
90 Self::default()
91 }
92
93 pub fn name(mut self, pattern: impl Into<String>) -> Self {
95 self.criteria.push(QueryCriteria::Name(pattern.into()));
96 self
97 }
98
99 pub fn description(mut self, keyword: impl Into<String>) -> Self {
101 self.criteria
102 .push(QueryCriteria::Description(keyword.into()));
103 self
104 }
105
106 pub fn has_parameter(mut self, param_name: impl Into<String>) -> Self {
108 self.criteria
109 .push(QueryCriteria::HasParameter(param_name.into()));
110 self
111 }
112
113 pub fn min_parameters(mut self, min: usize) -> Self {
115 self.criteria.push(QueryCriteria::MinParameters(min));
116 self
117 }
118
119 pub fn max_parameters(mut self, max: usize) -> Self {
121 self.criteria.push(QueryCriteria::MaxParameters(max));
122 self
123 }
124
125 pub fn build(self) -> Query {
127 Query {
128 criteria: if self.criteria.len() == 1 {
129 self.criteria.into_iter().next().unwrap()
130 } else {
131 QueryCriteria::And(self.criteria)
132 },
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct Query {
140 criteria: QueryCriteria,
141}
142
143impl Query {
144 pub fn new(criteria: QueryCriteria) -> Self {
146 Self { criteria }
147 }
148
149 pub fn execute(&self, tools: &[ToolDefinition]) -> Vec<ToolDefinition> {
151 tools
152 .iter()
153 .filter(|tool| self.criteria.matches(tool))
154 .cloned()
155 .collect()
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use thulp_core::Parameter;
163
164 fn create_test_tool(name: &str, description: &str, param_count: usize) -> ToolDefinition {
165 let mut builder = ToolDefinition::builder(name).description(description);
166
167 for i in 0..param_count {
168 builder = builder.parameter(Parameter::required_string(format!("param{}", i)));
169 }
170
171 builder.build()
172 }
173
174 #[test]
175 fn test_query_name() {
176 let tool = create_test_tool("file_read", "Read a file", 1);
177 let criteria = QueryCriteria::Name("file".to_string());
178 assert!(criteria.matches(&tool));
179 }
180
181 #[test]
182 fn test_query_name_wildcard() {
183 let tool = create_test_tool("file_read", "Read a file", 1);
184 let criteria = QueryCriteria::Name("file_*".to_string());
185 assert!(criteria.matches(&tool));
186 }
187
188 #[test]
189 fn test_query_description() {
190 let tool = create_test_tool("file_read", "Read a file from disk", 1);
191 let criteria = QueryCriteria::Description("disk".to_string());
192 assert!(criteria.matches(&tool));
193 }
194
195 #[test]
196 fn test_query_has_parameter() {
197 let tool = ToolDefinition::builder("test")
198 .parameter(Parameter::required_string("path"))
199 .build();
200
201 let criteria = QueryCriteria::HasParameter("path".to_string());
202 assert!(criteria.matches(&tool));
203 }
204
205 #[test]
206 fn test_query_min_parameters() {
207 let tool = create_test_tool("test", "Test", 3);
208 let criteria = QueryCriteria::MinParameters(2);
209 assert!(criteria.matches(&tool));
210 }
211
212 #[test]
213 fn test_query_max_parameters() {
214 let tool = create_test_tool("test", "Test", 2);
215 let criteria = QueryCriteria::MaxParameters(3);
216 assert!(criteria.matches(&tool));
217 }
218
219 #[test]
220 fn test_query_and() {
221 let tool = create_test_tool("file_read", "Read a file", 2);
222 let criteria = QueryCriteria::And(vec![
223 QueryCriteria::Name("file".to_string()),
224 QueryCriteria::MinParameters(2),
225 ]);
226 assert!(criteria.matches(&tool));
227 }
228
229 #[test]
230 fn test_query_or() {
231 let tool = create_test_tool("file_read", "Read a file", 1);
232 let criteria = QueryCriteria::Or(vec![
233 QueryCriteria::Name("network".to_string()),
234 QueryCriteria::Name("file".to_string()),
235 ]);
236 assert!(criteria.matches(&tool));
237 }
238
239 #[test]
240 fn test_query_not() {
241 let tool = create_test_tool("file_read", "Read a file", 1);
242 let criteria = QueryCriteria::Not(Box::new(QueryCriteria::Name("network".to_string())));
243 assert!(criteria.matches(&tool));
244 }
245
246 #[test]
247 fn test_query_builder() {
248 let query = QueryBuilder::new().name("file").min_parameters(1).build();
249
250 let tools = vec![
251 create_test_tool("file_read", "Read", 2),
252 create_test_tool("network_get", "Get", 1),
253 ];
254
255 let results = query.execute(&tools);
256 assert_eq!(results.len(), 1);
257 assert_eq!(results[0].name, "file_read");
258 }
259
260 #[test]
261 fn test_query_execute() {
262 let query = Query::new(QueryCriteria::MinParameters(2));
263
264 let tools = vec![
265 create_test_tool("tool1", "Test 1", 1),
266 create_test_tool("tool2", "Test 2", 2),
267 create_test_tool("tool3", "Test 3", 3),
268 ];
269
270 let results = query.execute(&tools);
271 assert_eq!(results.len(), 2);
272 }
273}