1use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum Intent {
21 SymbolQuery,
23 TextSearch,
25 TracePath,
27 FindCallers,
29 FindCallees,
31 Visualize,
33 IndexStatus,
35 Ambiguous,
37}
38
39impl Intent {
40 pub const NUM_CLASSES: usize = 8;
42
43 #[must_use]
45 pub fn from_index(idx: usize) -> Self {
46 match idx {
47 0 => Self::SymbolQuery,
48 1 => Self::TextSearch,
49 2 => Self::TracePath,
50 3 => Self::FindCallers,
51 4 => Self::FindCallees,
52 5 => Self::Visualize,
53 6 => Self::IndexStatus,
54 _ => Self::Ambiguous,
55 }
56 }
57
58 #[must_use]
60 pub const fn to_index(self) -> usize {
61 match self {
62 Self::SymbolQuery => 0,
63 Self::TextSearch => 1,
64 Self::TracePath => 2,
65 Self::FindCallers => 3,
66 Self::FindCallees => 4,
67 Self::Visualize => 5,
68 Self::IndexStatus => 6,
69 Self::Ambiguous => 7,
70 }
71 }
72
73 #[must_use]
75 pub const fn as_str(&self) -> &'static str {
76 match self {
77 Self::SymbolQuery => "symbol_query",
78 Self::TextSearch => "text_search",
79 Self::TracePath => "trace_path",
80 Self::FindCallers => "find_callers",
81 Self::FindCallees => "find_callees",
82 Self::Visualize => "visualize",
83 Self::IndexStatus => "index_status",
84 Self::Ambiguous => "ambiguous",
85 }
86 }
87
88 #[must_use]
90 pub const fn description(&self) -> &'static str {
91 match self {
92 Self::SymbolQuery => "Search for symbols by name or pattern",
93 Self::TextSearch => "Search for text patterns in code",
94 Self::TracePath => "Find call path between two symbols",
95 Self::FindCallers => "Find all places that call a symbol",
96 Self::FindCallees => "Find all symbols called by a function",
97 Self::Visualize => "Generate a diagram of code relationships",
98 Self::IndexStatus => "Check the status of the code index",
99 Self::Ambiguous => "Intent unclear, needs clarification",
100 }
101 }
102}
103
104impl std::fmt::Display for Intent {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 write!(f, "{}", self.as_str())
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
112#[serde(rename_all = "snake_case")]
113pub enum ValidationStatus {
114 Valid,
116 RejectedMetachar,
118 RejectedPathTraversal,
120 RejectedWriteMode,
122 RejectedEnvVar,
124 RejectedTooLong,
126 RejectedUnknown,
128}
129
130impl ValidationStatus {
131 #[must_use]
133 pub const fn is_valid(&self) -> bool {
134 matches!(self, Self::Valid)
135 }
136
137 #[must_use]
139 pub const fn rejection_reason(&self) -> Option<&'static str> {
140 match self {
141 Self::Valid => None,
142 Self::RejectedMetachar => Some("Contains shell metacharacters"),
143 Self::RejectedPathTraversal => Some("Contains path traversal"),
144 Self::RejectedWriteMode => Some("Attempts write operation"),
145 Self::RejectedEnvVar => Some("Contains environment variable"),
146 Self::RejectedTooLong => Some("Exceeds maximum command length"),
147 Self::RejectedUnknown => Some("Doesn't match allowed command patterns"),
148 }
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155pub enum PredicateType {
156 Impl,
158 Duplicates,
160 Circular,
162 Unused,
164}
165
166impl PredicateType {
167 #[must_use]
169 pub const fn as_prefix(&self) -> &'static str {
170 match self {
171 Self::Impl => "impl:",
172 Self::Duplicates => "duplicates:",
173 Self::Circular => "circular:",
174 Self::Unused => "unused:",
175 }
176 }
177}
178
179impl std::fmt::Display for PredicateType {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 write!(f, "{}", self.as_prefix())
182 }
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
187#[serde(rename_all = "snake_case")]
188pub enum Visibility {
189 Public,
191 Private,
193}
194
195impl Visibility {
196 #[must_use]
198 pub const fn as_str(&self) -> &'static str {
199 match self {
200 Self::Public => "public",
201 Self::Private => "private",
202 }
203 }
204}
205
206impl std::fmt::Display for Visibility {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 write!(f, "{}", self.as_str())
209 }
210}
211
212#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
214#[serde(rename_all = "snake_case")]
215pub enum SymbolKind {
216 Function,
218 Class,
220 Struct,
222 Enum,
224 Trait,
226 Interface,
228 Method,
230 Module,
232 Constant,
234 Variable,
236 TypeAlias,
238}
239
240impl SymbolKind {
241 #[must_use]
243 pub const fn as_str(&self) -> &'static str {
244 match self {
245 Self::Function => "function",
246 Self::Class => "class",
247 Self::Struct => "struct",
248 Self::Enum => "enum",
249 Self::Trait => "trait",
250 Self::Interface => "interface",
251 Self::Method => "method",
252 Self::Module => "module",
253 Self::Constant => "constant",
254 Self::Variable => "variable",
255 Self::TypeAlias => "type_alias",
256 }
257 }
258}
259
260impl std::fmt::Display for SymbolKind {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 write!(f, "{}", self.as_str())
263 }
264}
265
266#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
268#[serde(rename_all = "snake_case")]
269pub enum OutputFormat {
270 Mermaid,
272 Dot,
274 Json,
276}
277
278impl OutputFormat {
279 #[must_use]
281 pub const fn as_str(&self) -> &'static str {
282 match self {
283 Self::Mermaid => "mermaid",
284 Self::Dot => "dot",
285 Self::Json => "json",
286 }
287 }
288}
289
290impl std::fmt::Display for OutputFormat {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 write!(f, "{}", self.as_str())
293 }
294}
295
296#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
301pub struct ExtractedEntities {
302 pub symbols: Vec<String>,
304
305 pub languages: Vec<String>,
307
308 pub paths: Vec<String>,
310
311 pub kind: Option<SymbolKind>,
313
314 pub limit: Option<u32>,
316
317 pub depth: Option<u32>,
319
320 pub format: Option<OutputFormat>,
322
323 pub from_symbol: Option<String>,
325
326 pub to_symbol: Option<String>,
328
329 pub relation: Option<String>,
331
332 pub predicate_type: Option<PredicateType>,
335
336 pub impl_trait: Option<String>,
338
339 pub predicate_arg: Option<String>,
341
342 pub visibility: Option<Visibility>,
344
345 pub is_async: Option<bool>,
347
348 pub is_unsafe: Option<bool>,
350}
351
352impl ExtractedEntities {
353 #[must_use]
355 pub fn new() -> Self {
356 Self::default()
357 }
358
359 #[must_use]
361 pub fn has_symbols(&self) -> bool {
362 !self.symbols.is_empty()
363 }
364
365 #[must_use]
367 pub fn has_trace_path(&self) -> bool {
368 self.from_symbol.is_some() && self.to_symbol.is_some()
369 }
370
371 #[must_use]
373 pub fn primary_symbol(&self) -> Option<&str> {
374 self.symbols.first().map(String::as_str)
375 }
376
377 #[must_use]
379 pub fn has_predicate(&self) -> bool {
380 self.predicate_type.is_some()
381 || self.impl_trait.is_some()
382 || self.visibility.is_some()
383 || self.is_async.is_some()
384 || self.is_unsafe.is_some()
385 }
386
387 #[must_use]
389 pub fn is_impl_query(&self) -> bool {
390 self.predicate_type == Some(PredicateType::Impl) || self.impl_trait.is_some()
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
402#[serde(tag = "type", rename_all = "snake_case")]
403pub enum TranslationResponse {
404 Execute {
406 command: String,
408 confidence: f32,
410 intent: Intent,
412 cached: bool,
414 latency_ms: u64,
416 },
417
418 Confirm {
420 command: String,
422 confidence: f32,
424 prompt: String,
426 },
427
428 Disambiguate {
430 options: Vec<DisambiguationOption>,
432 prompt: String,
434 },
435
436 Reject {
438 reason: String,
440 suggestions: Vec<String>,
442 },
443}
444
445#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct DisambiguationOption {
448 pub command: String,
450 pub intent: Intent,
452 pub description: String,
454 pub confidence: f32,
456}
457
458#[derive(Debug, Clone)]
460pub struct PreprocessResult {
461 pub text: String,
463 pub quoted_spans: Vec<String>,
465 pub normalized: bool,
467 pub homoglyphs_replaced: bool,
469}
470
471impl PreprocessResult {
472 #[must_use]
474 pub fn ok(text: String, quoted_spans: Vec<String>) -> Self {
475 Self {
476 text,
477 quoted_spans,
478 normalized: false,
479 homoglyphs_replaced: false,
480 }
481 }
482}
483
484#[derive(Debug, Clone)]
486pub struct ClassificationResult {
487 pub intent: Intent,
489 pub confidence: f32,
491 pub all_probabilities: Vec<f32>,
493 pub model_version: String,
495}
496
497#[derive(Debug, Clone)]
499pub struct AssembledCommand {
500 pub command: String,
502 pub template_type: TemplateType,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq)]
508pub enum TemplateType {
509 Query,
511 Search,
513 TracePath,
515 GraphCallers,
517 GraphCallees,
519 Visualize,
521 IndexStatus,
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn test_intent_round_trip() {
531 for i in 0..Intent::NUM_CLASSES {
532 let intent = Intent::from_index(i);
533 assert_eq!(intent.to_index(), i);
534 }
535 }
536
537 #[test]
538 fn test_intent_display() {
539 assert_eq!(Intent::SymbolQuery.to_string(), "symbol_query");
540 assert_eq!(Intent::FindCallers.to_string(), "find_callers");
541 }
542
543 #[test]
544 fn test_validation_status_is_valid() {
545 assert!(ValidationStatus::Valid.is_valid());
546 assert!(!ValidationStatus::RejectedMetachar.is_valid());
547 }
548
549 #[test]
550 fn test_extracted_entities_default() {
551 let entities = ExtractedEntities::new();
552 assert!(!entities.has_symbols());
553 assert!(!entities.has_trace_path());
554 assert!(entities.primary_symbol().is_none());
555 }
556
557 #[test]
558 fn test_extracted_entities_with_symbols() {
559 let mut entities = ExtractedEntities::new();
560 entities.symbols.push("foo".to_string());
561 entities.symbols.push("bar".to_string());
562
563 assert!(entities.has_symbols());
564 assert_eq!(entities.primary_symbol(), Some("foo"));
565 }
566
567 #[test]
568 fn test_translation_response_serde() {
569 let response = TranslationResponse::Execute {
570 command: "sqry query \"test\"".to_string(),
571 confidence: 0.95,
572 intent: Intent::SymbolQuery,
573 cached: false,
574 latency_ms: 42,
575 };
576
577 let json = serde_json::to_string(&response).unwrap();
578 assert!(json.contains("\"type\":\"execute\""));
579 assert!(json.contains("symbol_query"));
580
581 let parsed: TranslationResponse = serde_json::from_str(&json).unwrap();
582 if let TranslationResponse::Execute { confidence, .. } = parsed {
583 assert!((confidence - 0.95).abs() < f32::EPSILON);
584 } else {
585 panic!("Wrong variant");
586 }
587 }
588}