1use crate::error::SearchError;
23use crate::index::store::SeekrIndex;
24use crate::parser::{ChunkKind, CodeChunk};
25
26#[derive(Debug, Clone)]
28pub struct AstPattern {
29 pub qualifiers: Vec<String>,
31
32 pub kind: PatternKind,
34
35 pub name_pattern: Option<String>,
37
38 pub param_patterns: Option<Vec<String>>,
40
41 pub return_pattern: Option<String>,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum PatternKind {
48 Function,
50 Class,
52 Struct,
54 Enum,
56 Interface,
58 Any,
60}
61
62#[derive(Debug, Clone)]
64pub struct AstMatch {
65 pub chunk_id: u64,
67
68 pub score: f32,
70}
71
72pub fn parse_pattern(pattern: &str) -> Result<AstPattern, SearchError> {
82 let pattern = pattern.trim();
83
84 if pattern.is_empty() {
85 return Err(SearchError::InvalidAstPattern("Empty pattern".to_string()));
86 }
87
88 let tokens = tokenize_pattern(pattern);
89
90 if tokens.is_empty() {
91 return Err(SearchError::InvalidAstPattern(
92 "Could not parse pattern".to_string(),
93 ));
94 }
95
96 let mut idx = 0;
97 let mut qualifiers = Vec::new();
98
99 while idx < tokens.len() {
101 match tokens[idx].as_str() {
102 "async" | "pub" | "static" | "export" | "private" | "protected" | "public"
103 | "abstract" | "virtual" | "override" | "const" | "mut" | "unsafe" => {
104 qualifiers.push(tokens[idx].clone());
105 idx += 1;
106 }
107 _ => break,
108 }
109 }
110
111 if idx >= tokens.len() {
112 return Err(SearchError::InvalidAstPattern(
113 "Pattern has only qualifiers, missing kind (fn, class, struct, etc.)".to_string(),
114 ));
115 }
116
117 let (kind, idx) = parse_kind(&tokens, idx)?;
119
120 let (name_pattern, idx) = parse_name(&tokens, idx);
122
123 let (param_patterns, idx) = if matches!(kind, PatternKind::Function | PatternKind::Any) {
125 parse_params(&tokens, idx)?
126 } else {
127 (None, idx)
128 };
129
130 let return_pattern = parse_return_type(&tokens, idx);
132
133 Ok(AstPattern {
134 qualifiers,
135 kind,
136 name_pattern,
137 param_patterns,
138 return_pattern,
139 })
140}
141
142pub fn search_ast_pattern(
144 index: &SeekrIndex,
145 pattern: &str,
146 top_k: usize,
147) -> Result<Vec<AstMatch>, SearchError> {
148 let parsed = parse_pattern(pattern)?;
149
150 let mut matches: Vec<AstMatch> = Vec::new();
151
152 for chunk in index.chunks.values() {
153 let score = match_chunk(&parsed, chunk);
154 if score > 0.0 {
155 matches.push(AstMatch {
156 chunk_id: chunk.id,
157 score,
158 });
159 }
160 }
161
162 matches.sort_by(|a, b| {
164 b.score
165 .partial_cmp(&a.score)
166 .unwrap_or(std::cmp::Ordering::Equal)
167 });
168
169 matches.truncate(top_k);
170
171 Ok(matches)
172}
173
174fn match_chunk(pattern: &AstPattern, chunk: &CodeChunk) -> f32 {
177 let mut score = 0.0f32;
178 let mut total_criteria = 0.0f32;
179
180 total_criteria += 0.3;
182 if match_kind(&pattern.kind, &chunk.kind) {
183 score += 0.3;
184 } else {
185 return 0.0;
187 }
188
189 if !pattern.qualifiers.is_empty() {
191 total_criteria += 0.1;
192 let sig_lower = chunk.signature.as_deref().unwrap_or("").to_lowercase();
193 let body_start = chunk.body.lines().next().unwrap_or("").to_lowercase();
194 let combined = format!("{} {}", sig_lower, body_start);
195
196 let matched_quals = pattern
197 .qualifiers
198 .iter()
199 .filter(|q| combined.contains(q.as_str()))
200 .count();
201
202 if !pattern.qualifiers.is_empty() {
203 score += 0.1 * (matched_quals as f32 / pattern.qualifiers.len() as f32);
204 }
205 }
206
207 if let Some(ref name_pat) = pattern.name_pattern {
209 total_criteria += 0.3;
210 if let Some(ref chunk_name) = chunk.name {
211 if wildcard_match(name_pat, chunk_name) {
212 score += 0.3;
213 } else if chunk_name
214 .to_lowercase()
215 .contains(&name_pat.to_lowercase().replace('*', ""))
216 {
217 score += 0.15;
219 }
220 }
221 }
222
223 if let Some(ref param_pats) = pattern.param_patterns {
225 total_criteria += 0.15;
226 let sig = chunk.signature.as_deref().unwrap_or(&chunk.body);
227 let chunk_params = extract_params_from_signature(sig);
228
229 if param_pats.len() == 1 && param_pats[0] == "*" {
230 score += 0.15;
232 } else if param_pats.is_empty() && chunk_params.is_empty() {
233 score += 0.15;
235 } else {
236 let param_score = match_param_types(param_pats, &chunk_params);
237 score += 0.15 * param_score;
238 }
239 }
240
241 if let Some(ref ret_pat) = pattern.return_pattern {
243 total_criteria += 0.15;
244 let sig = chunk.signature.as_deref().unwrap_or(&chunk.body);
245 let chunk_ret = extract_return_type_from_signature(sig);
246
247 if ret_pat == "*" {
248 score += 0.15;
249 } else if let Some(ref chunk_ret) = chunk_ret {
250 if fuzzy_type_match(ret_pat, chunk_ret) {
251 score += 0.15;
252 } else if chunk_ret.to_lowercase().contains(&ret_pat.to_lowercase()) {
253 score += 0.075; }
255 }
256 }
257
258 if total_criteria > 0.0 {
260 score / total_criteria
261 } else {
262 0.0
263 }
264}
265
266fn tokenize_pattern(pattern: &str) -> Vec<String> {
272 let mut tokens = Vec::new();
273 let mut current = String::new();
274 let mut chars = pattern.chars().peekable();
275
276 while let Some(ch) = chars.next() {
277 match ch {
278 '(' | ')' | ',' => {
279 if !current.is_empty() {
280 tokens.push(std::mem::take(&mut current));
281 }
282 tokens.push(ch.to_string());
283 }
284 '-' if chars.peek() == Some(&'>') => {
285 if !current.is_empty() {
286 tokens.push(std::mem::take(&mut current));
287 }
288 chars.next(); tokens.push("->".to_string());
290 }
291 ' ' | '\t' => {
292 if !current.is_empty() {
293 tokens.push(std::mem::take(&mut current));
294 }
295 }
296 _ => {
297 current.push(ch);
298 }
299 }
300 }
301
302 if !current.is_empty() {
303 tokens.push(current);
304 }
305
306 tokens
307}
308
309fn parse_kind(tokens: &[String], idx: usize) -> Result<(PatternKind, usize), SearchError> {
311 if idx >= tokens.len() {
312 return Ok((PatternKind::Any, idx));
313 }
314
315 let kind_str = tokens[idx].to_lowercase();
316 let kind = match kind_str.as_str() {
317 "fn" | "func" | "function" | "def" | "method" => PatternKind::Function,
318 "class" => PatternKind::Class,
319 "struct" => PatternKind::Struct,
320 "enum" => PatternKind::Enum,
321 "interface" | "trait" | "protocol" => PatternKind::Interface,
322 "*" => PatternKind::Any,
323 _ => {
324 return Ok((PatternKind::Function, idx));
327 }
328 };
329
330 Ok((kind, idx + 1))
331}
332
333fn parse_name(tokens: &[String], idx: usize) -> (Option<String>, usize) {
335 if idx >= tokens.len() {
336 return (None, idx);
337 }
338
339 if tokens[idx] != "(" && tokens[idx] != "->" && tokens[idx] != ")" && tokens[idx] != "," {
341 (Some(tokens[idx].clone()), idx + 1)
342 } else {
343 (None, idx)
344 }
345}
346
347fn parse_params(
349 tokens: &[String],
350 idx: usize,
351) -> Result<(Option<Vec<String>>, usize), SearchError> {
352 if idx >= tokens.len() || tokens[idx] != "(" {
353 return Ok((None, idx));
354 }
355
356 let mut params = Vec::new();
357 let mut i = idx + 1; while i < tokens.len() && tokens[i] != ")" {
360 if tokens[i] == "," {
361 i += 1;
362 continue;
363 }
364 params.push(tokens[i].clone());
365 i += 1;
366 }
367
368 if i < tokens.len() && tokens[i] == ")" {
369 i += 1; }
371
372 Ok((Some(params), i))
373}
374
375fn parse_return_type(tokens: &[String], idx: usize) -> Option<String> {
377 if idx + 1 < tokens.len() && tokens[idx] == "->" {
378 let ret_parts: Vec<&str> = tokens[idx + 1..].iter().map(|s| s.as_str()).collect();
380 if ret_parts.is_empty() {
381 None
382 } else {
383 Some(ret_parts.join(" "))
384 }
385 } else {
386 None
387 }
388}
389
390fn match_kind(pattern_kind: &PatternKind, chunk_kind: &ChunkKind) -> bool {
396 match pattern_kind {
397 PatternKind::Any => true,
398 PatternKind::Function => matches!(chunk_kind, ChunkKind::Function | ChunkKind::Method),
399 PatternKind::Class => matches!(chunk_kind, ChunkKind::Class),
400 PatternKind::Struct => matches!(chunk_kind, ChunkKind::Struct),
401 PatternKind::Enum => matches!(chunk_kind, ChunkKind::Enum),
402 PatternKind::Interface => matches!(chunk_kind, ChunkKind::Interface),
403 }
404}
405
406fn wildcard_match(pattern: &str, text: &str) -> bool {
411 let pattern = pattern.to_lowercase();
412 let text = text.to_lowercase();
413
414 if !pattern.contains('*') {
415 return pattern == text;
416 }
417
418 let parts: Vec<&str> = pattern.split('*').collect();
419
420 if parts.len() == 1 {
421 return pattern == text;
422 }
423
424 let mut text_pos = 0;
425
426 for (i, part) in parts.iter().enumerate() {
427 if part.is_empty() {
428 continue;
429 }
430
431 if i == 0 {
432 if !text[text_pos..].starts_with(part) {
434 return false;
435 }
436 text_pos += part.len();
437 } else if i == parts.len() - 1 {
438 if !text[text_pos..].ends_with(part) {
440 return false;
441 }
442 } else {
443 match text[text_pos..].find(part) {
445 Some(pos) => text_pos += pos + part.len(),
446 None => return false,
447 }
448 }
449 }
450
451 true
452}
453
454fn extract_params_from_signature(sig: &str) -> Vec<String> {
461 let Some(open) = sig.find('(') else {
463 return Vec::new();
464 };
465
466 let mut depth = 0;
467 let mut close = None;
468
469 for (i, ch) in sig[open..].char_indices() {
470 match ch {
471 '(' => depth += 1,
472 ')' => {
473 depth -= 1;
474 if depth == 0 {
475 close = Some(open + i);
476 break;
477 }
478 }
479 _ => {}
480 }
481 }
482
483 let Some(close) = close else {
484 return Vec::new();
485 };
486
487 let params_str = &sig[open + 1..close];
488 if params_str.trim().is_empty() {
489 return Vec::new();
490 }
491
492 split_params(params_str)
494 .iter()
495 .filter_map(|p| extract_type_from_param(p.trim()))
496 .collect()
497}
498
499fn split_params(params: &str) -> Vec<String> {
501 let mut parts = Vec::new();
502 let mut current = String::new();
503 let mut depth = 0;
504
505 for ch in params.chars() {
506 match ch {
507 '<' | '(' | '[' | '{' => {
508 depth += 1;
509 current.push(ch);
510 }
511 '>' | ')' | ']' | '}' => {
512 depth -= 1;
513 current.push(ch);
514 }
515 ',' if depth == 0 => {
516 parts.push(std::mem::take(&mut current));
517 }
518 _ => current.push(ch),
519 }
520 }
521
522 if !current.is_empty() {
523 parts.push(current);
524 }
525
526 parts
527}
528
529fn extract_type_from_param(param: &str) -> Option<String> {
538 let param = param.trim();
539 if param.is_empty() {
540 return None;
541 }
542
543 if let Some(colon_pos) = param.find(':') {
545 let type_part = param[colon_pos + 1..].trim();
546 let type_part = type_part
548 .trim_start_matches('&')
549 .trim_start_matches("mut ")
550 .trim();
551 return Some(type_part.to_string());
552 }
553
554 let parts: Vec<&str> = param.split_whitespace().collect();
556 if parts.len() >= 2 {
557 return Some(parts.last().unwrap().to_string());
558 }
559
560 Some(param.to_string())
562}
563
564fn extract_return_type_from_signature(sig: &str) -> Option<String> {
566 if let Some(arrow_pos) = sig.find("->") {
568 let ret = sig[arrow_pos + 2..].trim();
569 let ret = ret.trim_end_matches(|c: char| c == '{' || c == ':' || c.is_whitespace());
571 if !ret.is_empty() {
572 return Some(ret.to_string());
573 }
574 }
575
576 if let Some(close_paren) = sig.rfind(')') {
579 let after = sig[close_paren + 1..].trim();
580 if let Some(stripped) = after.strip_prefix(':') {
581 let ret = stripped
582 .trim()
583 .trim_end_matches(|c: char| c == '{' || c.is_whitespace());
584 if !ret.is_empty() {
585 return Some(ret.to_string());
586 }
587 }
588 }
589
590 None
591}
592
593fn match_param_types(pattern_params: &[String], chunk_params: &[String]) -> f32 {
596 if pattern_params.is_empty() && chunk_params.is_empty() {
597 return 1.0;
598 }
599
600 if pattern_params.is_empty() || chunk_params.is_empty() {
601 return 0.0;
602 }
603
604 let pattern_count = pattern_params.len();
606 let chunk_count = chunk_params.len();
607
608 if pattern_count == 1 && pattern_params[0] == "*" {
610 return 1.0;
611 }
612
613 let fixed_params: Vec<&String> = pattern_params
615 .iter()
616 .filter(|p| p.as_str() != "*")
617 .collect();
618
619 if fixed_params.len() > chunk_count {
620 return 0.0; }
622
623 let mut matched = 0;
624 let mut chunk_idx = 0;
625
626 for pat_param in pattern_params {
627 if pat_param == "*" {
628 matched += 1;
629 if chunk_idx < chunk_count {
630 chunk_idx += 1;
631 }
632 continue;
633 }
634
635 while chunk_idx < chunk_count {
637 if fuzzy_type_match(pat_param, &chunk_params[chunk_idx]) {
638 matched += 1;
639 chunk_idx += 1;
640 break;
641 }
642 chunk_idx += 1;
643 }
644 }
645
646 matched as f32 / pattern_params.len() as f32
647}
648
649fn fuzzy_type_match(pattern: &str, actual: &str) -> bool {
659 let pattern = pattern.trim().to_lowercase();
660 let actual = actual.trim().to_lowercase();
661
662 if pattern == "*" {
663 return true;
664 }
665
666 if pattern == actual {
668 return true;
669 }
670
671 if pattern.contains('*') {
673 return wildcard_match(&pattern, &actual);
674 }
675
676 match pattern.as_str() {
678 "string" | "str" => {
679 matches!(
680 actual.as_str(),
681 "string" | "str" | "&str" | "std::string::string" | "text"
682 )
683 }
684 "number" | "num" | "int" | "integer" => {
685 matches!(
686 actual.as_str(),
687 "i8" | "i16"
688 | "i32"
689 | "i64"
690 | "i128"
691 | "isize"
692 | "u8"
693 | "u16"
694 | "u32"
695 | "u64"
696 | "u128"
697 | "usize"
698 | "f32"
699 | "f64"
700 | "int"
701 | "int8"
702 | "int16"
703 | "int32"
704 | "int64"
705 | "uint"
706 | "float"
707 | "float32"
708 | "float64"
709 | "number"
710 | "double"
711 )
712 }
713 "bool" | "boolean" => {
714 matches!(actual.as_str(), "bool" | "boolean")
715 }
716 "void" | "none" | "unit" | "()" => {
717 matches!(actual.as_str(), "void" | "none" | "()" | "unit" | "null")
718 }
719 _ => {
720 actual.contains(&pattern) || pattern.contains(&actual)
722 }
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use crate::parser::ChunkKind;
730 use std::path::PathBuf;
731
732 fn make_chunk(id: u64, kind: ChunkKind, name: &str, sig: &str, body: &str) -> CodeChunk {
733 CodeChunk {
734 id,
735 file_path: PathBuf::from("test.rs"),
736 language: "rust".to_string(),
737 kind,
738 name: Some(name.to_string()),
739 signature: Some(sig.to_string()),
740 doc_comment: None,
741 body: body.to_string(),
742 byte_range: 0..body.len(),
743 line_range: 0..body.lines().count(),
744 }
745 }
746
747 #[test]
748 fn test_parse_simple_pattern() {
749 let pat = parse_pattern("fn(string) -> number").unwrap();
750 assert_eq!(pat.kind, PatternKind::Function);
751 assert!(pat.name_pattern.is_none());
752 assert_eq!(pat.param_patterns.as_ref().unwrap(), &["string"]);
753 assert_eq!(pat.return_pattern.as_ref().unwrap(), "number");
754 }
755
756 #[test]
757 fn test_parse_named_function_pattern() {
758 let pat = parse_pattern("fn authenticate(*)").unwrap();
759 assert_eq!(pat.kind, PatternKind::Function);
760 assert_eq!(pat.name_pattern.as_ref().unwrap(), "authenticate");
761 assert_eq!(pat.param_patterns.as_ref().unwrap(), &["*"]);
762 }
763
764 #[test]
765 fn test_parse_async_pattern() {
766 let pat = parse_pattern("async fn(*) -> Result").unwrap();
767 assert_eq!(pat.kind, PatternKind::Function);
768 assert!(pat.qualifiers.contains(&"async".to_string()));
769 assert_eq!(pat.return_pattern.as_ref().unwrap(), "Result");
770 }
771
772 #[test]
773 fn test_parse_class_pattern() {
774 let pat = parse_pattern("class User").unwrap();
775 assert_eq!(pat.kind, PatternKind::Class);
776 assert_eq!(pat.name_pattern.as_ref().unwrap(), "User");
777 }
778
779 #[test]
780 fn test_parse_struct_wildcard() {
781 let pat = parse_pattern("struct *Config").unwrap();
782 assert_eq!(pat.kind, PatternKind::Struct);
783 assert_eq!(pat.name_pattern.as_ref().unwrap(), "*Config");
784 }
785
786 #[test]
787 fn test_parse_multi_param() {
788 let pat = parse_pattern("fn(string, number) -> bool").unwrap();
789 assert_eq!(pat.kind, PatternKind::Function);
790 let params = pat.param_patterns.as_ref().unwrap();
791 assert_eq!(params.len(), 2);
792 assert_eq!(params[0], "string");
793 assert_eq!(params[1], "number");
794 assert_eq!(pat.return_pattern.as_ref().unwrap(), "bool");
795 }
796
797 #[test]
798 fn test_parse_empty_params() {
799 let pat = parse_pattern("fn()").unwrap();
800 assert_eq!(pat.kind, PatternKind::Function);
801 assert!(pat.param_patterns.as_ref().unwrap().is_empty());
802 }
803
804 #[test]
805 fn test_wildcard_match() {
806 assert!(wildcard_match("*Config", "SeekrConfig"));
807 assert!(wildcard_match("*Config", "AppConfig"));
808 assert!(!wildcard_match("*Config", "ConfigManager"));
809 assert!(wildcard_match("Auth*", "AuthService"));
810 assert!(wildcard_match("*", "anything"));
811 assert!(wildcard_match("exact", "exact"));
812 assert!(!wildcard_match("exact", "notexact"));
813 }
814
815 #[test]
816 fn test_fuzzy_type_match() {
817 assert!(fuzzy_type_match("string", "String"));
818 assert!(fuzzy_type_match("string", "&str"));
819 assert!(fuzzy_type_match("number", "i32"));
820 assert!(fuzzy_type_match("number", "f64"));
821 assert!(fuzzy_type_match("bool", "boolean"));
822 assert!(fuzzy_type_match("*", "anything"));
823 assert!(fuzzy_type_match("Result*", "Result<String, Error>"));
824 }
825
826 #[test]
827 fn test_extract_params_rust() {
828 let params =
829 extract_params_from_signature("fn authenticate(user: &str, password: String) -> bool");
830 assert_eq!(params.len(), 2);
831 assert_eq!(params[0], "str");
832 assert_eq!(params[1], "String");
833 }
834
835 #[test]
836 fn test_extract_return_type_rust() {
837 let ret = extract_return_type_from_signature("fn foo(x: i32) -> Result<String, Error>");
838 assert_eq!(ret, Some("Result<String, Error>".to_string()));
839 }
840
841 #[test]
842 fn test_extract_return_type_arrow() {
843 let ret = extract_return_type_from_signature("def foo(x: int) -> bool:");
844 assert_eq!(ret, Some("bool".to_string()));
845 }
846
847 #[test]
848 fn test_match_function_by_return_type() {
849 let pat = parse_pattern("fn(*) -> Result").unwrap();
850
851 let chunk = make_chunk(
852 1,
853 ChunkKind::Function,
854 "authenticate",
855 "fn authenticate(user: &str) -> Result<Token, Error>",
856 "fn authenticate(user: &str) -> Result<Token, Error> { }",
857 );
858
859 let score = match_chunk(&pat, &chunk);
860 assert!(
861 score > 0.5,
862 "Should match function returning Result, got {}",
863 score
864 );
865 }
866
867 #[test]
868 fn test_match_function_by_name() {
869 let pat = parse_pattern("fn authenticate(*)").unwrap();
870
871 let chunk = make_chunk(
872 1,
873 ChunkKind::Function,
874 "authenticate",
875 "fn authenticate(user: &str, pass: &str) -> bool",
876 "fn authenticate(user: &str, pass: &str) -> bool { }",
877 );
878
879 let score = match_chunk(&pat, &chunk);
880 assert!(score > 0.5, "Should match by name, got {}", score);
881 }
882
883 #[test]
884 fn test_no_match_wrong_kind() {
885 let pat = parse_pattern("class Foo").unwrap();
886
887 let chunk = make_chunk(1, ChunkKind::Function, "Foo", "fn Foo()", "fn Foo() {}");
888
889 let score = match_chunk(&pat, &chunk);
890 assert_eq!(score, 0.0, "Should not match wrong kind");
891 }
892
893 #[test]
894 fn test_search_ast_pattern_integration() {
895 let mut index = SeekrIndex::new(4);
896
897 let chunks = vec![
899 make_chunk(
900 1,
901 ChunkKind::Function,
902 "authenticate",
903 "fn authenticate(user: &str) -> Result<Token, AuthError>",
904 "fn authenticate(user: &str) -> Result<Token, AuthError> { }",
905 ),
906 make_chunk(
907 2,
908 ChunkKind::Function,
909 "calculate",
910 "fn calculate(x: f64, y: f64) -> f64",
911 "fn calculate(x: f64, y: f64) -> f64 { x + y }",
912 ),
913 make_chunk(
914 3,
915 ChunkKind::Struct,
916 "AppConfig",
917 "pub struct AppConfig",
918 "pub struct AppConfig { pub port: u16 }",
919 ),
920 ];
921
922 for chunk in &chunks {
923 let entry = crate::index::IndexEntry {
924 chunk_id: chunk.id,
925 embedding: vec![0.1; 4],
926 text_tokens: vec![],
927 };
928 index.add_entry(entry, chunk.clone());
929 }
930
931 let results = search_ast_pattern(&index, "fn(*) -> Result", 10).unwrap();
933 assert!(!results.is_empty());
934 assert_eq!(results[0].chunk_id, 1);
935
936 let results = search_ast_pattern(&index, "struct *Config", 10).unwrap();
938 assert!(!results.is_empty());
939 assert_eq!(results[0].chunk_id, 3);
940
941 let results = search_ast_pattern(&index, "fn calculate(*)", 10).unwrap();
943 assert!(!results.is_empty());
944 assert_eq!(results[0].chunk_id, 2);
945 }
946
947 #[test]
948 fn test_empty_pattern_error() {
949 let result = parse_pattern("");
950 assert!(result.is_err());
951 }
952}