Skip to main content

sqry_nl/assembler/
mod.rs

1//! Command template assembly.
2//!
3//! Maps (intent, entities) to validated sqry command strings.
4
5mod formatting;
6pub(crate) mod templates;
7
8use crate::error::{AssemblerError, NlResult};
9use crate::types::{ExtractedEntities, Intent, PredicateType, TemplateType, Visibility};
10
11/// Convenience function to assemble a command from intent and entities.
12///
13/// Uses default assembler configuration.
14///
15/// # Errors
16///
17/// Returns [`AssemblerError`] if assembly fails.
18pub fn assemble_command(
19    intent: &Intent,
20    entities: &ExtractedEntities,
21) -> Result<String, AssemblerError> {
22    let assembler = TemplateAssembler::default();
23    assembler
24        .assemble(*intent, entities)
25        .map(|cmd| cmd.command)
26        .map_err(|e| match e {
27            crate::error::NlError::Assembler(ae) => ae,
28            _ => AssemblerError::AmbiguousIntent, // Fallback for unexpected errors
29        })
30}
31
32/// Assembled command result.
33#[derive(Debug, Clone)]
34pub struct AssembledCommand {
35    /// The assembled command string.
36    pub command: String,
37    /// The template type used.
38    pub template_type: TemplateType,
39}
40
41/// Command assembler configuration.
42#[derive(Debug, Clone)]
43pub struct AssemblerConfig {
44    /// Default limit when not specified in query
45    pub default_limit: u32,
46    /// Maximum command length
47    pub max_command_length: usize,
48    /// Default max depth for graph queries
49    pub default_max_depth: u32,
50}
51
52impl Default for AssemblerConfig {
53    fn default() -> Self {
54        Self {
55            default_limit: 100,
56            max_command_length: 512,
57            default_max_depth: 10,
58        }
59    }
60}
61
62/// Template assembler for building sqry commands.
63pub struct TemplateAssembler {
64    config: AssemblerConfig,
65}
66
67impl Default for TemplateAssembler {
68    fn default() -> Self {
69        Self::new(AssemblerConfig::default())
70    }
71}
72
73impl TemplateAssembler {
74    /// Create a new assembler with the given configuration.
75    #[must_use]
76    pub fn new(config: AssemblerConfig) -> Self {
77        Self { config }
78    }
79
80    /// Assemble a command from intent and entities.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`AssemblerError`] if:
85    /// - Intent is Ambiguous
86    /// - Required entities are missing
87    /// - Generated command exceeds length limit
88    pub fn assemble(
89        &self,
90        intent: Intent,
91        entities: &ExtractedEntities,
92    ) -> NlResult<AssembledCommand> {
93        match intent {
94            Intent::SymbolQuery => self.build_query_command(entities),
95            Intent::TextSearch => self.build_search_command(entities),
96            Intent::TracePath => self.build_trace_path_command(entities),
97            Intent::FindCallers => self.build_callers_command(entities),
98            Intent::FindCallees => self.build_callees_command(entities),
99            Intent::Visualize => self.build_visualize_command(entities),
100            Intent::IndexStatus => self.build_index_status_command(entities),
101            Intent::Ambiguous => Err(AssemblerError::AmbiguousIntent.into()),
102        }
103    }
104
105    fn build_command(
106        &self,
107        parts: &[String],
108        template_type: TemplateType,
109    ) -> NlResult<AssembledCommand> {
110        let command = parts.join(" ");
111        self.validate_length(&command)?;
112
113        Ok(AssembledCommand {
114            command,
115            template_type,
116        })
117    }
118
119    fn push_languages(parts: &mut Vec<String>, languages: &[String]) {
120        for lang in languages {
121            parts.push(format!("--language {lang}"));
122        }
123    }
124
125    fn push_path(parts: &mut Vec<String>, paths: &[String]) {
126        if let Some(path) = paths.first() {
127            parts.push(format!("--path \"{}\"", formatting::escape_quotes(path)));
128        }
129    }
130
131    fn require_primary_symbol(
132        entities: &ExtractedEntities,
133        error: AssemblerError,
134    ) -> NlResult<&str> {
135        entities.primary_symbol().ok_or_else(|| error.into())
136    }
137
138    fn build_query_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
139        // Check if this is a CD predicate query first
140        let query_expr = Self::build_query_expression(entities)?;
141
142        let mut parts = vec![
143            "sqry".to_string(),
144            "query".to_string(),
145            format!("\"{}\"", formatting::escape_quotes(&query_expr)),
146        ];
147
148        // Add language filters
149        Self::push_languages(&mut parts, &entities.languages);
150
151        // Add path filter
152        Self::push_path(&mut parts, &entities.paths);
153
154        // Note: kind filter is now included in query expression as `kind:X` predicate
155        // (see build_query_expression)
156
157        // Add limit
158        let limit = entities.limit.unwrap_or(self.config.default_limit);
159        parts.push(format!("--limit {limit}"));
160
161        self.build_command(&parts, TemplateType::Query)
162    }
163
164    /// Build the query expression, handling CD predicates.
165    ///
166    /// Returns the query expression string (e.g., "impl:Future", "duplicates:body", etc.)
167    fn build_query_expression(entities: &ExtractedEntities) -> NlResult<String> {
168        let mut expr_parts = Self::collect_predicates(entities);
169
170        if expr_parts.is_empty() {
171            return Self::build_symbol_only_query(entities);
172        }
173
174        if let Some(symbol) = entities.primary_symbol()
175            && Self::should_include_symbol(entities, symbol)
176        {
177            expr_parts.push(symbol.to_string());
178        }
179
180        // Use AND to join predicates - required for correct parsing
181        // when boolean predicates like async:true are combined with other predicates
182        Ok(expr_parts.join(" AND "))
183    }
184
185    fn build_search_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
186        let pattern = Self::require_primary_symbol(entities, AssemblerError::MissingSymbol)?;
187
188        let mut parts = vec![
189            "sqry".to_string(),
190            "search".to_string(),
191            format!("\"{}\"", formatting::escape_quotes(pattern)),
192        ];
193
194        // Add language filters
195        Self::push_languages(&mut parts, &entities.languages);
196
197        // Add path filter
198        Self::push_path(&mut parts, &entities.paths);
199
200        self.build_command(&parts, TemplateType::Search)
201    }
202
203    fn build_trace_path_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
204        let from = entities
205            .from_symbol
206            .as_deref()
207            .or_else(|| entities.symbols.first().map(String::as_str))
208            .ok_or(AssemblerError::MissingTracePath)?;
209
210        let to = entities
211            .to_symbol
212            .as_deref()
213            .or_else(|| entities.symbols.get(1).map(String::as_str))
214            .ok_or(AssemblerError::MissingTracePath)?;
215
216        let mut parts = vec![
217            "sqry".to_string(),
218            "graph".to_string(),
219            "trace-path".to_string(),
220            format!("\"{}\"", formatting::escape_quotes(from)),
221            format!("\"{}\"", formatting::escape_quotes(to)),
222        ];
223
224        // Add max-depth
225        let depth = entities.depth.unwrap_or(self.config.default_max_depth);
226        parts.push(format!("--max-depth {depth}"));
227
228        let command = parts.join(" ");
229        self.validate_length(&command)?;
230
231        Ok(AssembledCommand {
232            command,
233            template_type: TemplateType::TracePath,
234        })
235    }
236
237    fn build_callers_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
238        let symbol = Self::require_primary_symbol(entities, AssemblerError::MissingSymbol)?;
239
240        let mut parts = vec![
241            "sqry".to_string(),
242            "graph".to_string(),
243            "direct-callers".to_string(),
244            format!("\"{}\"", formatting::escape_quotes(symbol)),
245        ];
246
247        // Single language only for graph commands
248        if let Some(lang) = entities.languages.first() {
249            parts.push(format!("--language {lang}"));
250        }
251
252        self.build_command(&parts, TemplateType::GraphCallers)
253    }
254
255    fn build_callees_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
256        let symbol = Self::require_primary_symbol(entities, AssemblerError::MissingSymbol)?;
257
258        let mut parts = vec![
259            "sqry".to_string(),
260            "graph".to_string(),
261            "direct-callees".to_string(),
262            format!("\"{}\"", formatting::escape_quotes(symbol)),
263        ];
264
265        // Single language only for graph commands
266        if let Some(lang) = entities.languages.first() {
267            parts.push(format!("--language {lang}"));
268        }
269
270        self.build_command(&parts, TemplateType::GraphCallees)
271    }
272
273    fn build_visualize_command(&self, entities: &ExtractedEntities) -> NlResult<AssembledCommand> {
274        let symbol = Self::require_primary_symbol(entities, AssemblerError::MissingSymbol)?;
275
276        let relation = entities.relation.as_deref().unwrap_or("call");
277
278        let mut parts = vec![
279            "sqry".to_string(),
280            "visualize".to_string(),
281            format!("--relation {}", relation),
282            format!("--symbol \"{}\"", formatting::escape_quotes(symbol)),
283        ];
284
285        // Add format
286        if let Some(format) = entities.format {
287            parts.push(format!("--format {}", format.as_str()));
288        }
289
290        self.build_command(&parts, TemplateType::Visualize)
291    }
292
293    fn build_index_status_command(
294        &self,
295        entities: &ExtractedEntities,
296    ) -> NlResult<AssembledCommand> {
297        let mut parts = vec![
298            "sqry".to_string(),
299            "index".to_string(),
300            "--status".to_string(),
301        ];
302
303        // Add path filter
304        if let Some(path) = entities.paths.first() {
305            parts.push(format!("--path \"{}\"", formatting::escape_quotes(path)));
306        }
307
308        // Add JSON flag if requested
309        if entities.format == Some(crate::types::OutputFormat::Json) {
310            parts.push("--json".to_string());
311        }
312
313        self.build_command(&parts, TemplateType::IndexStatus)
314    }
315
316    fn collect_predicates(entities: &ExtractedEntities) -> Vec<String> {
317        let mut expr_parts = Vec::new();
318
319        // 1. Handle impl: predicate
320        if let Some(trait_name) = &entities.impl_trait {
321            expr_parts.push(format!("impl:{trait_name}"));
322        }
323
324        // 2. Handle duplicates: predicate (default to "body" if no arg specified)
325        if entities.predicate_type == Some(PredicateType::Duplicates) {
326            let arg = entities.predicate_arg.as_deref().unwrap_or("body");
327            expr_parts.push(format!("duplicates:{arg}"));
328        }
329
330        // 3. Handle circular: predicate (default to "calls" if no arg specified)
331        if entities.predicate_type == Some(PredicateType::Circular) {
332            let arg = entities.predicate_arg.as_deref().unwrap_or("calls");
333            expr_parts.push(format!("circular:{arg}"));
334        }
335
336        // 4. Handle unused: predicate
337        if entities.predicate_type == Some(PredicateType::Unused) {
338            expr_parts.push("unused:".to_string());
339        }
340
341        // 5. Handle visibility filter
342        if let Some(visibility) = entities.visibility {
343            match visibility {
344                Visibility::Public => expr_parts.push("visibility:public".to_string()),
345                Visibility::Private => expr_parts.push("visibility:private".to_string()),
346            }
347        }
348
349        // 6. Handle async filter
350        if entities.is_async == Some(true) {
351            expr_parts.push("async:true".to_string());
352        }
353
354        // 7. Handle unsafe filter
355        if entities.is_unsafe == Some(true) {
356            expr_parts.push("unsafe:true".to_string());
357        }
358
359        // 8. Handle kind filter as query predicate
360        if let Some(kind) = entities.kind {
361            expr_parts.push(format!("kind:{}", kind.as_str()));
362        }
363
364        expr_parts
365    }
366
367    fn build_symbol_only_query(entities: &ExtractedEntities) -> NlResult<String> {
368        match entities.primary_symbol() {
369            Some(symbol) => Ok(symbol.to_string()),
370            None if entities.kind.is_some() => Ok("*".to_string()),
371            None => Err(AssemblerError::MissingSymbol.into()),
372        }
373    }
374
375    fn should_include_symbol(entities: &ExtractedEntities, symbol: &str) -> bool {
376        entities.impl_trait.is_none() || symbol != entities.impl_trait.as_deref().unwrap_or("")
377    }
378
379    fn validate_length(&self, command: &str) -> NlResult<()> {
380        if command.len() > self.config.max_command_length {
381            return Err(AssemblerError::CommandTooLong {
382                len: command.len(),
383                max: self.config.max_command_length,
384            }
385            .into());
386        }
387        Ok(())
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::types::{PredicateType, SymbolKind, Visibility};
395
396    #[test]
397    fn test_build_query_basic() {
398        let assembler = TemplateAssembler::default();
399        let mut entities = ExtractedEntities::new();
400        entities.symbols.push("authenticate".to_string());
401
402        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
403        assert!(result.command.starts_with("sqry query"));
404        assert!(result.command.contains("\"authenticate\""));
405    }
406
407    #[test]
408    fn test_build_query_with_options() {
409        let assembler = TemplateAssembler::default();
410        let mut entities = ExtractedEntities::new();
411        entities.symbols.push("foo".to_string());
412        entities.languages.push("rust".to_string());
413        entities.kind = Some(SymbolKind::Function);
414        entities.limit = Some(10);
415
416        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
417        assert!(result.command.contains("--language rust"));
418        // kind is now part of the query expression, not a CLI flag
419        assert!(result.command.contains("kind:function"));
420        assert!(result.command.contains("--limit 10"));
421    }
422
423    #[test]
424    fn test_build_callers() {
425        let assembler = TemplateAssembler::default();
426        let mut entities = ExtractedEntities::new();
427        entities.symbols.push("login".to_string());
428
429        let result = assembler.assemble(Intent::FindCallers, &entities).unwrap();
430        assert!(result.command.contains("sqry graph direct-callers"));
431        assert!(result.command.contains("\"login\""));
432    }
433
434    #[test]
435    fn test_build_trace_path() {
436        let assembler = TemplateAssembler::default();
437        let mut entities = ExtractedEntities::new();
438        entities.from_symbol = Some("login".to_string());
439        entities.to_symbol = Some("database".to_string());
440
441        let result = assembler.assemble(Intent::TracePath, &entities).unwrap();
442        assert!(result.command.contains("sqry graph trace-path"));
443        assert!(result.command.contains("\"login\""));
444        assert!(result.command.contains("\"database\""));
445    }
446
447    #[test]
448    fn test_missing_symbol_error() {
449        let assembler = TemplateAssembler::default();
450        let entities = ExtractedEntities::new();
451
452        let result = assembler.assemble(Intent::SymbolQuery, &entities);
453        assert!(matches!(
454            result,
455            Err(crate::error::NlError::Assembler(
456                AssemblerError::MissingSymbol
457            ))
458        ));
459    }
460
461    #[test]
462    fn test_ambiguous_intent_error() {
463        let assembler = TemplateAssembler::default();
464        let entities = ExtractedEntities::new();
465
466        let result = assembler.assemble(Intent::Ambiguous, &entities);
467        assert!(matches!(
468            result,
469            Err(crate::error::NlError::Assembler(
470                AssemblerError::AmbiguousIntent
471            ))
472        ));
473    }
474
475    // --- CD Predicate tests ---
476
477    #[test]
478    fn test_build_query_impl_predicate() {
479        let assembler = TemplateAssembler::default();
480        let mut entities = ExtractedEntities::new();
481        entities.impl_trait = Some("Future".to_string());
482        entities.predicate_type = Some(PredicateType::Impl);
483
484        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
485        assert!(result.command.contains("\"impl:Future\""));
486        assert!(result.command.starts_with("sqry query"));
487    }
488
489    #[test]
490    fn test_build_query_duplicates_predicate() {
491        let assembler = TemplateAssembler::default();
492        let mut entities = ExtractedEntities::new();
493        entities.predicate_type = Some(PredicateType::Duplicates);
494
495        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
496        // Default to "body" when no arg specified
497        assert!(result.command.contains("\"duplicates:body\""));
498    }
499
500    #[test]
501    fn test_build_query_duplicates_signature() {
502        let assembler = TemplateAssembler::default();
503        let mut entities = ExtractedEntities::new();
504        entities.predicate_type = Some(PredicateType::Duplicates);
505        entities.predicate_arg = Some("signature".to_string());
506
507        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
508        assert!(result.command.contains("\"duplicates:signature\""));
509    }
510
511    #[test]
512    fn test_build_query_circular_predicate() {
513        let assembler = TemplateAssembler::default();
514        let mut entities = ExtractedEntities::new();
515        entities.predicate_type = Some(PredicateType::Circular);
516
517        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
518        // Default to "calls" when no arg specified
519        assert!(result.command.contains("\"circular:calls\""));
520    }
521
522    #[test]
523    fn test_build_query_unused_predicate() {
524        let assembler = TemplateAssembler::default();
525        let mut entities = ExtractedEntities::new();
526        entities.predicate_type = Some(PredicateType::Unused);
527
528        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
529        assert!(result.command.contains("\"unused:\""));
530    }
531
532    #[test]
533    fn test_build_query_visibility_public() {
534        let assembler = TemplateAssembler::default();
535        let mut entities = ExtractedEntities::new();
536        entities.visibility = Some(Visibility::Public);
537
538        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
539        assert!(result.command.contains("visibility:public"));
540    }
541
542    #[test]
543    fn test_build_query_async_predicate() {
544        let assembler = TemplateAssembler::default();
545        let mut entities = ExtractedEntities::new();
546        entities.is_async = Some(true);
547
548        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
549        assert!(result.command.contains("async:true"));
550    }
551
552    #[test]
553    fn test_build_query_unsafe_predicate() {
554        let assembler = TemplateAssembler::default();
555        let mut entities = ExtractedEntities::new();
556        entities.is_unsafe = Some(true);
557
558        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
559        assert!(result.command.contains("unsafe:true"));
560    }
561
562    #[test]
563    fn test_build_query_combined_predicates() {
564        let assembler = TemplateAssembler::default();
565        let mut entities = ExtractedEntities::new();
566        entities.visibility = Some(Visibility::Public);
567        entities.is_async = Some(true);
568
569        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
570        assert!(result.command.contains("visibility:public"));
571        assert!(result.command.contains("async:true"));
572    }
573
574    #[test]
575    fn test_build_query_impl_with_symbol_no_duplicate() {
576        let assembler = TemplateAssembler::default();
577        let mut entities = ExtractedEntities::new();
578        entities.impl_trait = Some("Iterator".to_string());
579        entities.predicate_type = Some(PredicateType::Impl);
580        // Symbol that matches trait name should not be duplicated
581        entities.symbols.push("Iterator".to_string());
582
583        let result = assembler.assemble(Intent::SymbolQuery, &entities).unwrap();
584        // Should only have impl:Iterator once, not "impl:Iterator Iterator"
585        assert!(result.command.contains("\"impl:Iterator\""));
586        // Count occurrences of "Iterator" in the command
587        let count = result.command.matches("Iterator").count();
588        assert_eq!(
589            count, 1,
590            "Iterator should only appear once in: {}",
591            result.command
592        );
593    }
594}
595
596// Predicate assembly regression tests
597#[cfg(test)]
598mod predicate_assembly_tests {
599    use super::*;
600    use crate::extractor::extract_entities;
601
602    #[test]
603    fn test_async_functions_assembly() {
604        let assembler = TemplateAssembler::default();
605        let entities = extract_entities("find async functions");
606        let result = assembler.assemble(Intent::SymbolQuery, &entities);
607
608        assert!(result.is_ok());
609        let cmd = result.unwrap();
610        assert!(cmd.command.contains("async:true"));
611    }
612
613    #[test]
614    fn test_unsafe_functions_assembly() {
615        let assembler = TemplateAssembler::default();
616        let entities = extract_entities("find unsafe functions");
617        let result = assembler.assemble(Intent::SymbolQuery, &entities);
618
619        assert!(result.is_ok());
620        let cmd = result.unwrap();
621        assert!(cmd.command.contains("unsafe:true"));
622    }
623
624    #[test]
625    fn test_public_async_functions_assembly() {
626        let assembler = TemplateAssembler::default();
627        let entities = extract_entities("find public async functions");
628        let result = assembler.assemble(Intent::SymbolQuery, &entities);
629
630        assert!(result.is_ok());
631        let cmd = result.unwrap();
632        assert!(cmd.command.contains("visibility:public"));
633        assert!(cmd.command.contains("async:true"));
634    }
635}