1use parking_lot::RwLock;
10use regex::Regex;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use skillsrs_core::{CallableId, CallableKind, CallableRecord, RiskTier};
14use skillsrs_registry::Registry;
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use thiserror::Error;
18use tracing::debug;
19
20#[derive(Error, Debug)]
21pub enum IndexError {
22 #[error("Invalid query: {0}")]
23 InvalidQuery(String),
24
25 #[error("Regex compilation failed: {0}")]
26 RegexError(#[from] regex::Error),
27
28 #[error("Index error: {0}")]
29 Internal(String),
30}
31
32pub type Result<T> = std::result::Result<T, IndexError>;
33
34#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
36pub struct SearchFilters {
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub server: Option<String>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub tags: Option<Vec<String>>,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub requires: Option<Vec<String>>,
45
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub capability: Option<Vec<String>>,
48}
49
50#[derive(Debug, Clone)]
52pub struct SearchQuery {
53 pub q: String,
54 pub kind: String, pub mode: String, pub limit: usize,
57 pub filters: Option<SearchFilters>,
58 pub cursor: Option<String>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct SearchMatch {
64 pub id: String,
65 pub kind: String,
66 pub name: String,
67 pub fq_name: String,
68 pub server: Option<String>,
69 pub description_snippet: String,
70 pub inputs: Vec<String>,
71 pub score: f64,
72
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub signature_short: Option<String>,
75
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub schema_digest: Option<String>,
78
79 #[serde(skip_serializing_if = "Option::is_none")]
80 pub uses: Option<Vec<String>>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct SearchResults {
86 pub matches: Vec<SearchMatch>,
87 pub total_matches: usize,
88 pub next_cursor: Option<String>,
89}
90
91#[derive(Debug, Clone, Copy, PartialEq)]
93enum QueryIntent {
94 OutcomeBased, ApiBased, Neutral,
97}
98
99pub struct SearchEngine {
101 registry: Arc<Registry>,
102 index: Arc<RwLock<InMemoryIndex>>,
103}
104
105impl SearchEngine {
106 pub fn new(registry: Arc<Registry>) -> Self {
107 SearchEngine {
108 registry,
109 index: Arc::new(RwLock::new(InMemoryIndex::new())),
110 }
111 }
112
113 pub fn rebuild(&self) {
115 debug!("Rebuilding search index");
116 let callables = self.registry.all();
117 let mut index = self.index.write();
118 index.clear();
119
120 for record in callables {
121 index.add_record(&record);
122 }
123
124 debug!("Index rebuilt with {} entries", index.len());
125 }
126
127 pub fn update_record(&self, record: &CallableRecord) {
129 let mut index = self.index.write();
130 index.add_record(record);
131 }
132
133 pub fn remove_record(&self, id: &CallableId) {
135 let mut index = self.index.write();
136 index.remove_record(id);
137 }
138
139 pub async fn search(&self, query: &SearchQuery) -> Result<SearchResults> {
141 if query.q.is_empty() {
142 return Err(IndexError::InvalidQuery(
143 "Query cannot be empty".to_string(),
144 ));
145 }
146
147 debug!("Search query: {:?}", query.q);
148
149 let intent = detect_intent(&query.q);
151
152 let mut candidates = self.registry.all();
154
155 if query.kind != "any" {
157 let target_kind = match query.kind.as_str() {
158 "tools" => CallableKind::Tool,
159 "skills" => CallableKind::Skill,
160 _ => {
161 return Err(IndexError::InvalidQuery(format!(
162 "Invalid kind: {}",
163 query.kind
164 )));
165 }
166 };
167 candidates.retain(|c| c.kind == target_kind);
168 }
169
170 if let Some(filters) = &query.filters {
172 candidates = apply_filters(candidates, filters);
173 }
174
175 let mut scored: Vec<(CallableRecord, f64)> = candidates
177 .into_iter()
178 .filter_map(|record| {
179 let score = match query.mode.as_str() {
180 "literal" => score_literal(&query.q, &record, intent),
181 "regex" => score_regex(&query.q, &record).ok()?,
182 "fuzzy" => score_fuzzy(&query.q, &record),
183 _ => return None,
184 };
185
186 if score > 0.0 {
187 Some((record, score))
188 } else {
189 None
190 }
191 })
192 .collect();
193
194 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
196
197 let total_matches = scored.len();
198
199 let offset = query
201 .cursor
202 .as_ref()
203 .and_then(|c| c.parse::<usize>().ok())
204 .unwrap_or(0);
205
206 let matches: Vec<SearchMatch> = scored
207 .into_iter()
208 .skip(offset)
209 .take(query.limit)
210 .map(|(record, score)| {
211 let inputs = extract_input_keys(&record.input_schema);
212 let description_snippet = record
213 .description
214 .clone()
215 .unwrap_or_else(|| record.title.clone().unwrap_or_default())
216 .chars()
217 .take(200)
218 .collect();
219
220 SearchMatch {
221 id: record.id.as_str().to_string(),
222 kind: record.kind.to_string(),
223 name: record.name.clone(),
224 fq_name: record.fq_name.clone(),
225 server: record.server_alias.clone(),
226 description_snippet,
227 inputs,
228 score,
229 signature_short: None,
230 schema_digest: Some(record.schema_digest.short().to_string()),
231 uses: if record.kind == CallableKind::Skill {
232 Some(
233 record
234 .uses
235 .iter()
236 .map(|id| id.as_str().to_string())
237 .collect(),
238 )
239 } else {
240 None
241 },
242 }
243 })
244 .collect();
245
246 let next_cursor = if offset + query.limit < total_matches {
247 Some((offset + query.limit).to_string())
248 } else {
249 None
250 };
251
252 Ok(SearchResults {
253 matches,
254 total_matches,
255 next_cursor,
256 })
257 }
258}
259
260struct InMemoryIndex {
262 tokens: HashMap<String, HashSet<String>>,
264 reverse: HashMap<String, HashSet<String>>,
266}
267
268impl InMemoryIndex {
269 fn new() -> Self {
270 InMemoryIndex {
271 tokens: HashMap::new(),
272 reverse: HashMap::new(),
273 }
274 }
275
276 fn add_record(&mut self, record: &CallableRecord) {
277 let id = record.id.as_str().to_string();
278 let tokens = tokenize_record(record);
279
280 for token in &tokens {
281 self.tokens
282 .entry(token.clone())
283 .or_default()
284 .insert(id.clone());
285 }
286
287 self.reverse.insert(id, tokens);
288 }
289
290 fn remove_record(&mut self, id: &CallableId) {
291 let id_str = id.as_str();
292 if let Some(tokens) = self.reverse.remove(id_str) {
293 for token in tokens {
294 if let Some(ids) = self.tokens.get_mut(&token) {
295 ids.remove(id_str);
296 }
297 }
298 }
299 }
300
301 fn clear(&mut self) {
302 self.tokens.clear();
303 self.reverse.clear();
304 }
305
306 fn len(&self) -> usize {
307 self.reverse.len()
308 }
309}
310
311fn tokenize_record(record: &CallableRecord) -> HashSet<String> {
313 let mut tokens = HashSet::new();
314
315 for token in tokenize(&record.name) {
317 tokens.insert(token);
318 }
319
320 for token in tokenize(&record.fq_name) {
322 tokens.insert(token);
323 }
324
325 if let Some(title) = &record.title {
327 for token in tokenize(title) {
328 tokens.insert(token);
329 }
330 }
331
332 if let Some(desc) = &record.description {
334 for token in tokenize(desc) {
335 tokens.insert(token);
336 }
337 }
338
339 for tag in &record.tags {
341 tokens.insert(tag.to_lowercase());
342 }
343
344 tokens
345}
346
347fn tokenize(text: &str) -> Vec<String> {
349 text.to_lowercase()
350 .split(|c: char| !c.is_alphanumeric())
351 .filter(|s| !s.is_empty())
352 .map(String::from)
353 .collect()
354}
355
356fn detect_intent(query: &str) -> QueryIntent {
358 let lower = query.to_lowercase();
359
360 let outcome_keywords = [
362 "calibrate",
363 "characterize",
364 "measure",
365 "generate",
366 "report",
367 "analyze",
368 "plot",
369 "sweep",
370 "tune",
371 "optimize",
372 "workflow",
373 "procedure",
374 "sop",
375 "end-to-end",
376 "test",
377 "verify",
378 ];
379
380 let api_keywords = [
381 "path", "cursor", "bytes", "json", "schema", "id", "regex", "pattern", "list", "get",
382 "read", "write",
383 ];
384
385 let outcome_score = outcome_keywords
386 .iter()
387 .filter(|kw| lower.contains(*kw))
388 .count();
389
390 let api_score = api_keywords.iter().filter(|kw| lower.contains(*kw)).count();
391
392 if outcome_score > api_score {
393 QueryIntent::OutcomeBased
394 } else if api_score > outcome_score {
395 QueryIntent::ApiBased
396 } else {
397 QueryIntent::Neutral
398 }
399}
400
401fn score_literal(query: &str, record: &CallableRecord, intent: QueryIntent) -> f64 {
403 let query_tokens: HashSet<String> = tokenize(query).into_iter().collect();
404 let mut score = 0.0;
405
406 if record.name.to_lowercase() == query.to_lowercase() {
408 score += 100.0;
409 }
410
411 if record.fq_name.to_lowercase() == query.to_lowercase() {
413 score += 90.0;
414 }
415
416 let name_tokens = tokenize(&record.name);
418 for token in &query_tokens {
419 if name_tokens.contains(token) {
420 score += 20.0;
421 }
422 }
423
424 if let Some(title) = &record.title {
426 let title_tokens = tokenize(title);
427 for token in &query_tokens {
428 if title_tokens.contains(token) {
429 score += 10.0;
430 }
431 }
432 }
433
434 if let Some(desc) = &record.description {
436 let desc_tokens = tokenize(desc);
437 for token in &query_tokens {
438 if desc_tokens.contains(token) {
439 score += 5.0;
440 }
441 }
442 }
443
444 let schema_keys = extract_input_keys(&record.input_schema);
446 for token in &query_tokens {
447 if schema_keys.iter().any(|k| k.to_lowercase().contains(token)) {
448 score += 8.0;
449 }
450 }
451
452 for token in &query_tokens {
454 if record.tags.iter().any(|t| t.to_lowercase().contains(token)) {
455 score += 12.0;
456 }
457 }
458
459 match intent {
461 QueryIntent::OutcomeBased => {
462 if record.kind == CallableKind::Skill {
463 score *= 1.3; }
465 }
466 QueryIntent::ApiBased => {
467 if record.kind == CallableKind::Tool {
468 score *= 1.3; }
470 }
471 QueryIntent::Neutral => {}
472 }
473
474 if record.risk_tier >= RiskTier::Destructive {
476 let destructive_keywords = ["delete", "remove", "destroy", "drop", "clear"];
477 if !destructive_keywords
478 .iter()
479 .any(|kw| query.to_lowercase().contains(kw))
480 {
481 score *= 0.7;
482 }
483 }
484
485 score
486}
487
488fn score_regex(pattern: &str, record: &CallableRecord) -> Result<f64> {
490 let re = Regex::new(pattern)?;
491 let mut score = 0.0;
492
493 if re.is_match(&record.name) {
494 score += 50.0;
495 }
496
497 if re.is_match(&record.fq_name) {
498 score += 40.0;
499 }
500
501 if let Some(title) = &record.title {
502 if re.is_match(title) {
503 score += 20.0;
504 }
505 }
506
507 if let Some(desc) = &record.description {
508 if re.is_match(desc) {
509 score += 10.0;
510 }
511 }
512
513 Ok(score)
514}
515
516fn score_fuzzy(query: &str, record: &CallableRecord) -> f64 {
518 let query_lower = query.to_lowercase();
519 let mut score = 0.0;
520
521 if record.name.to_lowercase().contains(&query_lower) {
523 score += 30.0;
524 }
525
526 if record.fq_name.to_lowercase().contains(&query_lower) {
527 score += 25.0;
528 }
529
530 if let Some(title) = &record.title {
531 if title.to_lowercase().contains(&query_lower) {
532 score += 15.0;
533 }
534 }
535
536 if let Some(desc) = &record.description {
537 if desc.to_lowercase().contains(&query_lower) {
538 score += 10.0;
539 }
540 }
541
542 score
543}
544
545fn apply_filters(candidates: Vec<CallableRecord>, filters: &SearchFilters) -> Vec<CallableRecord> {
547 let mut filtered = candidates;
548
549 if let Some(server) = &filters.server {
551 filtered.retain(|c| {
552 c.server_alias
553 .as_ref()
554 .map(|s| s == server)
555 .unwrap_or(false)
556 });
557 }
558
559 if let Some(tags) = &filters.tags {
561 if !tags.is_empty() {
562 filtered.retain(|c| tags.iter().any(|tag| c.tags.contains(tag)));
563 }
564 }
565
566 if let Some(requires) = &filters.requires {
568 if !requires.is_empty() {
569 filtered.retain(|c| {
570 let keys = extract_input_keys(&c.input_schema);
571 requires.iter().all(|req| keys.contains(req))
572 });
573 }
574 }
575
576 if let Some(capability) = &filters.capability {
578 if !capability.is_empty() {
579 filtered.retain(|c| capability.iter().any(|cap| c.tags.contains(cap)));
580 }
581 }
582
583 filtered
584}
585
586fn extract_input_keys(schema: &serde_json::Value) -> Vec<String> {
588 if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
589 properties.keys().cloned().collect()
590 } else {
591 vec![]
592 }
593}