Skip to main content

sqry_nl/
types.rs

1//! Core types for the sqry-nl crate.
2//!
3//! All types are designed to be `Send + Sync` for thread-safe usage.
4
5use serde::{Deserialize, Serialize};
6
7/// Intent classification result from the NL classifier.
8///
9/// Each intent maps to a specific sqry command template:
10/// - `SymbolQuery` → `sqry query "<expr>"`
11/// - `TextSearch` → `sqry search "<pattern>"`
12/// - `TracePath` → `sqry graph trace-path "<from>" "<to>"`
13/// - `FindCallers` → `sqry graph direct-callers "<symbol>"`
14/// - `FindCallees` → `sqry graph direct-callees "<symbol>"`
15/// - `Visualize` → `sqry visualize --relation <kind> --symbol "<name>"`
16/// - `IndexStatus` → `sqry index --status`
17/// - `Ambiguous` → Cannot determine intent, needs clarification
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum Intent {
21    /// Search for symbols by name, kind, or pattern
22    SymbolQuery,
23    /// Text/grep search for patterns in code
24    TextSearch,
25    /// Trace call path between two symbols
26    TracePath,
27    /// Find all callers of a symbol
28    FindCallers,
29    /// Find all callees of a symbol
30    FindCallees,
31    /// Generate visualization (Mermaid/DOT)
32    Visualize,
33    /// Check index status
34    IndexStatus,
35    /// Intent unclear, need disambiguation
36    Ambiguous,
37}
38
39impl Intent {
40    /// Number of intent classes
41    pub const NUM_CLASSES: usize = 8;
42
43    /// Convert from classifier output index
44    #[must_use]
45    pub fn from_index(idx: usize) -> Self {
46        match idx {
47            0 => Self::SymbolQuery,
48            1 => Self::TextSearch,
49            2 => Self::TracePath,
50            3 => Self::FindCallers,
51            4 => Self::FindCallees,
52            5 => Self::Visualize,
53            6 => Self::IndexStatus,
54            _ => Self::Ambiguous,
55        }
56    }
57
58    /// Convert to classifier output index
59    #[must_use]
60    pub const fn to_index(self) -> usize {
61        match self {
62            Self::SymbolQuery => 0,
63            Self::TextSearch => 1,
64            Self::TracePath => 2,
65            Self::FindCallers => 3,
66            Self::FindCallees => 4,
67            Self::Visualize => 5,
68            Self::IndexStatus => 6,
69            Self::Ambiguous => 7,
70        }
71    }
72
73    /// Human-readable name for the intent
74    #[must_use]
75    pub const fn as_str(&self) -> &'static str {
76        match self {
77            Self::SymbolQuery => "symbol_query",
78            Self::TextSearch => "text_search",
79            Self::TracePath => "trace_path",
80            Self::FindCallers => "find_callers",
81            Self::FindCallees => "find_callees",
82            Self::Visualize => "visualize",
83            Self::IndexStatus => "index_status",
84            Self::Ambiguous => "ambiguous",
85        }
86    }
87
88    /// Description of what this intent does
89    #[must_use]
90    pub const fn description(&self) -> &'static str {
91        match self {
92            Self::SymbolQuery => "Search for symbols by name or pattern",
93            Self::TextSearch => "Search for text patterns in code",
94            Self::TracePath => "Find call path between two symbols",
95            Self::FindCallers => "Find all places that call a symbol",
96            Self::FindCallees => "Find all symbols called by a function",
97            Self::Visualize => "Generate a diagram of code relationships",
98            Self::IndexStatus => "Check the status of the code index",
99            Self::Ambiguous => "Intent unclear, needs clarification",
100        }
101    }
102}
103
104impl std::fmt::Display for Intent {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        write!(f, "{}", self.as_str())
107    }
108}
109
110/// Validation status for generated commands.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
112#[serde(rename_all = "snake_case")]
113pub enum ValidationStatus {
114    /// Command passed all validation checks
115    Valid,
116    /// Rejected: contains shell metacharacters
117    RejectedMetachar,
118    /// Rejected: contains path traversal
119    RejectedPathTraversal,
120    /// Rejected: attempts write operation
121    RejectedWriteMode,
122    /// Rejected: contains environment variable
123    RejectedEnvVar,
124    /// Rejected: exceeds length limit
125    RejectedTooLong,
126    /// Rejected: doesn't match any allowed template
127    RejectedUnknown,
128}
129
130impl ValidationStatus {
131    /// Whether this status represents a valid command
132    #[must_use]
133    pub const fn is_valid(&self) -> bool {
134        matches!(self, Self::Valid)
135    }
136
137    /// Human-readable reason for rejection
138    #[must_use]
139    pub const fn rejection_reason(&self) -> Option<&'static str> {
140        match self {
141            Self::Valid => None,
142            Self::RejectedMetachar => Some("Contains shell metacharacters"),
143            Self::RejectedPathTraversal => Some("Contains path traversal"),
144            Self::RejectedWriteMode => Some("Attempts write operation"),
145            Self::RejectedEnvVar => Some("Contains environment variable"),
146            Self::RejectedTooLong => Some("Exceeds maximum command length"),
147            Self::RejectedUnknown => Some("Doesn't match allowed command patterns"),
148        }
149    }
150}
151
152/// Predicate type for CD (Cross-file Discovery) queries.
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155pub enum PredicateType {
156    /// Find trait implementations (`impl:Trait`)
157    Impl,
158    /// Find duplicate code (`duplicates:` or `duplicates:body`)
159    Duplicates,
160    /// Find circular dependencies (`circular:` or `circular:calls`)
161    Circular,
162    /// Find unused code (`unused:`)
163    Unused,
164}
165
166impl PredicateType {
167    /// Convert to sqry predicate prefix
168    #[must_use]
169    pub const fn as_prefix(&self) -> &'static str {
170        match self {
171            Self::Impl => "impl:",
172            Self::Duplicates => "duplicates:",
173            Self::Circular => "circular:",
174            Self::Unused => "unused:",
175        }
176    }
177}
178
179impl std::fmt::Display for PredicateType {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        write!(f, "{}", self.as_prefix())
182    }
183}
184
185/// Visibility filter for symbol queries.
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
187#[serde(rename_all = "snake_case")]
188pub enum Visibility {
189    /// Public symbols only
190    Public,
191    /// Private symbols only
192    Private,
193}
194
195impl Visibility {
196    /// Convert to sqry predicate value
197    #[must_use]
198    pub const fn as_str(&self) -> &'static str {
199        match self {
200            Self::Public => "public",
201            Self::Private => "private",
202        }
203    }
204}
205
206impl std::fmt::Display for Visibility {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        write!(f, "{}", self.as_str())
209    }
210}
211
212/// Symbol kind for filtering queries.
213#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
214#[serde(rename_all = "snake_case")]
215pub enum SymbolKind {
216    /// Function or method
217    Function,
218    /// Class definition
219    Class,
220    /// Struct definition
221    Struct,
222    /// Enum definition
223    Enum,
224    /// Trait definition (Rust) or interface (other languages)
225    Trait,
226    /// Interface definition
227    Interface,
228    /// Method (attached to a type)
229    Method,
230    /// Module or package
231    Module,
232    /// Constant or static variable
233    Constant,
234    /// Variable definition
235    Variable,
236    /// Type alias
237    TypeAlias,
238}
239
240impl SymbolKind {
241    /// Convert to CLI argument value
242    #[must_use]
243    pub const fn as_str(&self) -> &'static str {
244        match self {
245            Self::Function => "function",
246            Self::Class => "class",
247            Self::Struct => "struct",
248            Self::Enum => "enum",
249            Self::Trait => "trait",
250            Self::Interface => "interface",
251            Self::Method => "method",
252            Self::Module => "module",
253            Self::Constant => "constant",
254            Self::Variable => "variable",
255            Self::TypeAlias => "type_alias",
256        }
257    }
258}
259
260impl std::fmt::Display for SymbolKind {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        write!(f, "{}", self.as_str())
263    }
264}
265
266/// Output format for visualization commands.
267#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
268#[serde(rename_all = "snake_case")]
269pub enum OutputFormat {
270    /// Mermaid diagram format
271    Mermaid,
272    /// Graphviz DOT format
273    Dot,
274    /// JSON format
275    Json,
276}
277
278impl OutputFormat {
279    /// Convert to CLI argument value
280    #[must_use]
281    pub const fn as_str(&self) -> &'static str {
282        match self {
283            Self::Mermaid => "mermaid",
284            Self::Dot => "dot",
285            Self::Json => "json",
286        }
287    }
288}
289
290impl std::fmt::Display for OutputFormat {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        write!(f, "{}", self.as_str())
293    }
294}
295
296/// Entities extracted from natural language input.
297///
298/// This struct contains all the "slots" that can be filled from
299/// a natural language query.
300#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
301pub struct ExtractedEntities {
302    /// Symbol names or patterns to search for
303    pub symbols: Vec<String>,
304
305    /// Programming languages to filter by
306    pub languages: Vec<String>,
307
308    /// Path patterns to filter by
309    pub paths: Vec<String>,
310
311    /// Symbol kind filter
312    pub kind: Option<SymbolKind>,
313
314    /// Maximum number of results
315    pub limit: Option<u32>,
316
317    /// Maximum depth for graph traversal
318    pub depth: Option<u32>,
319
320    /// Output format for visualization
321    pub format: Option<OutputFormat>,
322
323    /// Source symbol for trace-path
324    pub from_symbol: Option<String>,
325
326    /// Target symbol for trace-path
327    pub to_symbol: Option<String>,
328
329    /// Relation type for visualization
330    pub relation: Option<String>,
331
332    // --- CD Predicate fields ---
333    /// Predicate type for CD queries (impl, duplicates, circular, unused)
334    pub predicate_type: Option<PredicateType>,
335
336    /// Trait name for impl: predicate (e.g., "Future" in "impl:Future")
337    pub impl_trait: Option<String>,
338
339    /// Predicate argument (e.g., "body" in "duplicates:body", "calls" in "circular:calls")
340    pub predicate_arg: Option<String>,
341
342    /// Visibility filter (public/private)
343    pub visibility: Option<Visibility>,
344
345    /// Async filter (true = find async functions)
346    pub is_async: Option<bool>,
347
348    /// Unsafe filter (true = find unsafe code)
349    pub is_unsafe: Option<bool>,
350}
351
352impl ExtractedEntities {
353    /// Create empty entities
354    #[must_use]
355    pub fn new() -> Self {
356        Self::default()
357    }
358
359    /// Check if any symbols were extracted
360    #[must_use]
361    pub fn has_symbols(&self) -> bool {
362        !self.symbols.is_empty()
363    }
364
365    /// Check if trace-path entities are complete
366    #[must_use]
367    pub fn has_trace_path(&self) -> bool {
368        self.from_symbol.is_some() && self.to_symbol.is_some()
369    }
370
371    /// Get the primary symbol (first one)
372    #[must_use]
373    pub fn primary_symbol(&self) -> Option<&str> {
374        self.symbols.first().map(String::as_str)
375    }
376
377    /// Check if this is a CD predicate query
378    #[must_use]
379    pub fn has_predicate(&self) -> bool {
380        self.predicate_type.is_some()
381            || self.impl_trait.is_some()
382            || self.visibility.is_some()
383            || self.is_async.is_some()
384            || self.is_unsafe.is_some()
385    }
386
387    /// Check if this is an impl: query
388    #[must_use]
389    pub fn is_impl_query(&self) -> bool {
390        self.predicate_type == Some(PredicateType::Impl) || self.impl_trait.is_some()
391    }
392}
393
394/// Response from the translation pipeline.
395///
396/// Implements tiered confidence responses (H8 mitigation):
397/// - `Execute`: High confidence (≥0.85) - run immediately
398/// - `Confirm`: Medium confidence (0.65-0.85) - ask for confirmation
399/// - `Disambiguate`: Low confidence (<0.65) - present options
400/// - `Reject`: Validation failure or error
401#[derive(Debug, Clone, Serialize, Deserialize)]
402#[serde(tag = "type", rename_all = "snake_case")]
403pub enum TranslationResponse {
404    /// High confidence: execute the command
405    Execute {
406        /// The translated sqry command
407        command: String,
408        /// Classifier confidence (0.0-1.0)
409        confidence: f32,
410        /// Classified intent
411        intent: Intent,
412        /// Whether result was from cache
413        cached: bool,
414        /// Translation latency in milliseconds
415        latency_ms: u64,
416    },
417
418    /// Medium confidence: ask user to confirm
419    Confirm {
420        /// The translated sqry command
421        command: String,
422        /// Classifier confidence (0.0-1.0)
423        confidence: f32,
424        /// Human-readable prompt for confirmation
425        prompt: String,
426    },
427
428    /// Low confidence: present options to disambiguate
429    Disambiguate {
430        /// Possible interpretations with commands
431        options: Vec<DisambiguationOption>,
432        /// Human-readable prompt
433        prompt: String,
434    },
435
436    /// Validation failure or error
437    Reject {
438        /// Reason for rejection
439        reason: String,
440        /// Helpful suggestions
441        suggestions: Vec<String>,
442    },
443}
444
445/// A disambiguation option presented when confidence is low.
446#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct DisambiguationOption {
448    /// The translated sqry command for this interpretation
449    pub command: String,
450    /// The intent this option represents
451    pub intent: Intent,
452    /// Human-readable description
453    pub description: String,
454    /// Confidence for this interpretation
455    pub confidence: f32,
456}
457
458/// Result from preprocessing stage.
459#[derive(Debug, Clone)]
460pub struct PreprocessResult {
461    /// Cleaned and normalized text
462    pub text: String,
463    /// Quoted spans extracted from input (preserved verbatim)
464    pub quoted_spans: Vec<String>,
465    /// Whether any normalization was applied
466    pub normalized: bool,
467    /// Whether any homoglyphs were replaced
468    pub homoglyphs_replaced: bool,
469}
470
471impl PreprocessResult {
472    /// Create a successful preprocess result
473    #[must_use]
474    pub fn ok(text: String, quoted_spans: Vec<String>) -> Self {
475        Self {
476            text,
477            quoted_spans,
478            normalized: false,
479            homoglyphs_replaced: false,
480        }
481    }
482}
483
484/// Result from intent classification.
485#[derive(Debug, Clone)]
486pub struct ClassificationResult {
487    /// Classified intent
488    pub intent: Intent,
489    /// Confidence score (0.0-1.0)
490    pub confidence: f32,
491    /// All class probabilities
492    pub all_probabilities: Vec<f32>,
493    /// Model version used for classification
494    pub model_version: String,
495}
496
497/// Assembled command from template.
498#[derive(Debug, Clone)]
499pub struct AssembledCommand {
500    /// The full command string
501    pub command: String,
502    /// The template type used
503    pub template_type: TemplateType,
504}
505
506/// Template type for assembled commands.
507#[derive(Debug, Clone, Copy, PartialEq, Eq)]
508pub enum TemplateType {
509    /// `sqry query "<expr>" [options]`
510    Query,
511    /// `sqry search "<pattern>" [options]`
512    Search,
513    /// `sqry graph trace-path "<from>" "<to>" [options]`
514    TracePath,
515    /// `sqry graph direct-callers "<symbol>" [options]`
516    GraphCallers,
517    /// `sqry graph direct-callees "<symbol>" [options]`
518    GraphCallees,
519    /// `sqry visualize [options]`
520    Visualize,
521    /// `sqry index --status [options]`
522    IndexStatus,
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[test]
530    fn test_intent_round_trip() {
531        for i in 0..Intent::NUM_CLASSES {
532            let intent = Intent::from_index(i);
533            assert_eq!(intent.to_index(), i);
534        }
535    }
536
537    #[test]
538    fn test_intent_display() {
539        assert_eq!(Intent::SymbolQuery.to_string(), "symbol_query");
540        assert_eq!(Intent::FindCallers.to_string(), "find_callers");
541    }
542
543    #[test]
544    fn test_validation_status_is_valid() {
545        assert!(ValidationStatus::Valid.is_valid());
546        assert!(!ValidationStatus::RejectedMetachar.is_valid());
547    }
548
549    #[test]
550    fn test_extracted_entities_default() {
551        let entities = ExtractedEntities::new();
552        assert!(!entities.has_symbols());
553        assert!(!entities.has_trace_path());
554        assert!(entities.primary_symbol().is_none());
555    }
556
557    #[test]
558    fn test_extracted_entities_with_symbols() {
559        let mut entities = ExtractedEntities::new();
560        entities.symbols.push("foo".to_string());
561        entities.symbols.push("bar".to_string());
562
563        assert!(entities.has_symbols());
564        assert_eq!(entities.primary_symbol(), Some("foo"));
565    }
566
567    #[test]
568    fn test_translation_response_serde() {
569        let response = TranslationResponse::Execute {
570            command: "sqry query \"test\"".to_string(),
571            confidence: 0.95,
572            intent: Intent::SymbolQuery,
573            cached: false,
574            latency_ms: 42,
575        };
576
577        let json = serde_json::to_string(&response).unwrap();
578        assert!(json.contains("\"type\":\"execute\""));
579        assert!(json.contains("symbol_query"));
580
581        let parsed: TranslationResponse = serde_json::from_str(&json).unwrap();
582        if let TranslationResponse::Execute { confidence, .. } = parsed {
583            assert!((confidence - 0.95).abs() < f32::EPSILON);
584        } else {
585            panic!("Wrong variant");
586        }
587    }
588}