Skip to main content

seekr_code/search/
ast_pattern.rs

1//! AST pattern search.
2//!
3//! Parses user-provided function signature patterns (e.g., "fn(string) -> number")
4//! and matches them against indexed CodeChunk signatures.
5//! Supports wildcards (`*`), optional keywords (`async`, `pub`), and fuzzy type matching.
6//!
7//! ## Pattern Syntax
8//!
9//! ```text
10//! [async] [pub] fn_keyword([param_type, ...]) [-> return_type]
11//! ```
12//!
13//! Examples:
14//! - `fn(string) -> number`        — any function taking a string, returning a number
15//! - `fn(*) -> Result`             — any function returning a Result
16//! - `async fn(*)`                 — any async function
17//! - `fn(*, *) -> bool`            — any function with 2 params returning bool
18//! - `fn authenticate(*)`          — function named "authenticate" with any params
19//! - `class User`                  — class named User
20//! - `struct *Config`              — struct whose name ends with "Config"
21
22use crate::error::SearchError;
23use crate::index::store::SeekrIndex;
24use crate::parser::{ChunkKind, CodeChunk};
25
26/// A parsed AST search pattern.
27#[derive(Debug, Clone)]
28pub struct AstPattern {
29    /// Optional qualifiers: "async", "pub", "static", etc.
30    pub qualifiers: Vec<String>,
31
32    /// The kind of construct to match.
33    pub kind: PatternKind,
34
35    /// Optional name pattern (supports `*` wildcard).
36    pub name_pattern: Option<String>,
37
38    /// Optional parameter type patterns (supports `*` wildcard).
39    pub param_patterns: Option<Vec<String>>,
40
41    /// Optional return type pattern (supports `*` wildcard).
42    pub return_pattern: Option<String>,
43}
44
45/// The kind of pattern target.
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum PatternKind {
48    /// Match functions or methods.
49    Function,
50    /// Match classes.
51    Class,
52    /// Match structs.
53    Struct,
54    /// Match enums.
55    Enum,
56    /// Match interfaces / traits.
57    Interface,
58    /// Match any kind.
59    Any,
60}
61
62/// Result of an AST pattern match.
63#[derive(Debug, Clone)]
64pub struct AstMatch {
65    /// The chunk ID that matched.
66    pub chunk_id: u64,
67
68    /// Match score (0.0 to 1.0).
69    pub score: f32,
70}
71
72/// Parse a user-provided AST pattern string.
73///
74/// # Examples
75/// ```
76/// use seekr_code::search::ast_pattern::parse_pattern;
77///
78/// let pat = parse_pattern("fn(string) -> number").unwrap();
79/// assert_eq!(pat.kind, seekr_code::search::ast_pattern::PatternKind::Function);
80/// ```
81pub fn parse_pattern(pattern: &str) -> Result<AstPattern, SearchError> {
82    let pattern = pattern.trim();
83
84    if pattern.is_empty() {
85        return Err(SearchError::InvalidAstPattern("Empty pattern".to_string()));
86    }
87
88    let tokens = tokenize_pattern(pattern);
89
90    if tokens.is_empty() {
91        return Err(SearchError::InvalidAstPattern(
92            "Could not parse pattern".to_string(),
93        ));
94    }
95
96    let mut idx = 0;
97    let mut qualifiers = Vec::new();
98
99    // Collect qualifiers (async, pub, static, etc.)
100    while idx < tokens.len() {
101        match tokens[idx].as_str() {
102            "async" | "pub" | "static" | "export" | "private" | "protected" | "public"
103            | "abstract" | "virtual" | "override" | "const" | "mut" | "unsafe" => {
104                qualifiers.push(tokens[idx].clone());
105                idx += 1;
106            }
107            _ => break,
108        }
109    }
110
111    if idx >= tokens.len() {
112        return Err(SearchError::InvalidAstPattern(
113            "Pattern has only qualifiers, missing kind (fn, class, struct, etc.)".to_string(),
114        ));
115    }
116
117    // Parse kind keyword
118    let (kind, idx) = parse_kind(&tokens, idx)?;
119
120    // Parse optional name
121    let (name_pattern, idx) = parse_name(&tokens, idx);
122
123    // Parse optional parameters (only for function-like kinds)
124    let (param_patterns, idx) = if matches!(kind, PatternKind::Function | PatternKind::Any) {
125        parse_params(&tokens, idx)?
126    } else {
127        (None, idx)
128    };
129
130    // Parse optional return type
131    let return_pattern = parse_return_type(&tokens, idx);
132
133    Ok(AstPattern {
134        qualifiers,
135        kind,
136        name_pattern,
137        param_patterns,
138        return_pattern,
139    })
140}
141
142/// Search chunks in the index using an AST pattern.
143pub fn search_ast_pattern(
144    index: &SeekrIndex,
145    pattern: &str,
146    top_k: usize,
147) -> Result<Vec<AstMatch>, SearchError> {
148    let parsed = parse_pattern(pattern)?;
149
150    let mut matches: Vec<AstMatch> = Vec::new();
151
152    for chunk in index.chunks.values() {
153        let score = match_chunk(&parsed, chunk);
154        if score > 0.0 {
155            matches.push(AstMatch {
156                chunk_id: chunk.id,
157                score,
158            });
159        }
160    }
161
162    // Sort by score descending
163    matches.sort_by(|a, b| {
164        b.score
165            .partial_cmp(&a.score)
166            .unwrap_or(std::cmp::Ordering::Equal)
167    });
168
169    matches.truncate(top_k);
170
171    Ok(matches)
172}
173
174/// Match a single chunk against a parsed AST pattern.
175/// Returns a score from 0.0 (no match) to 1.0 (perfect match).
176fn match_chunk(pattern: &AstPattern, chunk: &CodeChunk) -> f32 {
177    let mut score = 0.0f32;
178    let mut total_criteria = 0.0f32;
179
180    // 1. Match kind (required, weight = 0.3)
181    total_criteria += 0.3;
182    if match_kind(&pattern.kind, &chunk.kind) {
183        score += 0.3;
184    } else {
185        // Kind mismatch is a hard filter — return 0
186        return 0.0;
187    }
188
189    // 2. Match qualifiers (weight = 0.1)
190    if !pattern.qualifiers.is_empty() {
191        total_criteria += 0.1;
192        let sig_lower = chunk.signature.as_deref().unwrap_or("").to_lowercase();
193        let body_start = chunk.body.lines().next().unwrap_or("").to_lowercase();
194        let combined = format!("{} {}", sig_lower, body_start);
195
196        let matched_quals = pattern
197            .qualifiers
198            .iter()
199            .filter(|q| combined.contains(q.as_str()))
200            .count();
201
202        if !pattern.qualifiers.is_empty() {
203            score += 0.1 * (matched_quals as f32 / pattern.qualifiers.len() as f32);
204        }
205    }
206
207    // 3. Match name (weight = 0.3)
208    if let Some(ref name_pat) = pattern.name_pattern {
209        total_criteria += 0.3;
210        if let Some(ref chunk_name) = chunk.name {
211            if wildcard_match(name_pat, chunk_name) {
212                score += 0.3;
213            } else if chunk_name
214                .to_lowercase()
215                .contains(&name_pat.to_lowercase().replace('*', ""))
216            {
217                // Partial name match gets partial credit
218                score += 0.15;
219            }
220        }
221    }
222
223    // 4. Match parameter patterns (weight = 0.15)
224    if let Some(ref param_pats) = pattern.param_patterns {
225        total_criteria += 0.15;
226        let sig = chunk.signature.as_deref().unwrap_or(&chunk.body);
227        let chunk_params = extract_params_from_signature(sig);
228
229        if param_pats.len() == 1 && param_pats[0] == "*" {
230            // Wildcard: any params match
231            score += 0.15;
232        } else if param_pats.is_empty() && chunk_params.is_empty() {
233            // Both empty: match
234            score += 0.15;
235        } else {
236            let param_score = match_param_types(param_pats, &chunk_params);
237            score += 0.15 * param_score;
238        }
239    }
240
241    // 5. Match return type (weight = 0.15)
242    if let Some(ref ret_pat) = pattern.return_pattern {
243        total_criteria += 0.15;
244        let sig = chunk.signature.as_deref().unwrap_or(&chunk.body);
245        let chunk_ret = extract_return_type_from_signature(sig);
246
247        if ret_pat == "*" {
248            score += 0.15;
249        } else if let Some(ref chunk_ret) = chunk_ret {
250            if fuzzy_type_match(ret_pat, chunk_ret) {
251                score += 0.15;
252            } else if chunk_ret.to_lowercase().contains(&ret_pat.to_lowercase()) {
253                score += 0.075; // partial credit
254            }
255        }
256    }
257
258    // Normalize: if no optional criteria were specified, boost the score
259    if total_criteria > 0.0 {
260        score / total_criteria
261    } else {
262        0.0
263    }
264}
265
266// ============================================================
267// Pattern Parsing Helpers
268// ============================================================
269
270/// Tokenize a pattern string into meaningful tokens.
271fn tokenize_pattern(pattern: &str) -> Vec<String> {
272    let mut tokens = Vec::new();
273    let mut current = String::new();
274    let mut chars = pattern.chars().peekable();
275
276    while let Some(ch) = chars.next() {
277        match ch {
278            '(' | ')' | ',' => {
279                if !current.is_empty() {
280                    tokens.push(std::mem::take(&mut current));
281                }
282                tokens.push(ch.to_string());
283            }
284            '-' if chars.peek() == Some(&'>') => {
285                if !current.is_empty() {
286                    tokens.push(std::mem::take(&mut current));
287                }
288                chars.next(); // consume '>'
289                tokens.push("->".to_string());
290            }
291            ' ' | '\t' => {
292                if !current.is_empty() {
293                    tokens.push(std::mem::take(&mut current));
294                }
295            }
296            _ => {
297                current.push(ch);
298            }
299        }
300    }
301
302    if !current.is_empty() {
303        tokens.push(current);
304    }
305
306    tokens
307}
308
309/// Parse the kind keyword from tokens.
310fn parse_kind(tokens: &[String], idx: usize) -> Result<(PatternKind, usize), SearchError> {
311    if idx >= tokens.len() {
312        return Ok((PatternKind::Any, idx));
313    }
314
315    let kind_str = tokens[idx].to_lowercase();
316    let kind = match kind_str.as_str() {
317        "fn" | "func" | "function" | "def" | "method" => PatternKind::Function,
318        "class" => PatternKind::Class,
319        "struct" => PatternKind::Struct,
320        "enum" => PatternKind::Enum,
321        "interface" | "trait" | "protocol" => PatternKind::Interface,
322        "*" => PatternKind::Any,
323        _ => {
324            // If it's not a recognized keyword, treat the first token as a kind "function"
325            // (since most patterns are for functions) and this token as the name
326            return Ok((PatternKind::Function, idx));
327        }
328    };
329
330    Ok((kind, idx + 1))
331}
332
333/// Parse an optional name pattern.
334fn parse_name(tokens: &[String], idx: usize) -> (Option<String>, usize) {
335    if idx >= tokens.len() {
336        return (None, idx);
337    }
338
339    // Next token is a name if it's not '(' or '->'
340    if tokens[idx] != "(" && tokens[idx] != "->" && tokens[idx] != ")" && tokens[idx] != "," {
341        (Some(tokens[idx].clone()), idx + 1)
342    } else {
343        (None, idx)
344    }
345}
346
347/// Parse optional parameter type patterns from parentheses.
348fn parse_params(
349    tokens: &[String],
350    idx: usize,
351) -> Result<(Option<Vec<String>>, usize), SearchError> {
352    if idx >= tokens.len() || tokens[idx] != "(" {
353        return Ok((None, idx));
354    }
355
356    let mut params = Vec::new();
357    let mut i = idx + 1; // skip '('
358
359    while i < tokens.len() && tokens[i] != ")" {
360        if tokens[i] == "," {
361            i += 1;
362            continue;
363        }
364        params.push(tokens[i].clone());
365        i += 1;
366    }
367
368    if i < tokens.len() && tokens[i] == ")" {
369        i += 1; // skip ')'
370    }
371
372    Ok((Some(params), i))
373}
374
375/// Parse optional return type from `->`.
376fn parse_return_type(tokens: &[String], idx: usize) -> Option<String> {
377    if idx + 1 < tokens.len() && tokens[idx] == "->" {
378        // Collect all remaining tokens as the return type
379        let ret_parts: Vec<&str> = tokens[idx + 1..].iter().map(|s| s.as_str()).collect();
380        if ret_parts.is_empty() {
381            None
382        } else {
383            Some(ret_parts.join(" "))
384        }
385    } else {
386        None
387    }
388}
389
390// ============================================================
391// Matching Helpers
392// ============================================================
393
394/// Check if the pattern kind matches the chunk kind.
395fn match_kind(pattern_kind: &PatternKind, chunk_kind: &ChunkKind) -> bool {
396    match pattern_kind {
397        PatternKind::Any => true,
398        PatternKind::Function => matches!(chunk_kind, ChunkKind::Function | ChunkKind::Method),
399        PatternKind::Class => matches!(chunk_kind, ChunkKind::Class),
400        PatternKind::Struct => matches!(chunk_kind, ChunkKind::Struct),
401        PatternKind::Enum => matches!(chunk_kind, ChunkKind::Enum),
402        PatternKind::Interface => matches!(chunk_kind, ChunkKind::Interface),
403    }
404}
405
406/// Wildcard string matching.
407///
408/// Supports `*` as a glob wildcard that matches any sequence of characters.
409/// Case-insensitive.
410fn wildcard_match(pattern: &str, text: &str) -> bool {
411    let pattern = pattern.to_lowercase();
412    let text = text.to_lowercase();
413
414    if !pattern.contains('*') {
415        return pattern == text;
416    }
417
418    let parts: Vec<&str> = pattern.split('*').collect();
419
420    if parts.len() == 1 {
421        return pattern == text;
422    }
423
424    let mut text_pos = 0;
425
426    for (i, part) in parts.iter().enumerate() {
427        if part.is_empty() {
428            continue;
429        }
430
431        if i == 0 {
432            // First part must be a prefix
433            if !text[text_pos..].starts_with(part) {
434                return false;
435            }
436            text_pos += part.len();
437        } else if i == parts.len() - 1 {
438            // Last part must be a suffix
439            if !text[text_pos..].ends_with(part) {
440                return false;
441            }
442        } else {
443            // Middle parts must appear in order
444            match text[text_pos..].find(part) {
445                Some(pos) => text_pos += pos + part.len(),
446                None => return false,
447            }
448        }
449    }
450
451    true
452}
453
454/// Extract parameter types from a function signature string.
455///
456/// Handles common signature formats:
457/// - `fn foo(x: i32, y: String) -> bool`
458/// - `def foo(x: int, y: str) -> bool`
459/// - `function foo(x, y)`
460fn extract_params_from_signature(sig: &str) -> Vec<String> {
461    // Find content between first '(' and matching ')'
462    let Some(open) = sig.find('(') else {
463        return Vec::new();
464    };
465
466    let mut depth = 0;
467    let mut close = None;
468
469    for (i, ch) in sig[open..].char_indices() {
470        match ch {
471            '(' => depth += 1,
472            ')' => {
473                depth -= 1;
474                if depth == 0 {
475                    close = Some(open + i);
476                    break;
477                }
478            }
479            _ => {}
480        }
481    }
482
483    let Some(close) = close else {
484        return Vec::new();
485    };
486
487    let params_str = &sig[open + 1..close];
488    if params_str.trim().is_empty() {
489        return Vec::new();
490    }
491
492    // Split by commas (being careful of nested generics)
493    split_params(params_str)
494        .iter()
495        .filter_map(|p| extract_type_from_param(p.trim()))
496        .collect()
497}
498
499/// Split parameter list by commas, respecting nesting (generics, etc.).
500fn split_params(params: &str) -> Vec<String> {
501    let mut parts = Vec::new();
502    let mut current = String::new();
503    let mut depth = 0;
504
505    for ch in params.chars() {
506        match ch {
507            '<' | '(' | '[' | '{' => {
508                depth += 1;
509                current.push(ch);
510            }
511            '>' | ')' | ']' | '}' => {
512                depth -= 1;
513                current.push(ch);
514            }
515            ',' if depth == 0 => {
516                parts.push(std::mem::take(&mut current));
517            }
518            _ => current.push(ch),
519        }
520    }
521
522    if !current.is_empty() {
523        parts.push(current);
524    }
525
526    parts
527}
528
529/// Extract the type from a parameter declaration.
530///
531/// Handles:
532/// - `x: i32` -> "i32"
533/// - `x: &str` -> "str"
534/// - `name: String` -> "String"
535/// - `x int` -> "int" (Go style)
536/// - `x` -> "x" (untyped, treated as name/type ambiguous)
537fn extract_type_from_param(param: &str) -> Option<String> {
538    let param = param.trim();
539    if param.is_empty() {
540        return None;
541    }
542
543    // Rust/Python style: `name: Type`
544    if let Some(colon_pos) = param.find(':') {
545        let type_part = param[colon_pos + 1..].trim();
546        // Strip references (&, &mut)
547        let type_part = type_part
548            .trim_start_matches('&')
549            .trim_start_matches("mut ")
550            .trim();
551        return Some(type_part.to_string());
552    }
553
554    // Go style: `name Type` (space separated, second token is type)
555    let parts: Vec<&str> = param.split_whitespace().collect();
556    if parts.len() >= 2 {
557        return Some(parts.last().unwrap().to_string());
558    }
559
560    // Single token — could be type or name
561    Some(param.to_string())
562}
563
564/// Extract return type from a function signature.
565fn extract_return_type_from_signature(sig: &str) -> Option<String> {
566    // Look for `->` (Rust/Python)
567    if let Some(arrow_pos) = sig.find("->") {
568        let ret = sig[arrow_pos + 2..].trim();
569        // Strip trailing '{' or ':'
570        let ret = ret.trim_end_matches(|c: char| c == '{' || c == ':' || c.is_whitespace());
571        if !ret.is_empty() {
572            return Some(ret.to_string());
573        }
574    }
575
576    // Look for `: ReturnType` after closing paren (TypeScript/Go style)
577    // Find the last ')' and check if there's `: Type` after it
578    if let Some(close_paren) = sig.rfind(')') {
579        let after = sig[close_paren + 1..].trim();
580        if let Some(stripped) = after.strip_prefix(':') {
581            let ret = stripped
582                .trim()
583                .trim_end_matches(|c: char| c == '{' || c.is_whitespace());
584            if !ret.is_empty() {
585                return Some(ret.to_string());
586            }
587        }
588    }
589
590    None
591}
592
593/// Match parameter types between pattern and chunk.
594/// Returns a score from 0.0 to 1.0.
595fn match_param_types(pattern_params: &[String], chunk_params: &[String]) -> f32 {
596    if pattern_params.is_empty() && chunk_params.is_empty() {
597        return 1.0;
598    }
599
600    if pattern_params.is_empty() || chunk_params.is_empty() {
601        return 0.0;
602    }
603
604    // Check parameter count match (with wildcards)
605    let pattern_count = pattern_params.len();
606    let chunk_count = chunk_params.len();
607
608    // If pattern has a single `*`, match any number of params
609    if pattern_count == 1 && pattern_params[0] == "*" {
610        return 1.0;
611    }
612
613    // Count non-wildcard params in pattern
614    let fixed_params: Vec<&String> = pattern_params
615        .iter()
616        .filter(|p| p.as_str() != "*")
617        .collect();
618
619    if fixed_params.len() > chunk_count {
620        return 0.0; // More fixed params than chunk has
621    }
622
623    let mut matched = 0;
624    let mut chunk_idx = 0;
625
626    for pat_param in pattern_params {
627        if pat_param == "*" {
628            matched += 1;
629            if chunk_idx < chunk_count {
630                chunk_idx += 1;
631            }
632            continue;
633        }
634
635        // Try to match this pattern param against remaining chunk params
636        while chunk_idx < chunk_count {
637            if fuzzy_type_match(pat_param, &chunk_params[chunk_idx]) {
638                matched += 1;
639                chunk_idx += 1;
640                break;
641            }
642            chunk_idx += 1;
643        }
644    }
645
646    matched as f32 / pattern_params.len() as f32
647}
648
649/// Fuzzy type matching between pattern type and actual type.
650///
651/// Handles:
652/// - Case-insensitive comparison
653/// - Common type aliases: "string" matches "String", "&str", "str"
654/// - "number" matches "i32", "f64", "int", "float", etc.
655/// - "bool" matches "boolean", "bool"
656/// - Wildcard `*` matches anything
657/// - Prefix/suffix with `*`
658fn fuzzy_type_match(pattern: &str, actual: &str) -> bool {
659    let pattern = pattern.trim().to_lowercase();
660    let actual = actual.trim().to_lowercase();
661
662    if pattern == "*" {
663        return true;
664    }
665
666    // Exact match
667    if pattern == actual {
668        return true;
669    }
670
671    // Wildcard matching
672    if pattern.contains('*') {
673        return wildcard_match(&pattern, &actual);
674    }
675
676    // Type group matching
677    match pattern.as_str() {
678        "string" | "str" => {
679            matches!(
680                actual.as_str(),
681                "string" | "str" | "&str" | "std::string::string" | "text"
682            )
683        }
684        "number" | "num" | "int" | "integer" => {
685            matches!(
686                actual.as_str(),
687                "i8" | "i16"
688                    | "i32"
689                    | "i64"
690                    | "i128"
691                    | "isize"
692                    | "u8"
693                    | "u16"
694                    | "u32"
695                    | "u64"
696                    | "u128"
697                    | "usize"
698                    | "f32"
699                    | "f64"
700                    | "int"
701                    | "int8"
702                    | "int16"
703                    | "int32"
704                    | "int64"
705                    | "uint"
706                    | "float"
707                    | "float32"
708                    | "float64"
709                    | "number"
710                    | "double"
711            )
712        }
713        "bool" | "boolean" => {
714            matches!(actual.as_str(), "bool" | "boolean")
715        }
716        "void" | "none" | "unit" | "()" => {
717            matches!(actual.as_str(), "void" | "none" | "()" | "unit" | "null")
718        }
719        _ => {
720            // Check if actual contains the pattern (partial match)
721            actual.contains(&pattern) || pattern.contains(&actual)
722        }
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use crate::parser::ChunkKind;
730    use std::path::PathBuf;
731
732    fn make_chunk(id: u64, kind: ChunkKind, name: &str, sig: &str, body: &str) -> CodeChunk {
733        CodeChunk {
734            id,
735            file_path: PathBuf::from("test.rs"),
736            language: "rust".to_string(),
737            kind,
738            name: Some(name.to_string()),
739            signature: Some(sig.to_string()),
740            doc_comment: None,
741            body: body.to_string(),
742            byte_range: 0..body.len(),
743            line_range: 0..body.lines().count(),
744        }
745    }
746
747    #[test]
748    fn test_parse_simple_pattern() {
749        let pat = parse_pattern("fn(string) -> number").unwrap();
750        assert_eq!(pat.kind, PatternKind::Function);
751        assert!(pat.name_pattern.is_none());
752        assert_eq!(pat.param_patterns.as_ref().unwrap(), &["string"]);
753        assert_eq!(pat.return_pattern.as_ref().unwrap(), "number");
754    }
755
756    #[test]
757    fn test_parse_named_function_pattern() {
758        let pat = parse_pattern("fn authenticate(*)").unwrap();
759        assert_eq!(pat.kind, PatternKind::Function);
760        assert_eq!(pat.name_pattern.as_ref().unwrap(), "authenticate");
761        assert_eq!(pat.param_patterns.as_ref().unwrap(), &["*"]);
762    }
763
764    #[test]
765    fn test_parse_async_pattern() {
766        let pat = parse_pattern("async fn(*) -> Result").unwrap();
767        assert_eq!(pat.kind, PatternKind::Function);
768        assert!(pat.qualifiers.contains(&"async".to_string()));
769        assert_eq!(pat.return_pattern.as_ref().unwrap(), "Result");
770    }
771
772    #[test]
773    fn test_parse_class_pattern() {
774        let pat = parse_pattern("class User").unwrap();
775        assert_eq!(pat.kind, PatternKind::Class);
776        assert_eq!(pat.name_pattern.as_ref().unwrap(), "User");
777    }
778
779    #[test]
780    fn test_parse_struct_wildcard() {
781        let pat = parse_pattern("struct *Config").unwrap();
782        assert_eq!(pat.kind, PatternKind::Struct);
783        assert_eq!(pat.name_pattern.as_ref().unwrap(), "*Config");
784    }
785
786    #[test]
787    fn test_parse_multi_param() {
788        let pat = parse_pattern("fn(string, number) -> bool").unwrap();
789        assert_eq!(pat.kind, PatternKind::Function);
790        let params = pat.param_patterns.as_ref().unwrap();
791        assert_eq!(params.len(), 2);
792        assert_eq!(params[0], "string");
793        assert_eq!(params[1], "number");
794        assert_eq!(pat.return_pattern.as_ref().unwrap(), "bool");
795    }
796
797    #[test]
798    fn test_parse_empty_params() {
799        let pat = parse_pattern("fn()").unwrap();
800        assert_eq!(pat.kind, PatternKind::Function);
801        assert!(pat.param_patterns.as_ref().unwrap().is_empty());
802    }
803
804    #[test]
805    fn test_wildcard_match() {
806        assert!(wildcard_match("*Config", "SeekrConfig"));
807        assert!(wildcard_match("*Config", "AppConfig"));
808        assert!(!wildcard_match("*Config", "ConfigManager"));
809        assert!(wildcard_match("Auth*", "AuthService"));
810        assert!(wildcard_match("*", "anything"));
811        assert!(wildcard_match("exact", "exact"));
812        assert!(!wildcard_match("exact", "notexact"));
813    }
814
815    #[test]
816    fn test_fuzzy_type_match() {
817        assert!(fuzzy_type_match("string", "String"));
818        assert!(fuzzy_type_match("string", "&str"));
819        assert!(fuzzy_type_match("number", "i32"));
820        assert!(fuzzy_type_match("number", "f64"));
821        assert!(fuzzy_type_match("bool", "boolean"));
822        assert!(fuzzy_type_match("*", "anything"));
823        assert!(fuzzy_type_match("Result*", "Result<String, Error>"));
824    }
825
826    #[test]
827    fn test_extract_params_rust() {
828        let params =
829            extract_params_from_signature("fn authenticate(user: &str, password: String) -> bool");
830        assert_eq!(params.len(), 2);
831        assert_eq!(params[0], "str");
832        assert_eq!(params[1], "String");
833    }
834
835    #[test]
836    fn test_extract_return_type_rust() {
837        let ret = extract_return_type_from_signature("fn foo(x: i32) -> Result<String, Error>");
838        assert_eq!(ret, Some("Result<String, Error>".to_string()));
839    }
840
841    #[test]
842    fn test_extract_return_type_arrow() {
843        let ret = extract_return_type_from_signature("def foo(x: int) -> bool:");
844        assert_eq!(ret, Some("bool".to_string()));
845    }
846
847    #[test]
848    fn test_match_function_by_return_type() {
849        let pat = parse_pattern("fn(*) -> Result").unwrap();
850
851        let chunk = make_chunk(
852            1,
853            ChunkKind::Function,
854            "authenticate",
855            "fn authenticate(user: &str) -> Result<Token, Error>",
856            "fn authenticate(user: &str) -> Result<Token, Error> { }",
857        );
858
859        let score = match_chunk(&pat, &chunk);
860        assert!(
861            score > 0.5,
862            "Should match function returning Result, got {}",
863            score
864        );
865    }
866
867    #[test]
868    fn test_match_function_by_name() {
869        let pat = parse_pattern("fn authenticate(*)").unwrap();
870
871        let chunk = make_chunk(
872            1,
873            ChunkKind::Function,
874            "authenticate",
875            "fn authenticate(user: &str, pass: &str) -> bool",
876            "fn authenticate(user: &str, pass: &str) -> bool { }",
877        );
878
879        let score = match_chunk(&pat, &chunk);
880        assert!(score > 0.5, "Should match by name, got {}", score);
881    }
882
883    #[test]
884    fn test_no_match_wrong_kind() {
885        let pat = parse_pattern("class Foo").unwrap();
886
887        let chunk = make_chunk(1, ChunkKind::Function, "Foo", "fn Foo()", "fn Foo() {}");
888
889        let score = match_chunk(&pat, &chunk);
890        assert_eq!(score, 0.0, "Should not match wrong kind");
891    }
892
893    #[test]
894    fn test_search_ast_pattern_integration() {
895        let mut index = SeekrIndex::new(4);
896
897        // Add some chunks
898        let chunks = vec![
899            make_chunk(
900                1,
901                ChunkKind::Function,
902                "authenticate",
903                "fn authenticate(user: &str) -> Result<Token, AuthError>",
904                "fn authenticate(user: &str) -> Result<Token, AuthError> { }",
905            ),
906            make_chunk(
907                2,
908                ChunkKind::Function,
909                "calculate",
910                "fn calculate(x: f64, y: f64) -> f64",
911                "fn calculate(x: f64, y: f64) -> f64 { x + y }",
912            ),
913            make_chunk(
914                3,
915                ChunkKind::Struct,
916                "AppConfig",
917                "pub struct AppConfig",
918                "pub struct AppConfig { pub port: u16 }",
919            ),
920        ];
921
922        for chunk in &chunks {
923            let entry = crate::index::IndexEntry {
924                chunk_id: chunk.id,
925                embedding: vec![0.1; 4],
926                text_tokens: vec![],
927            };
928            index.add_entry(entry, chunk.clone());
929        }
930
931        // Search for functions returning Result
932        let results = search_ast_pattern(&index, "fn(*) -> Result", 10).unwrap();
933        assert!(!results.is_empty());
934        assert_eq!(results[0].chunk_id, 1);
935
936        // Search for structs with *Config name
937        let results = search_ast_pattern(&index, "struct *Config", 10).unwrap();
938        assert!(!results.is_empty());
939        assert_eq!(results[0].chunk_id, 3);
940
941        // Search for function named calculate
942        let results = search_ast_pattern(&index, "fn calculate(*)", 10).unwrap();
943        assert!(!results.is_empty());
944        assert_eq!(results[0].chunk_id, 2);
945    }
946
947    #[test]
948    fn test_empty_pattern_error() {
949        let result = parse_pattern("");
950        assert!(result.is_err());
951    }
952}