1use crate::error::SearchError;
23use crate::index::store::SeekrIndex;
24use crate::parser::{ChunkKind, CodeChunk};
25
26#[derive(Debug, Clone)]
28pub struct AstPattern {
29 pub qualifiers: Vec<String>,
31
32 pub kind: PatternKind,
34
35 pub name_pattern: Option<String>,
37
38 pub param_patterns: Option<Vec<String>>,
40
41 pub return_pattern: Option<String>,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum PatternKind {
48 Function,
50 Class,
52 Struct,
54 Enum,
56 Interface,
58 Any,
60}
61
62#[derive(Debug, Clone)]
64pub struct AstMatch {
65 pub chunk_id: u64,
67
68 pub score: f32,
70}
71
72pub fn parse_pattern(pattern: &str) -> Result<AstPattern, SearchError> {
82 let pattern = pattern.trim();
83
84 if pattern.is_empty() {
85 return Err(SearchError::InvalidAstPattern(
86 "Empty pattern".to_string(),
87 ));
88 }
89
90 let tokens = tokenize_pattern(pattern);
91
92 if tokens.is_empty() {
93 return Err(SearchError::InvalidAstPattern(
94 "Could not parse pattern".to_string(),
95 ));
96 }
97
98 let mut idx = 0;
99 let mut qualifiers = Vec::new();
100
101 while idx < tokens.len() {
103 match tokens[idx].as_str() {
104 "async" | "pub" | "static" | "export" | "private" | "protected" | "public"
105 | "abstract" | "virtual" | "override" | "const" | "mut" | "unsafe" => {
106 qualifiers.push(tokens[idx].clone());
107 idx += 1;
108 }
109 _ => break,
110 }
111 }
112
113 if idx >= tokens.len() {
114 return Err(SearchError::InvalidAstPattern(
115 "Pattern has only qualifiers, missing kind (fn, class, struct, etc.)".to_string(),
116 ));
117 }
118
119 let (kind, idx) = parse_kind(&tokens, idx)?;
121
122 let (name_pattern, idx) = parse_name(&tokens, idx);
124
125 let (param_patterns, idx) = if matches!(kind, PatternKind::Function | PatternKind::Any) {
127 parse_params(&tokens, idx)?
128 } else {
129 (None, idx)
130 };
131
132 let return_pattern = parse_return_type(&tokens, idx);
134
135 Ok(AstPattern {
136 qualifiers,
137 kind,
138 name_pattern,
139 param_patterns,
140 return_pattern,
141 })
142}
143
144pub fn search_ast_pattern(
146 index: &SeekrIndex,
147 pattern: &str,
148 top_k: usize,
149) -> Result<Vec<AstMatch>, SearchError> {
150 let parsed = parse_pattern(pattern)?;
151
152 let mut matches: Vec<AstMatch> = Vec::new();
153
154 for (_chunk_id, chunk) in &index.chunks {
155 let score = match_chunk(&parsed, chunk);
156 if score > 0.0 {
157 matches.push(AstMatch {
158 chunk_id: chunk.id,
159 score,
160 });
161 }
162 }
163
164 matches.sort_by(|a, b| {
166 b.score
167 .partial_cmp(&a.score)
168 .unwrap_or(std::cmp::Ordering::Equal)
169 });
170
171 matches.truncate(top_k);
172
173 Ok(matches)
174}
175
176fn match_chunk(pattern: &AstPattern, chunk: &CodeChunk) -> f32 {
179 let mut score = 0.0f32;
180 let mut total_criteria = 0.0f32;
181
182 total_criteria += 0.3;
184 if match_kind(&pattern.kind, &chunk.kind) {
185 score += 0.3;
186 } else {
187 return 0.0;
189 }
190
191 if !pattern.qualifiers.is_empty() {
193 total_criteria += 0.1;
194 let sig_lower = chunk
195 .signature
196 .as_deref()
197 .unwrap_or("")
198 .to_lowercase();
199 let body_start = chunk.body.lines().next().unwrap_or("").to_lowercase();
200 let combined = format!("{} {}", sig_lower, body_start);
201
202 let matched_quals = pattern
203 .qualifiers
204 .iter()
205 .filter(|q| combined.contains(q.as_str()))
206 .count();
207
208 if pattern.qualifiers.len() > 0 {
209 score += 0.1 * (matched_quals as f32 / pattern.qualifiers.len() as f32);
210 }
211 }
212
213 if let Some(ref name_pat) = pattern.name_pattern {
215 total_criteria += 0.3;
216 if let Some(ref chunk_name) = chunk.name {
217 if wildcard_match(name_pat, chunk_name) {
218 score += 0.3;
219 } else if chunk_name.to_lowercase().contains(&name_pat.to_lowercase().replace('*', ""))
220 {
221 score += 0.15;
223 }
224 }
225 }
226
227 if let Some(ref param_pats) = pattern.param_patterns {
229 total_criteria += 0.15;
230 let sig = chunk.signature.as_deref().unwrap_or(&chunk.body);
231 let chunk_params = extract_params_from_signature(sig);
232
233 if param_pats.len() == 1 && param_pats[0] == "*" {
234 score += 0.15;
236 } else if param_pats.is_empty() && chunk_params.is_empty() {
237 score += 0.15;
239 } else {
240 let param_score = match_param_types(param_pats, &chunk_params);
241 score += 0.15 * param_score;
242 }
243 }
244
245 if let Some(ref ret_pat) = pattern.return_pattern {
247 total_criteria += 0.15;
248 let sig = chunk.signature.as_deref().unwrap_or(&chunk.body);
249 let chunk_ret = extract_return_type_from_signature(sig);
250
251 if ret_pat == "*" {
252 score += 0.15;
253 } else if let Some(ref chunk_ret) = chunk_ret {
254 if fuzzy_type_match(ret_pat, chunk_ret) {
255 score += 0.15;
256 } else if chunk_ret.to_lowercase().contains(&ret_pat.to_lowercase()) {
257 score += 0.075; }
259 }
260 }
261
262 if total_criteria > 0.0 {
264 score / total_criteria
265 } else {
266 0.0
267 }
268}
269
270fn tokenize_pattern(pattern: &str) -> Vec<String> {
276 let mut tokens = Vec::new();
277 let mut current = String::new();
278 let mut chars = pattern.chars().peekable();
279
280 while let Some(ch) = chars.next() {
281 match ch {
282 '(' | ')' | ',' => {
283 if !current.is_empty() {
284 tokens.push(std::mem::take(&mut current));
285 }
286 tokens.push(ch.to_string());
287 }
288 '-' if chars.peek() == Some(&'>') => {
289 if !current.is_empty() {
290 tokens.push(std::mem::take(&mut current));
291 }
292 chars.next(); tokens.push("->".to_string());
294 }
295 ' ' | '\t' => {
296 if !current.is_empty() {
297 tokens.push(std::mem::take(&mut current));
298 }
299 }
300 _ => {
301 current.push(ch);
302 }
303 }
304 }
305
306 if !current.is_empty() {
307 tokens.push(current);
308 }
309
310 tokens
311}
312
313fn parse_kind(tokens: &[String], idx: usize) -> Result<(PatternKind, usize), SearchError> {
315 if idx >= tokens.len() {
316 return Ok((PatternKind::Any, idx));
317 }
318
319 let kind_str = tokens[idx].to_lowercase();
320 let kind = match kind_str.as_str() {
321 "fn" | "func" | "function" | "def" | "method" => PatternKind::Function,
322 "class" => PatternKind::Class,
323 "struct" => PatternKind::Struct,
324 "enum" => PatternKind::Enum,
325 "interface" | "trait" | "protocol" => PatternKind::Interface,
326 "*" => PatternKind::Any,
327 _ => {
328 return Ok((PatternKind::Function, idx));
331 }
332 };
333
334 Ok((kind, idx + 1))
335}
336
337fn parse_name(tokens: &[String], idx: usize) -> (Option<String>, usize) {
339 if idx >= tokens.len() {
340 return (None, idx);
341 }
342
343 if tokens[idx] != "(" && tokens[idx] != "->" && tokens[idx] != ")" && tokens[idx] != "," {
345 (Some(tokens[idx].clone()), idx + 1)
346 } else {
347 (None, idx)
348 }
349}
350
351fn parse_params(
353 tokens: &[String],
354 idx: usize,
355) -> Result<(Option<Vec<String>>, usize), SearchError> {
356 if idx >= tokens.len() || tokens[idx] != "(" {
357 return Ok((None, idx));
358 }
359
360 let mut params = Vec::new();
361 let mut i = idx + 1; while i < tokens.len() && tokens[i] != ")" {
364 if tokens[i] == "," {
365 i += 1;
366 continue;
367 }
368 params.push(tokens[i].clone());
369 i += 1;
370 }
371
372 if i < tokens.len() && tokens[i] == ")" {
373 i += 1; }
375
376 Ok((Some(params), i))
377}
378
379fn parse_return_type(tokens: &[String], idx: usize) -> Option<String> {
381 if idx + 1 < tokens.len() && tokens[idx] == "->" {
382 let ret_parts: Vec<&str> = tokens[idx + 1..].iter().map(|s| s.as_str()).collect();
384 if ret_parts.is_empty() {
385 None
386 } else {
387 Some(ret_parts.join(" "))
388 }
389 } else {
390 None
391 }
392}
393
394fn match_kind(pattern_kind: &PatternKind, chunk_kind: &ChunkKind) -> bool {
400 match pattern_kind {
401 PatternKind::Any => true,
402 PatternKind::Function => matches!(chunk_kind, ChunkKind::Function | ChunkKind::Method),
403 PatternKind::Class => matches!(chunk_kind, ChunkKind::Class),
404 PatternKind::Struct => matches!(chunk_kind, ChunkKind::Struct),
405 PatternKind::Enum => matches!(chunk_kind, ChunkKind::Enum),
406 PatternKind::Interface => matches!(chunk_kind, ChunkKind::Interface),
407 }
408}
409
410fn wildcard_match(pattern: &str, text: &str) -> bool {
415 let pattern = pattern.to_lowercase();
416 let text = text.to_lowercase();
417
418 if !pattern.contains('*') {
419 return pattern == text;
420 }
421
422 let parts: Vec<&str> = pattern.split('*').collect();
423
424 if parts.len() == 1 {
425 return pattern == text;
426 }
427
428 let mut text_pos = 0;
429
430 for (i, part) in parts.iter().enumerate() {
431 if part.is_empty() {
432 continue;
433 }
434
435 if i == 0 {
436 if !text[text_pos..].starts_with(part) {
438 return false;
439 }
440 text_pos += part.len();
441 } else if i == parts.len() - 1 {
442 if !text[text_pos..].ends_with(part) {
444 return false;
445 }
446 } else {
447 match text[text_pos..].find(part) {
449 Some(pos) => text_pos += pos + part.len(),
450 None => return false,
451 }
452 }
453 }
454
455 true
456}
457
458fn extract_params_from_signature(sig: &str) -> Vec<String> {
465 let Some(open) = sig.find('(') else {
467 return Vec::new();
468 };
469
470 let mut depth = 0;
471 let mut close = None;
472
473 for (i, ch) in sig[open..].char_indices() {
474 match ch {
475 '(' => depth += 1,
476 ')' => {
477 depth -= 1;
478 if depth == 0 {
479 close = Some(open + i);
480 break;
481 }
482 }
483 _ => {}
484 }
485 }
486
487 let Some(close) = close else {
488 return Vec::new();
489 };
490
491 let params_str = &sig[open + 1..close];
492 if params_str.trim().is_empty() {
493 return Vec::new();
494 }
495
496 split_params(params_str)
498 .iter()
499 .filter_map(|p| extract_type_from_param(p.trim()))
500 .collect()
501}
502
503fn split_params(params: &str) -> Vec<String> {
505 let mut parts = Vec::new();
506 let mut current = String::new();
507 let mut depth = 0;
508
509 for ch in params.chars() {
510 match ch {
511 '<' | '(' | '[' | '{' => {
512 depth += 1;
513 current.push(ch);
514 }
515 '>' | ')' | ']' | '}' => {
516 depth -= 1;
517 current.push(ch);
518 }
519 ',' if depth == 0 => {
520 parts.push(std::mem::take(&mut current));
521 }
522 _ => current.push(ch),
523 }
524 }
525
526 if !current.is_empty() {
527 parts.push(current);
528 }
529
530 parts
531}
532
533fn extract_type_from_param(param: &str) -> Option<String> {
542 let param = param.trim();
543 if param.is_empty() {
544 return None;
545 }
546
547 if let Some(colon_pos) = param.find(':') {
549 let type_part = param[colon_pos + 1..].trim();
550 let type_part = type_part
552 .trim_start_matches('&')
553 .trim_start_matches("mut ")
554 .trim();
555 return Some(type_part.to_string());
556 }
557
558 let parts: Vec<&str> = param.split_whitespace().collect();
560 if parts.len() >= 2 {
561 return Some(parts.last().unwrap().to_string());
562 }
563
564 Some(param.to_string())
566}
567
568fn extract_return_type_from_signature(sig: &str) -> Option<String> {
570 if let Some(arrow_pos) = sig.find("->") {
572 let ret = sig[arrow_pos + 2..].trim();
573 let ret = ret.trim_end_matches(|c: char| c == '{' || c == ':' || c.is_whitespace());
575 if !ret.is_empty() {
576 return Some(ret.to_string());
577 }
578 }
579
580 if let Some(close_paren) = sig.rfind(')') {
583 let after = sig[close_paren + 1..].trim();
584 if let Some(stripped) = after.strip_prefix(':') {
585 let ret = stripped.trim().trim_end_matches(|c: char| c == '{' || c.is_whitespace());
586 if !ret.is_empty() {
587 return Some(ret.to_string());
588 }
589 }
590 }
591
592 None
593}
594
595fn match_param_types(pattern_params: &[String], chunk_params: &[String]) -> f32 {
598 if pattern_params.is_empty() && chunk_params.is_empty() {
599 return 1.0;
600 }
601
602 if pattern_params.is_empty() || chunk_params.is_empty() {
603 return 0.0;
604 }
605
606 let pattern_count = pattern_params.len();
608 let chunk_count = chunk_params.len();
609
610 if pattern_count == 1 && pattern_params[0] == "*" {
612 return 1.0;
613 }
614
615 let fixed_params: Vec<&String> = pattern_params.iter().filter(|p| p.as_str() != "*").collect();
617
618 if fixed_params.len() > chunk_count {
619 return 0.0; }
621
622 let mut matched = 0;
623 let mut chunk_idx = 0;
624
625 for pat_param in pattern_params {
626 if pat_param == "*" {
627 matched += 1;
628 if chunk_idx < chunk_count {
629 chunk_idx += 1;
630 }
631 continue;
632 }
633
634 while chunk_idx < chunk_count {
636 if fuzzy_type_match(pat_param, &chunk_params[chunk_idx]) {
637 matched += 1;
638 chunk_idx += 1;
639 break;
640 }
641 chunk_idx += 1;
642 }
643 }
644
645 matched as f32 / pattern_params.len() as f32
646}
647
648fn fuzzy_type_match(pattern: &str, actual: &str) -> bool {
658 let pattern = pattern.trim().to_lowercase();
659 let actual = actual.trim().to_lowercase();
660
661 if pattern == "*" {
662 return true;
663 }
664
665 if pattern == actual {
667 return true;
668 }
669
670 if pattern.contains('*') {
672 return wildcard_match(&pattern, &actual);
673 }
674
675 match pattern.as_str() {
677 "string" | "str" => {
678 matches!(
679 actual.as_str(),
680 "string" | "str" | "&str" | "std::string::string" | "text"
681 )
682 }
683 "number" | "num" | "int" | "integer" => {
684 matches!(
685 actual.as_str(),
686 "i8" | "i16"
687 | "i32"
688 | "i64"
689 | "i128"
690 | "isize"
691 | "u8"
692 | "u16"
693 | "u32"
694 | "u64"
695 | "u128"
696 | "usize"
697 | "f32"
698 | "f64"
699 | "int"
700 | "int8"
701 | "int16"
702 | "int32"
703 | "int64"
704 | "uint"
705 | "float"
706 | "float32"
707 | "float64"
708 | "number"
709 | "double"
710 )
711 }
712 "bool" | "boolean" => {
713 matches!(actual.as_str(), "bool" | "boolean")
714 }
715 "void" | "none" | "unit" | "()" => {
716 matches!(actual.as_str(), "void" | "none" | "()" | "unit" | "null")
717 }
718 _ => {
719 actual.contains(&pattern) || pattern.contains(&actual)
721 }
722 }
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728 use crate::parser::ChunkKind;
729 use std::path::PathBuf;
730
731 fn make_chunk(id: u64, kind: ChunkKind, name: &str, sig: &str, body: &str) -> CodeChunk {
732 CodeChunk {
733 id,
734 file_path: PathBuf::from("test.rs"),
735 language: "rust".to_string(),
736 kind,
737 name: Some(name.to_string()),
738 signature: Some(sig.to_string()),
739 doc_comment: None,
740 body: body.to_string(),
741 byte_range: 0..body.len(),
742 line_range: 0..body.lines().count(),
743 }
744 }
745
746 #[test]
747 fn test_parse_simple_pattern() {
748 let pat = parse_pattern("fn(string) -> number").unwrap();
749 assert_eq!(pat.kind, PatternKind::Function);
750 assert!(pat.name_pattern.is_none());
751 assert_eq!(pat.param_patterns.as_ref().unwrap(), &["string"]);
752 assert_eq!(pat.return_pattern.as_ref().unwrap(), "number");
753 }
754
755 #[test]
756 fn test_parse_named_function_pattern() {
757 let pat = parse_pattern("fn authenticate(*)").unwrap();
758 assert_eq!(pat.kind, PatternKind::Function);
759 assert_eq!(pat.name_pattern.as_ref().unwrap(), "authenticate");
760 assert_eq!(pat.param_patterns.as_ref().unwrap(), &["*"]);
761 }
762
763 #[test]
764 fn test_parse_async_pattern() {
765 let pat = parse_pattern("async fn(*) -> Result").unwrap();
766 assert_eq!(pat.kind, PatternKind::Function);
767 assert!(pat.qualifiers.contains(&"async".to_string()));
768 assert_eq!(pat.return_pattern.as_ref().unwrap(), "Result");
769 }
770
771 #[test]
772 fn test_parse_class_pattern() {
773 let pat = parse_pattern("class User").unwrap();
774 assert_eq!(pat.kind, PatternKind::Class);
775 assert_eq!(pat.name_pattern.as_ref().unwrap(), "User");
776 }
777
778 #[test]
779 fn test_parse_struct_wildcard() {
780 let pat = parse_pattern("struct *Config").unwrap();
781 assert_eq!(pat.kind, PatternKind::Struct);
782 assert_eq!(pat.name_pattern.as_ref().unwrap(), "*Config");
783 }
784
785 #[test]
786 fn test_parse_multi_param() {
787 let pat = parse_pattern("fn(string, number) -> bool").unwrap();
788 assert_eq!(pat.kind, PatternKind::Function);
789 let params = pat.param_patterns.as_ref().unwrap();
790 assert_eq!(params.len(), 2);
791 assert_eq!(params[0], "string");
792 assert_eq!(params[1], "number");
793 assert_eq!(pat.return_pattern.as_ref().unwrap(), "bool");
794 }
795
796 #[test]
797 fn test_parse_empty_params() {
798 let pat = parse_pattern("fn()").unwrap();
799 assert_eq!(pat.kind, PatternKind::Function);
800 assert!(pat.param_patterns.as_ref().unwrap().is_empty());
801 }
802
803 #[test]
804 fn test_wildcard_match() {
805 assert!(wildcard_match("*Config", "SeekrConfig"));
806 assert!(wildcard_match("*Config", "AppConfig"));
807 assert!(!wildcard_match("*Config", "ConfigManager"));
808 assert!(wildcard_match("Auth*", "AuthService"));
809 assert!(wildcard_match("*", "anything"));
810 assert!(wildcard_match("exact", "exact"));
811 assert!(!wildcard_match("exact", "notexact"));
812 }
813
814 #[test]
815 fn test_fuzzy_type_match() {
816 assert!(fuzzy_type_match("string", "String"));
817 assert!(fuzzy_type_match("string", "&str"));
818 assert!(fuzzy_type_match("number", "i32"));
819 assert!(fuzzy_type_match("number", "f64"));
820 assert!(fuzzy_type_match("bool", "boolean"));
821 assert!(fuzzy_type_match("*", "anything"));
822 assert!(fuzzy_type_match("Result*", "Result<String, Error>"));
823 }
824
825 #[test]
826 fn test_extract_params_rust() {
827 let params = extract_params_from_signature("fn authenticate(user: &str, password: String) -> bool");
828 assert_eq!(params.len(), 2);
829 assert_eq!(params[0], "str");
830 assert_eq!(params[1], "String");
831 }
832
833 #[test]
834 fn test_extract_return_type_rust() {
835 let ret = extract_return_type_from_signature("fn foo(x: i32) -> Result<String, Error>");
836 assert_eq!(ret, Some("Result<String, Error>".to_string()));
837 }
838
839 #[test]
840 fn test_extract_return_type_arrow() {
841 let ret = extract_return_type_from_signature("def foo(x: int) -> bool:");
842 assert_eq!(ret, Some("bool".to_string()));
843 }
844
845 #[test]
846 fn test_match_function_by_return_type() {
847 let pat = parse_pattern("fn(*) -> Result").unwrap();
848
849 let chunk = make_chunk(
850 1,
851 ChunkKind::Function,
852 "authenticate",
853 "fn authenticate(user: &str) -> Result<Token, Error>",
854 "fn authenticate(user: &str) -> Result<Token, Error> { }",
855 );
856
857 let score = match_chunk(&pat, &chunk);
858 assert!(score > 0.5, "Should match function returning Result, got {}", score);
859 }
860
861 #[test]
862 fn test_match_function_by_name() {
863 let pat = parse_pattern("fn authenticate(*)").unwrap();
864
865 let chunk = make_chunk(
866 1,
867 ChunkKind::Function,
868 "authenticate",
869 "fn authenticate(user: &str, pass: &str) -> bool",
870 "fn authenticate(user: &str, pass: &str) -> bool { }",
871 );
872
873 let score = match_chunk(&pat, &chunk);
874 assert!(score > 0.5, "Should match by name, got {}", score);
875 }
876
877 #[test]
878 fn test_no_match_wrong_kind() {
879 let pat = parse_pattern("class Foo").unwrap();
880
881 let chunk = make_chunk(
882 1,
883 ChunkKind::Function,
884 "Foo",
885 "fn Foo()",
886 "fn Foo() {}",
887 );
888
889 let score = match_chunk(&pat, &chunk);
890 assert_eq!(score, 0.0, "Should not match wrong kind");
891 }
892
893 #[test]
894 fn test_search_ast_pattern_integration() {
895 let mut index = SeekrIndex::new(4);
896
897 let chunks = vec![
899 make_chunk(
900 1,
901 ChunkKind::Function,
902 "authenticate",
903 "fn authenticate(user: &str) -> Result<Token, AuthError>",
904 "fn authenticate(user: &str) -> Result<Token, AuthError> { }",
905 ),
906 make_chunk(
907 2,
908 ChunkKind::Function,
909 "calculate",
910 "fn calculate(x: f64, y: f64) -> f64",
911 "fn calculate(x: f64, y: f64) -> f64 { x + y }",
912 ),
913 make_chunk(
914 3,
915 ChunkKind::Struct,
916 "AppConfig",
917 "pub struct AppConfig",
918 "pub struct AppConfig { pub port: u16 }",
919 ),
920 ];
921
922 for chunk in &chunks {
923 let entry = crate::index::IndexEntry {
924 chunk_id: chunk.id,
925 embedding: vec![0.1; 4],
926 text_tokens: vec![],
927 };
928 index.add_entry(entry, chunk.clone());
929 }
930
931 let results = search_ast_pattern(&index, "fn(*) -> Result", 10).unwrap();
933 assert!(!results.is_empty());
934 assert_eq!(results[0].chunk_id, 1);
935
936 let results = search_ast_pattern(&index, "struct *Config", 10).unwrap();
938 assert!(!results.is_empty());
939 assert_eq!(results[0].chunk_id, 3);
940
941 let results = search_ast_pattern(&index, "fn calculate(*)", 10).unwrap();
943 assert!(!results.is_empty());
944 assert_eq!(results[0].chunk_id, 2);
945 }
946
947 #[test]
948 fn test_empty_pattern_error() {
949 let result = parse_pattern("");
950 assert!(result.is_err());
951 }
952}