1use crate::token_budget::{BudgetSection, TokenBudgetConfig, TokenBudgetEnforcer, TokenEstimator};
55use crate::soch_ql::{ComparisonOp, Condition, LogicalOp, SochValue, WhereClause};
56use std::collections::HashMap;
57
58#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct ContextSelectQuery {
65 pub output_name: String,
67 pub session: SessionReference,
69 pub options: ContextQueryOptions,
71 pub sections: Vec<ContextSection>,
73}
74
75#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
77pub enum SessionReference {
78 Session(String),
80 Agent(String),
82 None,
84}
85
86#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
88pub struct ContextQueryOptions {
89 pub token_limit: usize,
91 pub include_schema: bool,
93 pub format: OutputFormat,
95 pub truncation: TruncationStrategy,
97 pub include_headers: bool,
99}
100
101impl Default for ContextQueryOptions {
102 fn default() -> Self {
103 Self {
104 token_limit: 4096,
105 include_schema: true,
106 format: OutputFormat::Soch,
107 truncation: TruncationStrategy::TailDrop,
108 include_headers: true,
109 }
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
115pub enum OutputFormat {
116 Soch,
118 Json,
120 Markdown,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
126pub enum TruncationStrategy {
127 TailDrop,
129 HeadDrop,
131 Proportional,
133 Fail,
135}
136
137#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
143pub struct ContextSection {
144 pub name: String,
146 pub priority: i32,
148 pub content: SectionContent,
150 pub transform: Option<SectionTransform>,
152}
153
154#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
156pub enum SectionContent {
157 Get { path: PathExpression },
160
161 Last {
164 count: usize,
165 table: String,
166 where_clause: Option<WhereClause>,
167 },
168
169 Search {
172 collection: String,
173 query: SimilarityQuery,
174 top_k: usize,
175 min_score: Option<f32>,
176 },
177
178 Select {
180 columns: Vec<String>,
181 table: String,
182 where_clause: Option<WhereClause>,
183 limit: Option<usize>,
184 },
185
186 Literal { value: String },
188
189 Variable { name: String },
191
192 ToolRegistry {
195 include: Vec<String>,
197 exclude: Vec<String>,
199 include_schema: bool,
201 },
202
203 ToolCalls {
206 count: usize,
208 tool_filter: Option<String>,
210 status_filter: Option<String>,
212 include_outputs: bool,
214 },
215}
216
217#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
219pub struct PathExpression {
220 pub segments: Vec<String>,
222 pub fields: Vec<String>,
224 pub all_fields: bool,
226}
227
228impl PathExpression {
229 pub fn parse(input: &str) -> Result<Self, ContextParseError> {
232 let input = input.trim();
233
234 if let Some(brace_start) = input.find('{') {
236 if !input.ends_with('}') {
237 return Err(ContextParseError::InvalidPath(
238 "unclosed field projection".to_string(),
239 ));
240 }
241
242 let path_part = &input[..brace_start].trim_end_matches('.');
243 let fields_part = &input[brace_start + 1..input.len() - 1];
244
245 let segments: Vec<String> = path_part.split('.').map(|s| s.to_string()).collect();
246 let fields: Vec<String> = fields_part
247 .split(',')
248 .map(|s| s.trim().to_string())
249 .filter(|s| !s.is_empty())
250 .collect();
251
252 Ok(PathExpression {
253 segments,
254 fields,
255 all_fields: false,
256 })
257 } else if let Some(path_part) = input.strip_suffix(".**") {
258 let segments: Vec<String> = path_part.split('.').map(|s| s.to_string()).collect();
260
261 Ok(PathExpression {
262 segments,
263 fields: vec![],
264 all_fields: true,
265 })
266 } else {
267 let segments: Vec<String> = input.split('.').map(|s| s.to_string()).collect();
269
270 Ok(PathExpression {
271 segments,
272 fields: vec![],
273 all_fields: true,
274 })
275 }
276 }
277
278 pub fn to_path_string(&self) -> String {
280 let base = self.segments.join(".");
281 if self.all_fields {
282 format!("{}.**", base)
283 } else if !self.fields.is_empty() {
284 format!("{}.{{{}}}", base, self.fields.join(", "))
285 } else {
286 base
287 }
288 }
289}
290
291#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
293pub enum SimilarityQuery {
294 Variable(String),
296 Embedding(Vec<f32>),
298 Text(String),
300}
301
302#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
304pub enum SectionTransform {
305 Summarize { max_tokens: usize },
307 Project { fields: Vec<String> },
309 Template { template: String },
311 Custom { function: String },
313}
314
315#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
327pub struct ContextRecipe {
328 pub id: String,
330 pub name: String,
332 pub description: String,
334 pub version: String,
336 pub query: ContextSelectQuery,
338 pub metadata: RecipeMetadata,
340 pub session_binding: Option<SessionBinding>,
342}
343
344#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
346pub struct RecipeMetadata {
347 pub author: Option<String>,
349 pub created_at: Option<String>,
351 pub updated_at: Option<String>,
353 pub tags: Vec<String>,
355 pub usage_count: u64,
357 pub avg_tokens: Option<f32>,
359}
360
361#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
363pub enum SessionBinding {
364 Session(String),
366 Agent(String),
368 Pattern(String),
370 None,
372}
373
374pub struct ContextRecipeStore {
376 recipes: std::sync::RwLock<HashMap<String, ContextRecipe>>,
378 versions: std::sync::RwLock<HashMap<String, Vec<String>>>,
380}
381
382impl ContextRecipeStore {
383 pub fn new() -> Self {
385 Self {
386 recipes: std::sync::RwLock::new(HashMap::new()),
387 versions: std::sync::RwLock::new(HashMap::new()),
388 }
389 }
390
391 pub fn save(&self, recipe: ContextRecipe) -> Result<(), String> {
393 let mut recipes = self.recipes.write().map_err(|e| e.to_string())?;
394 let mut versions = self.versions.write().map_err(|e| e.to_string())?;
395
396 let key = format!("{}:{}", recipe.id, recipe.version);
397 recipes.insert(key.clone(), recipe.clone());
398
399 versions
400 .entry(recipe.id.clone())
401 .or_default()
402 .push(recipe.version.clone());
403
404 Ok(())
405 }
406
407 pub fn get_latest(&self, recipe_id: &str) -> Option<ContextRecipe> {
409 let versions = self.versions.read().ok()?;
410 let latest_version = versions.get(recipe_id)?.last()?;
411
412 let recipes = self.recipes.read().ok()?;
413 let key = format!("{}:{}", recipe_id, latest_version);
414 recipes.get(&key).cloned()
415 }
416
417 pub fn get_version(&self, recipe_id: &str, version: &str) -> Option<ContextRecipe> {
419 let recipes = self.recipes.read().ok()?;
420 let key = format!("{}:{}", recipe_id, version);
421 recipes.get(&key).cloned()
422 }
423
424 pub fn list_versions(&self, recipe_id: &str) -> Vec<String> {
426 self.versions
427 .read()
428 .ok()
429 .and_then(|v| v.get(recipe_id).cloned())
430 .unwrap_or_default()
431 }
432
433 pub fn find_by_session(&self, session_id: &str) -> Vec<ContextRecipe> {
435 let recipes = match self.recipes.read() {
436 Ok(r) => r,
437 Err(_) => return Vec::new(),
438 };
439
440 recipes
441 .values()
442 .filter(|r| match &r.session_binding {
443 Some(SessionBinding::Session(sid)) => sid == session_id,
444 Some(SessionBinding::Pattern(pattern)) => {
445 glob_match(pattern, session_id)
446 }
447 _ => false,
448 })
449 .cloned()
450 .collect()
451 }
452
453 pub fn find_by_agent(&self, agent_id: &str) -> Vec<ContextRecipe> {
455 let recipes = match self.recipes.read() {
456 Ok(r) => r,
457 Err(_) => return Vec::new(),
458 };
459
460 recipes
461 .values()
462 .filter(|r| matches!(&r.session_binding, Some(SessionBinding::Agent(aid)) if aid == agent_id))
463 .cloned()
464 .collect()
465 }
466}
467
468impl Default for ContextRecipeStore {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474fn glob_match(pattern: &str, input: &str) -> bool {
476 if pattern == "*" {
478 return true;
479 }
480 if pattern.contains('*') {
481 let parts: Vec<&str> = pattern.split('*').collect();
482 if parts.len() == 2 {
483 return input.starts_with(parts[0]) && input.ends_with(parts[1]);
484 }
485 }
486 pattern == input
487}
488
489#[derive(Debug, Clone)]
495pub struct VectorSearchResult {
496 pub id: String,
498 pub score: f32,
500 pub content: String,
502 pub metadata: HashMap<String, SochValue>,
504}
505
506pub trait VectorIndex: Send + Sync {
513 fn search_by_embedding(
515 &self,
516 collection: &str,
517 embedding: &[f32],
518 k: usize,
519 min_score: Option<f32>,
520 ) -> Result<Vec<VectorSearchResult>, String>;
521
522 fn search_by_text(
524 &self,
525 collection: &str,
526 text: &str,
527 k: usize,
528 min_score: Option<f32>,
529 ) -> Result<Vec<VectorSearchResult>, String>;
530
531 fn stats(&self, collection: &str) -> Option<VectorIndexStats>;
533}
534
535#[derive(Debug, Clone)]
537pub struct VectorIndexStats {
538 pub vector_count: usize,
540 pub dimension: usize,
542 pub metric: String,
544}
545
546pub struct SimpleVectorIndex {
551 collections: std::sync::RwLock<HashMap<String, VectorCollection>>,
553}
554
555struct VectorCollection {
557 #[allow(clippy::type_complexity)]
559 vectors: Vec<(String, Vec<f32>, String, HashMap<String, SochValue>)>,
560 dimension: usize,
562}
563
564impl SimpleVectorIndex {
565 pub fn new() -> Self {
567 Self {
568 collections: std::sync::RwLock::new(HashMap::new()),
569 }
570 }
571
572 pub fn create_collection(&self, name: &str, dimension: usize) {
574 let mut collections = self.collections.write().unwrap();
575 collections
576 .entry(name.to_string())
577 .or_insert_with(|| VectorCollection {
578 vectors: Vec::new(),
579 dimension,
580 });
581 }
582
583 pub fn insert(
585 &self,
586 collection: &str,
587 id: String,
588 vector: Vec<f32>,
589 content: String,
590 metadata: HashMap<String, SochValue>,
591 ) -> Result<(), String> {
592 let mut collections = self.collections.write().unwrap();
593 let coll = collections
594 .get_mut(collection)
595 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
596
597 if vector.len() != coll.dimension {
598 return Err(format!(
599 "Vector dimension mismatch: expected {}, got {}",
600 coll.dimension,
601 vector.len()
602 ));
603 }
604
605 coll.vectors.push((id, vector, content, metadata));
606 Ok(())
607 }
608
609 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
611 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
612 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
613 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
614 if norm_a == 0.0 || norm_b == 0.0 {
615 0.0
616 } else {
617 dot / (norm_a * norm_b)
618 }
619 }
620}
621
622impl Default for SimpleVectorIndex {
623 fn default() -> Self {
624 Self::new()
625 }
626}
627
628impl VectorIndex for SimpleVectorIndex {
629 fn search_by_embedding(
630 &self,
631 collection: &str,
632 embedding: &[f32],
633 k: usize,
634 min_score: Option<f32>,
635 ) -> Result<Vec<VectorSearchResult>, String> {
636 let collections = self.collections.read().unwrap();
637 let coll = collections
638 .get(collection)
639 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
640
641 let mut scored: Vec<_> = coll
643 .vectors
644 .iter()
645 .map(|(id, vec, content, meta)| {
646 let score = Self::cosine_similarity(embedding, vec);
647 (id, score, content, meta)
648 })
649 .filter(|(_, score, _, _)| min_score.map(|min| *score >= min).unwrap_or(true))
650 .collect();
651
652 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
654
655 Ok(scored
657 .into_iter()
658 .take(k)
659 .map(|(id, score, content, meta)| VectorSearchResult {
660 id: id.clone(),
661 score,
662 content: content.clone(),
663 metadata: meta.clone(),
664 })
665 .collect())
666 }
667
668 fn search_by_text(
669 &self,
670 _collection: &str,
671 _text: &str,
672 _k: usize,
673 _min_score: Option<f32>,
674 ) -> Result<Vec<VectorSearchResult>, String> {
675 Err(
677 "Text-based search requires an embedding model. Use search_by_embedding instead."
678 .to_string(),
679 )
680 }
681
682 fn stats(&self, collection: &str) -> Option<VectorIndexStats> {
683 let collections = self.collections.read().unwrap();
684 collections.get(collection).map(|coll| VectorIndexStats {
685 vector_count: coll.vectors.len(),
686 dimension: coll.dimension,
687 metric: "cosine".to_string(),
688 })
689 }
690}
691
692pub struct HnswVectorIndex {
701 collections: std::sync::RwLock<HashMap<String, HnswCollection>>,
703}
704
705struct HnswCollection {
707 index: sochdb_index::vector::VectorIndex,
709 #[allow(clippy::type_complexity)]
711 metadata: HashMap<u128, (String, String, HashMap<String, SochValue>)>,
712 next_edge_id: u128,
714 dimension: usize,
716}
717
718impl HnswVectorIndex {
719 pub fn new() -> Self {
721 Self {
722 collections: std::sync::RwLock::new(HashMap::new()),
723 }
724 }
725
726 pub fn create_collection(&self, name: &str, dimension: usize) {
728 let mut collections = self.collections.write().unwrap();
729 collections.entry(name.to_string()).or_insert_with(|| {
730 let index = sochdb_index::vector::VectorIndex::with_dimension(
731 sochdb_index::vector::DistanceMetric::Cosine,
732 dimension,
733 );
734 HnswCollection {
735 index,
736 metadata: HashMap::new(),
737 next_edge_id: 0,
738 dimension,
739 }
740 });
741 }
742
743 pub fn insert(
745 &self,
746 collection: &str,
747 id: String,
748 vector: Vec<f32>,
749 content: String,
750 metadata: HashMap<String, SochValue>,
751 ) -> Result<(), String> {
752 let mut collections = self.collections.write().unwrap();
753 let coll = collections
754 .get_mut(collection)
755 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
756
757 if vector.len() != coll.dimension {
758 return Err(format!(
759 "Vector dimension mismatch: expected {}, got {}",
760 coll.dimension,
761 vector.len()
762 ));
763 }
764
765 let edge_id = coll.next_edge_id;
767 coll.next_edge_id += 1;
768 coll.metadata.insert(edge_id, (id, content, metadata));
769
770 let embedding = ndarray::Array1::from_vec(vector);
772
773 coll.index.add(edge_id, embedding)?;
775
776 Ok(())
777 }
778
779 pub fn vector_count(&self, collection: &str) -> Option<usize> {
781 let collections = self.collections.read().unwrap();
782 collections.get(collection).map(|c| c.metadata.len())
783 }
784}
785
786impl Default for HnswVectorIndex {
787 fn default() -> Self {
788 Self::new()
789 }
790}
791
792impl VectorIndex for HnswVectorIndex {
793 fn search_by_embedding(
794 &self,
795 collection: &str,
796 embedding: &[f32],
797 k: usize,
798 min_score: Option<f32>,
799 ) -> Result<Vec<VectorSearchResult>, String> {
800 let collections = self.collections.read().unwrap();
801 let coll = collections
802 .get(collection)
803 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
804
805 let query = ndarray::Array1::from_vec(embedding.to_vec());
807
808 let results = coll.index.search(&query, k)?;
810
811 let mut search_results = Vec::with_capacity(results.len());
814 for (edge_id, distance) in results {
815 let score = 1.0 - distance;
817
818 if let Some(min) = min_score {
820 if score < min {
821 continue;
822 }
823 }
824
825 if let Some((id, content, meta)) = coll.metadata.get(&edge_id) {
827 search_results.push(VectorSearchResult {
828 id: id.clone(),
829 score,
830 content: content.clone(),
831 metadata: meta.clone(),
832 });
833 }
834 }
835
836 Ok(search_results)
837 }
838
839 fn search_by_text(
840 &self,
841 _collection: &str,
842 _text: &str,
843 _k: usize,
844 _min_score: Option<f32>,
845 ) -> Result<Vec<VectorSearchResult>, String> {
846 Err(
848 "Text-based search requires an embedding model. Use search_by_embedding instead."
849 .to_string(),
850 )
851 }
852
853 fn stats(&self, collection: &str) -> Option<VectorIndexStats> {
854 let collections = self.collections.read().unwrap();
855 collections.get(collection).map(|coll| VectorIndexStats {
856 vector_count: coll.metadata.len(),
857 dimension: coll.dimension,
858 metric: "cosine".to_string(),
859 })
860 }
861}
862
863#[derive(Debug, Clone)]
869pub struct ContextResult {
870 pub context: String,
872 pub token_count: usize,
874 pub token_budget: usize,
876 pub sections_included: Vec<SectionResult>,
878 pub sections_truncated: Vec<String>,
880 pub sections_dropped: Vec<String>,
882}
883
884#[derive(Debug, Clone)]
886pub struct SectionResult {
887 pub name: String,
889 pub priority: i32,
891 pub content: String,
893 pub tokens: usize,
895 pub tokens_used: usize,
897 pub truncated: bool,
899 pub row_count: usize,
901}
902
903#[derive(Debug, Clone)]
909pub enum ContextQueryError {
910 SessionMismatch { expected: String, actual: String },
912 VariableNotFound(String),
914 InvalidVariableType { variable: String, expected: String },
916 BudgetExceeded {
918 section: String,
919 requested: usize,
920 available: usize,
921 },
922 BudgetExhausted(String),
924 PermissionDenied(String),
926 InvalidPath(String),
928 Parse(ContextParseError),
930 FormatError(String),
932 InvalidQuery(String),
934 VectorSearchError(String),
936}
937
938impl std::fmt::Display for ContextQueryError {
939 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
940 match self {
941 Self::SessionMismatch { expected, actual } => {
942 write!(f, "session mismatch: expected {}, got {}", expected, actual)
943 }
944 Self::VariableNotFound(name) => write!(f, "variable not found: {}", name),
945 Self::InvalidVariableType { variable, expected } => {
946 write!(
947 f,
948 "variable {} has invalid type, expected {}",
949 variable, expected
950 )
951 }
952 Self::BudgetExceeded {
953 section,
954 requested,
955 available,
956 } => {
957 write!(
958 f,
959 "section {} exceeds budget: {} > {}",
960 section, requested, available
961 )
962 }
963 Self::BudgetExhausted(msg) => write!(f, "budget exhausted: {}", msg),
964 Self::PermissionDenied(msg) => write!(f, "permission denied: {}", msg),
965 Self::InvalidPath(path) => write!(f, "invalid path: {}", path),
966 Self::Parse(e) => write!(f, "parse error: {}", e),
967 Self::FormatError(e) => write!(f, "format error: {}", e),
968 Self::InvalidQuery(msg) => write!(f, "invalid query: {}", msg),
969 Self::VectorSearchError(e) => write!(f, "vector search error: {}", e),
970 }
971 }
972}
973
974impl std::error::Error for ContextQueryError {}
975
976#[derive(Debug, Clone)]
978pub enum ContextParseError {
979 UnexpectedToken { expected: String, found: String },
981 MissingClause(String),
983 InvalidOption(String),
985 InvalidPath(String),
987 InvalidSection(String),
989 SyntaxError(String),
991}
992
993impl std::fmt::Display for ContextParseError {
994 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
995 match self {
996 Self::UnexpectedToken { expected, found } => {
997 write!(f, "expected {}, found '{}'", expected, found)
998 }
999 Self::MissingClause(clause) => write!(f, "missing {} clause", clause),
1000 Self::InvalidOption(opt) => write!(f, "invalid option: {}", opt),
1001 Self::InvalidPath(path) => write!(f, "invalid path: {}", path),
1002 Self::InvalidSection(sec) => write!(f, "invalid section: {}", sec),
1003 Self::SyntaxError(msg) => write!(f, "syntax error: {}", msg),
1004 }
1005 }
1006}
1007
1008impl std::error::Error for ContextParseError {}
1009
1010pub struct ContextQueryParser {
1012 pos: usize,
1014 tokens: Vec<Token>,
1016}
1017
1018#[derive(Debug, Clone, PartialEq)]
1020enum Token {
1021 Keyword(String),
1023 Ident(String),
1025 Number(f64),
1027 String(String),
1029 Punct(char),
1031 Variable(String),
1033 Eof,
1035}
1036
1037impl ContextQueryParser {
1038 pub fn new(input: &str) -> Self {
1040 let tokens = Self::tokenize(input);
1041 Self { pos: 0, tokens }
1042 }
1043
1044 pub fn parse(&mut self) -> Result<ContextSelectQuery, ContextParseError> {
1046 self.expect_keyword("CONTEXT")?;
1048 self.expect_keyword("SELECT")?;
1049 let output_name = self.expect_ident()?;
1050
1051 let session = if self.match_keyword("FROM") {
1053 self.parse_session_reference()?
1054 } else {
1055 SessionReference::None
1056 };
1057
1058 let options = if self.match_keyword("WITH") {
1060 self.parse_options()?
1061 } else {
1062 ContextQueryOptions::default()
1063 };
1064
1065 self.expect_keyword("SECTIONS")?;
1067 let sections = self.parse_sections()?;
1068
1069 Ok(ContextSelectQuery {
1070 output_name,
1071 session,
1072 options,
1073 sections,
1074 })
1075 }
1076
1077 fn parse_session_reference(&mut self) -> Result<SessionReference, ContextParseError> {
1079 if self.match_keyword("session") {
1080 self.expect_punct('(')?;
1081 let var = self.expect_variable()?;
1082 self.expect_punct(')')?;
1083 Ok(SessionReference::Session(var))
1084 } else if self.match_keyword("agent") {
1085 self.expect_punct('(')?;
1086 let var = self.expect_variable()?;
1087 self.expect_punct(')')?;
1088 Ok(SessionReference::Agent(var))
1089 } else {
1090 Err(ContextParseError::SyntaxError(
1091 "expected 'session' or 'agent'".to_string(),
1092 ))
1093 }
1094 }
1095
1096 fn parse_options(&mut self) -> Result<ContextQueryOptions, ContextParseError> {
1098 self.expect_punct('(')?;
1099 let mut options = ContextQueryOptions::default();
1100
1101 loop {
1102 let key = self.expect_ident()?;
1103 self.expect_punct('=')?;
1104
1105 match key.as_str() {
1106 "token_limit" => {
1107 if let Token::Number(n) = self.current().clone() {
1108 options.token_limit = n as usize;
1109 self.advance();
1110 }
1111 }
1112 "include_schema" => {
1113 options.include_schema = self.parse_bool()?;
1114 }
1115 "format" => {
1116 let format = self.expect_ident()?;
1117 options.format = match format.to_lowercase().as_str() {
1118 "toon" => OutputFormat::Soch,
1119 "json" => OutputFormat::Json,
1120 "markdown" => OutputFormat::Markdown,
1121 _ => return Err(ContextParseError::InvalidOption(format)),
1122 };
1123 }
1124 "truncation" => {
1125 let strategy = self.expect_ident()?;
1126 options.truncation = match strategy.to_lowercase().as_str() {
1127 "tail_drop" | "taildrop" => TruncationStrategy::TailDrop,
1128 "head_drop" | "headdrop" => TruncationStrategy::HeadDrop,
1129 "proportional" => TruncationStrategy::Proportional,
1130 "fail" => TruncationStrategy::Fail,
1131 _ => return Err(ContextParseError::InvalidOption(strategy)),
1132 };
1133 }
1134 "include_headers" => {
1135 options.include_headers = self.parse_bool()?;
1136 }
1137 _ => return Err(ContextParseError::InvalidOption(key)),
1138 }
1139
1140 if !self.match_punct(',') {
1141 break;
1142 }
1143 }
1144
1145 self.expect_punct(')')?;
1146 Ok(options)
1147 }
1148
1149 fn parse_sections(&mut self) -> Result<Vec<ContextSection>, ContextParseError> {
1151 self.expect_punct('(')?;
1152 let mut sections = Vec::new();
1153
1154 loop {
1155 if self.check_punct(')') {
1156 break;
1157 }
1158
1159 let section = self.parse_section()?;
1160 sections.push(section);
1161
1162 if !self.match_punct(',') {
1163 break;
1164 }
1165 }
1166
1167 self.expect_punct(')')?;
1168 Ok(sections)
1169 }
1170
1171 fn parse_section(&mut self) -> Result<ContextSection, ContextParseError> {
1173 let name = self.expect_ident()?;
1175
1176 self.expect_keyword("PRIORITY")?;
1177 let priority = if let Token::Number(n) = self.current().clone() {
1178 let val = n as i32;
1179 self.advance();
1180 val
1181 } else {
1182 0
1183 };
1184
1185 self.expect_punct(':')?;
1186
1187 let content = self.parse_section_content()?;
1188
1189 Ok(ContextSection {
1190 name,
1191 priority,
1192 content,
1193 transform: None,
1194 })
1195 }
1196
1197 fn parse_section_content(&mut self) -> Result<SectionContent, ContextParseError> {
1199 if self.match_keyword("GET") {
1200 let path_str = self.collect_until(&[',', ')']);
1202 let path = PathExpression::parse(&path_str)?;
1203 Ok(SectionContent::Get { path })
1204 } else if self.match_keyword("LAST") {
1205 let count = if let Token::Number(n) = self.current().clone() {
1207 let val = n as usize;
1208 self.advance();
1209 val
1210 } else {
1211 10 };
1213
1214 self.expect_keyword("FROM")?;
1215 let table = self.expect_ident()?;
1216
1217 let where_clause = if self.match_keyword("WHERE") {
1218 Some(self.parse_where_clause()?)
1219 } else {
1220 None
1221 };
1222
1223 Ok(SectionContent::Last {
1224 count,
1225 table,
1226 where_clause,
1227 })
1228 } else if self.match_keyword("SEARCH") {
1229 let collection = self.expect_ident()?;
1231 self.expect_keyword("BY")?;
1232 self.expect_keyword("SIMILARITY")?;
1233
1234 self.expect_punct('(')?;
1235 let query = if let Token::Variable(v) = self.current().clone() {
1236 self.advance();
1237 SimilarityQuery::Variable(v)
1238 } else if let Token::String(s) = self.current().clone() {
1239 self.advance();
1240 SimilarityQuery::Text(s)
1241 } else {
1242 return Err(ContextParseError::SyntaxError(
1243 "expected variable or string for similarity query".to_string(),
1244 ));
1245 };
1246 self.expect_punct(')')?;
1247
1248 self.expect_keyword("TOP")?;
1249 let top_k = if let Token::Number(n) = self.current().clone() {
1250 let val = n as usize;
1251 self.advance();
1252 val
1253 } else {
1254 5 };
1256
1257 Ok(SectionContent::Search {
1258 collection,
1259 query,
1260 top_k,
1261 min_score: None,
1262 })
1263 } else if self.match_keyword("SELECT") {
1264 let columns = self.parse_column_list()?;
1266 self.expect_keyword("FROM")?;
1267 let table = self.expect_ident()?;
1268
1269 let where_clause = if self.match_keyword("WHERE") {
1270 Some(self.parse_where_clause()?)
1271 } else {
1272 None
1273 };
1274
1275 let limit = if self.match_keyword("LIMIT") {
1276 if let Token::Number(n) = self.current().clone() {
1277 let val = n as usize;
1278 self.advance();
1279 Some(val)
1280 } else {
1281 None
1282 }
1283 } else {
1284 None
1285 };
1286
1287 Ok(SectionContent::Select {
1288 columns,
1289 table,
1290 where_clause,
1291 limit,
1292 })
1293 } else if let Token::Variable(v) = self.current().clone() {
1294 self.advance();
1295 Ok(SectionContent::Variable { name: v })
1296 } else if let Token::String(s) = self.current().clone() {
1297 self.advance();
1298 Ok(SectionContent::Literal { value: s })
1299 } else {
1300 Err(ContextParseError::InvalidSection(
1301 "expected GET, LAST, SEARCH, SELECT, or literal".to_string(),
1302 ))
1303 }
1304 }
1305
1306 fn parse_where_clause(&mut self) -> Result<WhereClause, ContextParseError> {
1308 let mut conditions = Vec::new();
1309
1310 loop {
1311 let column = self.expect_ident()?;
1312 let operator = self.parse_comparison_op()?;
1313 let value = self.parse_value()?;
1314
1315 conditions.push(Condition {
1316 column,
1317 operator,
1318 value,
1319 });
1320
1321 if !self.match_keyword("AND") && !self.match_keyword("OR") {
1322 break;
1323 }
1324 }
1325
1326 Ok(WhereClause {
1327 conditions,
1328 operator: LogicalOp::And,
1329 })
1330 }
1331
1332 fn parse_comparison_op(&mut self) -> Result<ComparisonOp, ContextParseError> {
1334 match self.current() {
1335 Token::Punct('=') => {
1336 self.advance();
1337 Ok(ComparisonOp::Eq)
1338 }
1339 Token::Punct('>') => {
1340 self.advance();
1341 if self.check_punct('=') {
1342 self.advance();
1343 Ok(ComparisonOp::Ge)
1344 } else {
1345 Ok(ComparisonOp::Gt)
1346 }
1347 }
1348 Token::Punct('<') => {
1349 self.advance();
1350 if self.check_punct('=') {
1351 self.advance();
1352 Ok(ComparisonOp::Le)
1353 } else {
1354 Ok(ComparisonOp::Lt)
1355 }
1356 }
1357 _ => {
1358 if self.match_keyword("LIKE") {
1359 Ok(ComparisonOp::Like)
1360 } else if self.match_keyword("IN") {
1361 Ok(ComparisonOp::In)
1362 } else {
1363 Err(ContextParseError::SyntaxError(
1364 "expected comparison operator".to_string(),
1365 ))
1366 }
1367 }
1368 }
1369 }
1370
1371 fn parse_value(&mut self) -> Result<SochValue, ContextParseError> {
1373 match self.current().clone() {
1374 Token::Number(n) => {
1375 self.advance();
1376 if n.fract() == 0.0 {
1377 Ok(SochValue::Int(n as i64))
1378 } else {
1379 Ok(SochValue::Float(n))
1380 }
1381 }
1382 Token::String(s) => {
1383 self.advance();
1384 Ok(SochValue::Text(s))
1385 }
1386 Token::Keyword(k) if k.eq_ignore_ascii_case("null") => {
1387 self.advance();
1388 Ok(SochValue::Null)
1389 }
1390 Token::Keyword(k) if k.eq_ignore_ascii_case("true") => {
1391 self.advance();
1392 Ok(SochValue::Bool(true))
1393 }
1394 Token::Keyword(k) if k.eq_ignore_ascii_case("false") => {
1395 self.advance();
1396 Ok(SochValue::Bool(false))
1397 }
1398 Token::Variable(v) => {
1399 self.advance();
1400 Ok(SochValue::Text(format!("${}", v)))
1402 }
1403 _ => Err(ContextParseError::SyntaxError("expected value".to_string())),
1404 }
1405 }
1406
1407 fn parse_column_list(&mut self) -> Result<Vec<String>, ContextParseError> {
1409 let mut columns = Vec::new();
1410
1411 if self.check_punct('*') {
1412 self.advance();
1413 columns.push("*".to_string());
1414 } else {
1415 loop {
1416 columns.push(self.expect_ident()?);
1417 if !self.match_punct(',') {
1418 break;
1419 }
1420 }
1421 }
1422
1423 Ok(columns)
1424 }
1425
1426 fn parse_bool(&mut self) -> Result<bool, ContextParseError> {
1428 match self.current() {
1429 Token::Keyword(k) if k.eq_ignore_ascii_case("true") => {
1430 self.advance();
1431 Ok(true)
1432 }
1433 Token::Keyword(k) if k.eq_ignore_ascii_case("false") => {
1434 self.advance();
1435 Ok(false)
1436 }
1437 _ => Err(ContextParseError::SyntaxError(
1438 "expected boolean".to_string(),
1439 )),
1440 }
1441 }
1442
1443 fn tokenize(input: &str) -> Vec<Token> {
1445 let mut tokens = Vec::new();
1446 let mut chars = input.chars().peekable();
1447
1448 while let Some(&ch) = chars.peek() {
1449 match ch {
1450 ' ' | '\t' | '\n' | '\r' => {
1452 chars.next();
1453 }
1454
1455 '(' | ')' | ',' | ':' | '=' | '<' | '>' | '*' | '{' | '}' | '.' => {
1457 tokens.push(Token::Punct(ch));
1458 chars.next();
1459 }
1460
1461 '$' => {
1463 chars.next();
1464 let mut name = String::new();
1465 while let Some(&c) = chars.peek() {
1466 if c.is_alphanumeric() || c == '_' {
1467 name.push(c);
1468 chars.next();
1469 } else {
1470 break;
1471 }
1472 }
1473 tokens.push(Token::Variable(name));
1474 }
1475
1476 '\'' | '"' => {
1478 let quote = ch;
1479 chars.next();
1480 let mut s = String::new();
1481 while let Some(&c) = chars.peek() {
1482 if c == quote {
1483 chars.next(); break;
1485 }
1486 s.push(c);
1487 chars.next();
1488 }
1489 tokens.push(Token::String(s));
1490 }
1491
1492 '0'..='9' | '-' => {
1494 let mut num_str = String::new();
1495 if ch == '-' {
1496 num_str.push(ch);
1497 chars.next();
1498 }
1499 while let Some(&c) = chars.peek() {
1500 if c.is_ascii_digit() || c == '.' {
1501 num_str.push(c);
1502 chars.next();
1503 } else {
1504 break;
1505 }
1506 }
1507 if let Ok(n) = num_str.parse::<f64>() {
1508 tokens.push(Token::Number(n));
1509 }
1510 }
1511
1512 'a'..='z' | 'A'..='Z' | '_' => {
1514 let mut ident = String::new();
1515 while let Some(&c) = chars.peek() {
1516 if c.is_alphanumeric() || c == '_' {
1517 ident.push(c);
1518 chars.next();
1519 } else {
1520 break;
1521 }
1522 }
1523
1524 let keywords = [
1526 "CONTEXT",
1527 "SELECT",
1528 "FROM",
1529 "WITH",
1530 "SECTIONS",
1531 "PRIORITY",
1532 "GET",
1533 "LAST",
1534 "SEARCH",
1535 "BY",
1536 "SIMILARITY",
1537 "TOP",
1538 "WHERE",
1539 "AND",
1540 "OR",
1541 "LIKE",
1542 "IN",
1543 "LIMIT",
1544 "session",
1545 "agent",
1546 "true",
1547 "false",
1548 "null",
1549 ];
1550
1551 if keywords.iter().any(|k| k.eq_ignore_ascii_case(&ident)) {
1552 tokens.push(Token::Keyword(ident.to_uppercase()));
1553 } else {
1554 tokens.push(Token::Ident(ident));
1555 }
1556 }
1557
1558 _ => {
1560 chars.next();
1561 }
1562 }
1563 }
1564
1565 tokens.push(Token::Eof);
1566 tokens
1567 }
1568
1569 fn current(&self) -> &Token {
1571 self.tokens.get(self.pos).unwrap_or(&Token::Eof)
1572 }
1573
1574 fn advance(&mut self) {
1575 if self.pos < self.tokens.len() {
1576 self.pos += 1;
1577 }
1578 }
1579
1580 fn expect_keyword(&mut self, kw: &str) -> Result<(), ContextParseError> {
1581 match self.current() {
1582 Token::Keyword(k) if k.eq_ignore_ascii_case(kw) => {
1583 self.advance();
1584 Ok(())
1585 }
1586 other => Err(ContextParseError::UnexpectedToken {
1587 expected: kw.to_string(),
1588 found: format!("{:?}", other),
1589 }),
1590 }
1591 }
1592
1593 fn match_keyword(&mut self, kw: &str) -> bool {
1594 match self.current() {
1595 Token::Keyword(k) if k.eq_ignore_ascii_case(kw) => {
1596 self.advance();
1597 true
1598 }
1599 _ => false,
1600 }
1601 }
1602
1603 fn expect_ident(&mut self) -> Result<String, ContextParseError> {
1604 match self.current().clone() {
1605 Token::Ident(s) => {
1606 self.advance();
1607 Ok(s)
1608 }
1609 Token::Keyword(s) => {
1610 self.advance();
1612 Ok(s)
1613 }
1614 other => Err(ContextParseError::UnexpectedToken {
1615 expected: "identifier".to_string(),
1616 found: format!("{:?}", other),
1617 }),
1618 }
1619 }
1620
1621 fn expect_variable(&mut self) -> Result<String, ContextParseError> {
1622 match self.current().clone() {
1623 Token::Variable(v) => {
1624 self.advance();
1625 Ok(v)
1626 }
1627 other => Err(ContextParseError::UnexpectedToken {
1628 expected: "variable ($name)".to_string(),
1629 found: format!("{:?}", other),
1630 }),
1631 }
1632 }
1633
1634 fn expect_punct(&mut self, p: char) -> Result<(), ContextParseError> {
1635 match self.current() {
1636 Token::Punct(c) if *c == p => {
1637 self.advance();
1638 Ok(())
1639 }
1640 other => Err(ContextParseError::UnexpectedToken {
1641 expected: p.to_string(),
1642 found: format!("{:?}", other),
1643 }),
1644 }
1645 }
1646
1647 fn match_punct(&mut self, p: char) -> bool {
1648 match self.current() {
1649 Token::Punct(c) if *c == p => {
1650 self.advance();
1651 true
1652 }
1653 _ => false,
1654 }
1655 }
1656
1657 fn check_punct(&self, p: char) -> bool {
1658 matches!(self.current(), Token::Punct(c) if *c == p)
1659 }
1660
1661 fn collect_until(&mut self, terminators: &[char]) -> String {
1662 let mut result = String::new();
1663 let mut depth = 0;
1664
1665 loop {
1666 match self.current() {
1667 Token::Punct('{') => {
1668 depth += 1;
1669 result.push('{');
1670 self.advance();
1671 }
1672 Token::Punct('}') => {
1673 depth -= 1;
1674 result.push('}');
1675 self.advance();
1676 }
1677 Token::Punct(c) if depth == 0 && terminators.contains(c) => {
1678 break;
1679 }
1680 Token::Punct(c) => {
1681 result.push(*c);
1682 self.advance();
1683 }
1684 Token::Ident(s) | Token::Keyword(s) => {
1685 if !result.is_empty() && !result.ends_with(['.', '{']) {
1686 result.push(' ');
1687 }
1688 result.push_str(s);
1689 self.advance();
1690 }
1691 Token::Eof => break,
1692 _ => {
1693 self.advance();
1694 }
1695 }
1696 }
1697
1698 result.trim().to_string()
1699 }
1700}
1701
1702use crate::agent_context::{AgentContext, AuditOperation, ContextValue};
1707
1708pub struct AgentContextIntegration<'a> {
1718 context: &'a mut AgentContext,
1720 budget_enforcer: TokenBudgetEnforcer,
1722 estimator: TokenEstimator,
1724 vector_index: Option<std::sync::Arc<dyn VectorIndex>>,
1726 embedding_provider: Option<std::sync::Arc<dyn EmbeddingProvider>>,
1728}
1729
1730pub trait EmbeddingProvider: Send + Sync {
1734 fn embed_text(&self, text: &str) -> Result<Vec<f32>, String>;
1736
1737 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, String> {
1739 texts.iter().map(|t| self.embed_text(t)).collect()
1740 }
1741
1742 fn dimension(&self) -> usize;
1744
1745 fn model_name(&self) -> &str;
1747}
1748
1749impl<'a> AgentContextIntegration<'a> {
1750 pub fn new(context: &'a mut AgentContext) -> Self {
1752 let config = TokenBudgetConfig {
1753 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1754 ..Default::default()
1755 };
1756
1757 Self {
1758 context,
1759 budget_enforcer: TokenBudgetEnforcer::new(config),
1760 estimator: TokenEstimator::default(),
1761 vector_index: None,
1762 embedding_provider: None,
1763 }
1764 }
1765
1766 pub fn with_vector_index(
1768 context: &'a mut AgentContext,
1769 vector_index: std::sync::Arc<dyn VectorIndex>,
1770 ) -> Self {
1771 let config = TokenBudgetConfig {
1772 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1773 ..Default::default()
1774 };
1775
1776 Self {
1777 context,
1778 budget_enforcer: TokenBudgetEnforcer::new(config),
1779 estimator: TokenEstimator::default(),
1780 vector_index: Some(vector_index),
1781 embedding_provider: None,
1782 }
1783 }
1784
1785 pub fn with_vector_and_embedding(
1787 context: &'a mut AgentContext,
1788 vector_index: std::sync::Arc<dyn VectorIndex>,
1789 embedding_provider: std::sync::Arc<dyn EmbeddingProvider>,
1790 ) -> Self {
1791 let config = TokenBudgetConfig {
1792 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1793 ..Default::default()
1794 };
1795
1796 Self {
1797 context,
1798 budget_enforcer: TokenBudgetEnforcer::new(config),
1799 estimator: TokenEstimator::default(),
1800 vector_index: Some(vector_index),
1801 embedding_provider: Some(embedding_provider),
1802 }
1803 }
1804
1805 pub fn set_embedding_provider(&mut self, provider: std::sync::Arc<dyn EmbeddingProvider>) {
1807 self.embedding_provider = Some(provider);
1808 }
1809
1810 pub fn set_vector_index(&mut self, index: std::sync::Arc<dyn VectorIndex>) {
1812 self.vector_index = Some(index);
1813 }
1814
1815 pub fn execute(
1817 &mut self,
1818 query: &ContextSelectQuery,
1819 ) -> Result<ContextQueryResult, ContextQueryError> {
1820 self.validate_session(&query.session)?;
1822
1823 self.context.audit.push(crate::agent_context::AuditEntry {
1825 timestamp: std::time::SystemTime::now(),
1826 operation: AuditOperation::DbQuery,
1827 resource: format!("CONTEXT SELECT {}", query.output_name),
1828 result: crate::agent_context::AuditResult::Success,
1829 metadata: std::collections::HashMap::new(),
1830 });
1831
1832 let resolved_sections = self.resolve_sections(&query.sections)?;
1834
1835 for section in &resolved_sections {
1837 self.check_section_permissions(section)?;
1838 }
1839
1840 let mut section_contents: Vec<(ContextSection, String)> = Vec::new();
1842 for section in &resolved_sections {
1843 let content = self.execute_section_content(section, query.options.token_limit)?;
1844 section_contents.push((section.clone(), content));
1845 }
1846
1847 let budget_sections: Vec<BudgetSection> = section_contents
1849 .iter()
1850 .map(|(section, content)| {
1851 let estimated = self.estimator.estimate_text(content);
1852 let minimum = if query.options.truncation == TruncationStrategy::Fail {
1854 None
1855 } else {
1856 Some(estimated.min(100).max(estimated / 10))
1857 };
1858 BudgetSection {
1859 name: section.name.clone(),
1860 estimated_tokens: estimated,
1861 minimum_tokens: minimum,
1862 priority: section.priority,
1863 required: section.priority == 0, weight: 1.0,
1865 }
1866 })
1867 .collect();
1868
1869 let allocation = self.budget_enforcer.allocate_sections(&budget_sections);
1871
1872 let mut result = ContextQueryResult::new(query.output_name.clone());
1874 result.format = query.options.format;
1875 result.allocation_explain = Some(allocation.explain.clone());
1876
1877 for (section, content) in section_contents.iter() {
1879 if allocation.full_sections.contains(§ion.name) {
1880 let tokens = self.estimator.estimate_text(content);
1881 result.sections.push(SectionResult {
1882 name: section.name.clone(),
1883 priority: section.priority,
1884 content: content.clone(),
1885 tokens,
1886 tokens_used: tokens,
1887 truncated: false,
1888 row_count: 0,
1889 });
1890 }
1891 }
1892
1893 for (section_name, _original, truncated_to) in &allocation.truncated_sections {
1895 if let Some((section, content)) = section_contents
1896 .iter()
1897 .find(|(s, _)| &s.name == section_name)
1898 {
1899 let truncated = self.estimator.truncate_to_tokens(content, *truncated_to);
1901 let actual_tokens = self.estimator.estimate_text(&truncated);
1902 result.sections.push(SectionResult {
1903 name: section.name.clone(),
1904 priority: section.priority,
1905 content: truncated,
1906 tokens: actual_tokens,
1907 tokens_used: actual_tokens,
1908 truncated: true,
1909 row_count: 0,
1910 });
1911 }
1912 }
1913
1914 result.sections.sort_by_key(|s| s.priority);
1916
1917 result.total_tokens = allocation.tokens_allocated;
1918 result.token_limit = query.options.token_limit;
1919
1920 self.context
1922 .consume_budget(result.total_tokens as u64, 0)
1923 .map_err(|e| ContextQueryError::BudgetExhausted(e.to_string()))?;
1924
1925 Ok(result)
1926 }
1927
1928 pub fn execute_explain(
1930 &mut self,
1931 query: &ContextSelectQuery,
1932 ) -> Result<(ContextQueryResult, String), ContextQueryError> {
1933 let result = self.execute(query)?;
1934 let explain = result
1935 .allocation_explain
1936 .as_ref()
1937 .map(|decisions| {
1938 use crate::token_budget::BudgetAllocation;
1939 let allocation = BudgetAllocation {
1940 full_sections: result
1941 .sections
1942 .iter()
1943 .filter(|s| !s.truncated)
1944 .map(|s| s.name.clone())
1945 .collect(),
1946 truncated_sections: result
1947 .sections
1948 .iter()
1949 .filter(|s| s.truncated)
1950 .map(|s| (s.name.clone(), s.tokens, s.tokens_used))
1951 .collect(),
1952 dropped_sections: Vec::new(),
1953 tokens_allocated: result.total_tokens,
1954 tokens_remaining: result.token_limit.saturating_sub(result.total_tokens),
1955 explain: decisions.clone(),
1956 };
1957 allocation.explain_text()
1958 })
1959 .unwrap_or_else(|| "No allocation explain available".to_string());
1960 Ok((result, explain))
1961 }
1962
1963 fn validate_session(&self, session_ref: &SessionReference) -> Result<(), ContextQueryError> {
1965 match session_ref {
1966 SessionReference::Session(sid) => {
1967 if sid.starts_with('$') {
1969 return Ok(());
1970 }
1971 if sid != &self.context.session_id && sid != "*" {
1973 return Err(ContextQueryError::SessionMismatch {
1974 expected: sid.clone(),
1975 actual: self.context.session_id.clone(),
1976 });
1977 }
1978 }
1979 SessionReference::Agent(aid) => {
1980 if let Some(ContextValue::String(agent_id)) = self.context.peek_var("agent_id")
1982 && aid != agent_id
1983 && aid != "*"
1984 {
1985 return Err(ContextQueryError::SessionMismatch {
1986 expected: aid.clone(),
1987 actual: agent_id.clone(),
1988 });
1989 }
1990 }
1991 SessionReference::None => {}
1992 }
1993 Ok(())
1994 }
1995
1996 fn resolve_sections(
1998 &self,
1999 sections: &[ContextSection],
2000 ) -> Result<Vec<ContextSection>, ContextQueryError> {
2001 let mut resolved = Vec::new();
2002
2003 for section in sections {
2004 let mut resolved_section = section.clone();
2005
2006 resolved_section.content = match §ion.content {
2008 SectionContent::Literal { value } => {
2009 let resolved_value = self.resolve_variables(value);
2010 SectionContent::Literal {
2011 value: resolved_value,
2012 }
2013 }
2014 SectionContent::Variable { name } => {
2015 if let Some(value) = self.context.peek_var(name) {
2016 SectionContent::Literal {
2017 value: value.to_string(),
2018 }
2019 } else {
2020 return Err(ContextQueryError::VariableNotFound(name.clone()));
2021 }
2022 }
2023 SectionContent::Search {
2024 collection,
2025 query,
2026 top_k,
2027 min_score,
2028 } => {
2029 let resolved_query = match query {
2030 SimilarityQuery::Variable(var) => {
2031 if let Some(value) = self.context.peek_var(var) {
2032 match value {
2033 ContextValue::String(s) => SimilarityQuery::Text(s.clone()),
2034 ContextValue::List(l) => {
2035 let vec: Vec<f32> = l
2036 .iter()
2037 .filter_map(|v| match v {
2038 ContextValue::Number(n) => Some(*n as f32),
2039 _ => None,
2040 })
2041 .collect();
2042 SimilarityQuery::Embedding(vec)
2043 }
2044 _ => {
2045 return Err(ContextQueryError::InvalidVariableType {
2046 variable: var.clone(),
2047 expected: "string or vector".to_string(),
2048 });
2049 }
2050 }
2051 } else {
2052 return Err(ContextQueryError::VariableNotFound(var.clone()));
2053 }
2054 }
2055 other => other.clone(),
2056 };
2057 SectionContent::Search {
2058 collection: collection.clone(),
2059 query: resolved_query,
2060 top_k: *top_k,
2061 min_score: *min_score,
2062 }
2063 }
2064 other => other.clone(),
2065 };
2066
2067 resolved.push(resolved_section);
2068 }
2069
2070 Ok(resolved)
2071 }
2072
2073 fn resolve_variables(&self, input: &str) -> String {
2075 self.context.substitute_vars(input)
2076 }
2077
2078 fn check_section_permissions(&self, section: &ContextSection) -> Result<(), ContextQueryError> {
2080 match §ion.content {
2081 SectionContent::Get { path } => {
2082 let path_str = path.to_path_string();
2084 if path_str.starts_with('/') {
2085 self.context
2086 .check_fs_permission(&path_str, AuditOperation::FsRead)
2087 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2088 } else {
2089 let table = path
2091 .segments
2092 .first()
2093 .ok_or_else(|| ContextQueryError::InvalidPath("empty path".to_string()))?;
2094 self.context
2095 .check_db_permission(table, AuditOperation::DbQuery)
2096 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2097 }
2098 }
2099 SectionContent::Last { table, .. } | SectionContent::Select { table, .. } => {
2100 self.context
2101 .check_db_permission(table, AuditOperation::DbQuery)
2102 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2103 }
2104 SectionContent::Search { collection, .. } => {
2105 self.context
2106 .check_db_permission(collection, AuditOperation::DbQuery)
2107 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2108 }
2109 SectionContent::Literal { .. } | SectionContent::Variable { .. } => {
2110 }
2112 SectionContent::ToolRegistry { .. } | SectionContent::ToolCalls { .. } => {
2113 }
2115 }
2116 Ok(())
2117 }
2118
2119 fn execute_section_content(
2121 &self,
2122 section: &ContextSection,
2123 _budget: usize,
2124 ) -> Result<String, ContextQueryError> {
2125 match §ion.content {
2128 SectionContent::Literal { value } => Ok(value.clone()),
2129 SectionContent::Variable { name } => self
2130 .context
2131 .peek_var(name)
2132 .map(|v| v.to_string())
2133 .ok_or_else(|| ContextQueryError::VariableNotFound(name.clone())),
2134 SectionContent::Get { path } => {
2135 Ok(format!(
2137 "[{}: path={}]",
2138 section.name,
2139 path.to_path_string()
2140 ))
2141 }
2142 SectionContent::Last { count, table, .. } => {
2143 Ok(format!("[{}: last {} from {}]", section.name, count, table))
2145 }
2146 SectionContent::Search {
2147 collection,
2148 query: similarity_query,
2149 top_k,
2150 min_score,
2151 } => {
2152 match &self.vector_index {
2154 Some(index) => {
2155 let results = match similarity_query {
2157 SimilarityQuery::Embedding(emb) => {
2158 index.search_by_embedding(collection, emb, *top_k, *min_score)
2159 }
2160 SimilarityQuery::Text(text) => {
2161 self.search_by_text_with_embedding(
2163 index, collection, text, *top_k, *min_score,
2164 )
2165 }
2166 SimilarityQuery::Variable(var_name) => {
2167 match self.context.peek_var(var_name) {
2169 Some(ContextValue::String(text)) => {
2170 self.search_by_text_with_embedding(
2171 index, collection, text, *top_k, *min_score,
2172 )
2173 }
2174 Some(ContextValue::List(list)) => {
2175 let embedding: Result<Vec<f32>, _> = list
2177 .iter()
2178 .map(|v| match v {
2179 ContextValue::Number(n) => Ok(*n as f32),
2180 ContextValue::String(s) => {
2181 s.parse::<f32>().map_err(|_| "not a number")
2182 }
2183 _ => Err("not a number"),
2184 })
2185 .collect();
2186
2187 match embedding {
2188 Ok(emb) => index.search_by_embedding(
2189 collection, &emb, *top_k, *min_score,
2190 ),
2191 Err(_) => {
2192 Err("Variable is not a valid embedding vector"
2193 .to_string())
2194 }
2195 }
2196 }
2197 _ => Err(format!(
2198 "Variable '{}' not found or has wrong type",
2199 var_name
2200 )),
2201 }
2202 }
2203 };
2204
2205 match results {
2206 Ok(search_results) => {
2207 self.format_search_results(§ion.name, &search_results)
2209 }
2210 Err(e) => {
2211 Ok(format!("[{}: search error: {}]", section.name, e))
2213 }
2214 }
2215 }
2216 None => {
2217 Ok(format!(
2219 "[{}: search {} top {}]",
2220 section.name, collection, top_k
2221 ))
2222 }
2223 }
2224 }
2225 SectionContent::Select { table, limit, .. } => {
2226 let limit_str = limit.map(|l| format!(" limit {}", l)).unwrap_or_default();
2228 Ok(format!(
2229 "[{}: select from {}{}]",
2230 section.name, table, limit_str
2231 ))
2232 }
2233 SectionContent::ToolRegistry {
2234 include,
2235 exclude,
2236 include_schema,
2237 } => {
2238 self.format_tool_registry(include, exclude, *include_schema)
2240 }
2241 SectionContent::ToolCalls {
2242 count,
2243 tool_filter,
2244 status_filter,
2245 include_outputs,
2246 } => {
2247 self.format_tool_calls(*count, tool_filter.as_deref(), status_filter.as_deref(), *include_outputs)
2249 }
2250 }
2251 }
2252
2253 fn format_tool_registry(
2255 &self,
2256 include: &[String],
2257 exclude: &[String],
2258 include_schema: bool,
2259 ) -> Result<String, ContextQueryError> {
2260 use std::fmt::Write;
2261
2262 let tools = &self.context.tool_registry;
2264 let mut output = String::new();
2265
2266 writeln!(output, "[tool_registry ({} tools)]", tools.len())
2267 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2268
2269 for tool in tools {
2270 if !include.is_empty() && !include.contains(&tool.name) {
2272 continue;
2273 }
2274 if exclude.contains(&tool.name) {
2275 continue;
2276 }
2277
2278 writeln!(output, " [{}]", tool.name)
2279 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2280 writeln!(output, " description = {:?}", tool.description)
2281 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2282
2283 if include_schema {
2284 if let Some(schema) = &tool.parameters_schema {
2285 writeln!(output, " parameters = {}", schema)
2286 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2287 }
2288 }
2289 }
2290
2291 Ok(output)
2292 }
2293
2294 fn format_tool_calls(
2296 &self,
2297 count: usize,
2298 tool_filter: Option<&str>,
2299 status_filter: Option<&str>,
2300 include_outputs: bool,
2301 ) -> Result<String, ContextQueryError> {
2302 use std::fmt::Write;
2303
2304 let calls = &self.context.tool_calls;
2306 let mut output = String::new();
2307
2308 let filtered: Vec<_> = calls
2310 .iter()
2311 .filter(|call| {
2312 tool_filter.map(|f| call.tool_name == f).unwrap_or(true)
2313 && status_filter
2314 .map(|s| {
2315 match s {
2316 "success" => call.result.is_some() && call.error.is_none(),
2317 "error" => call.error.is_some(),
2318 "pending" => call.result.is_none() && call.error.is_none(),
2319 _ => true,
2320 }
2321 })
2322 .unwrap_or(true)
2323 })
2324 .rev() .take(count)
2326 .collect();
2327
2328 writeln!(output, "[tool_calls ({} calls)]", filtered.len())
2329 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2330
2331 for call in filtered {
2332 writeln!(output, " [call {}]", call.call_id)
2333 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2334 writeln!(output, " tool = {:?}", call.tool_name)
2335 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2336 writeln!(output, " arguments = {:?}", call.arguments)
2337 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2338
2339 if include_outputs {
2340 if let Some(result) = &call.result {
2341 writeln!(output, " result = {:?}", result)
2342 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2343 }
2344 if let Some(error) = &call.error {
2345 writeln!(output, " error = {:?}", error)
2346 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2347 }
2348 }
2349 }
2350
2351 Ok(output)
2352 }
2353
2354 fn search_by_text_with_embedding(
2360 &self,
2361 index: &std::sync::Arc<dyn VectorIndex>,
2362 collection: &str,
2363 text: &str,
2364 k: usize,
2365 min_score: Option<f32>,
2366 ) -> Result<Vec<VectorSearchResult>, String> {
2367 match &self.embedding_provider {
2368 Some(provider) => {
2369 let embedding = provider.embed_text(text)?;
2371 index.search_by_embedding(collection, &embedding, k, min_score)
2373 }
2374 None => {
2375 index.search_by_text(collection, text, k, min_score)
2377 }
2378 }
2379 }
2380
2381 fn format_search_results(
2383 &self,
2384 section_name: &str,
2385 results: &[VectorSearchResult],
2386 ) -> Result<String, ContextQueryError> {
2387 use std::fmt::Write;
2388
2389 let mut output = String::new();
2390 writeln!(output, "[{} ({} results)]", section_name, results.len())
2391 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2392
2393 for (i, result) in results.iter().enumerate() {
2394 writeln!(output, " [result {} score={:.4}]", i + 1, result.score)
2395 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2396 writeln!(output, " id = {}", result.id)
2397 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2398
2399 for line in result.content.lines() {
2401 writeln!(output, " {}", line)
2402 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2403 }
2404
2405 if !result.metadata.is_empty() {
2407 writeln!(output, " [metadata]")
2408 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2409 for (key, value) in &result.metadata {
2410 writeln!(output, " {} = {:?}", key, value)
2411 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2412 }
2413 }
2414 }
2415
2416 Ok(output)
2417 }
2418
2419 #[allow(dead_code)]
2421 fn truncate_content(
2422 &self,
2423 content: &str,
2424 max_tokens: usize,
2425 strategy: TruncationStrategy,
2426 ) -> String {
2427 let max_chars = max_tokens * 4;
2429
2430 if content.len() <= max_chars {
2431 return content.to_string();
2432 }
2433
2434 match strategy {
2435 TruncationStrategy::TailDrop => {
2436 let mut result: String = content.chars().take(max_chars - 3).collect();
2437 result.push_str("...");
2438 result
2439 }
2440 TruncationStrategy::HeadDrop => {
2441 let skip = content.len() - max_chars + 3;
2442 let mut result = "...".to_string();
2443 result.extend(content.chars().skip(skip));
2444 result
2445 }
2446 TruncationStrategy::Proportional => {
2447 let quarter = max_chars / 4;
2449 let first: String = content.chars().take(quarter).collect();
2450 let last: String = content
2451 .chars()
2452 .skip(content.len().saturating_sub(quarter))
2453 .collect();
2454 format!("{}...{}...", first, last)
2455 }
2456 TruncationStrategy::Fail => {
2457 content.to_string() }
2459 }
2460 }
2461
2462 pub fn get_session_context(&self) -> HashMap<String, String> {
2464 self.context
2465 .variables
2466 .iter()
2467 .map(|(k, v)| (k.clone(), v.to_string()))
2468 .collect()
2469 }
2470
2471 pub fn set_variable(&mut self, name: &str, value: ContextValue) {
2473 self.context.set_var(name, value);
2474 }
2475
2476 pub fn remaining_budget(&self) -> u64 {
2478 self.context
2479 .budget
2480 .max_tokens
2481 .map(|max| max.saturating_sub(self.context.budget.tokens_used))
2482 .unwrap_or(u64::MAX)
2483 }
2484}
2485
2486#[derive(Debug, Clone)]
2488pub struct ContextQueryResult {
2489 pub output_name: String,
2491 pub sections: Vec<SectionResult>,
2493 pub total_tokens: usize,
2495 pub token_limit: usize,
2497 pub format: OutputFormat,
2499 pub allocation_explain: Option<Vec<crate::token_budget::AllocationDecision>>,
2501}
2502
2503impl ContextQueryResult {
2504 fn new(output_name: String) -> Self {
2505 Self {
2506 output_name,
2507 sections: Vec::new(),
2508 total_tokens: 0,
2509 token_limit: 0,
2510 format: OutputFormat::Soch,
2511 allocation_explain: None,
2512 }
2513 }
2514
2515 pub fn render(&self) -> String {
2517 let mut output = String::new();
2518
2519 match self.format {
2520 OutputFormat::Soch => {
2521 output.push_str(&format!("{}[{}]:\n", self.output_name, self.sections.len()));
2523 for section in &self.sections {
2524 output.push_str(&format!(
2525 " {}[{}{}]:\n",
2526 section.name,
2527 section.tokens_used,
2528 if section.truncated { "T" } else { "" }
2529 ));
2530 for line in section.content.lines() {
2531 output.push_str(&format!(" {}\n", line));
2532 }
2533 }
2534 }
2535 OutputFormat::Json => {
2536 output.push_str("{\n");
2537 output.push_str(&format!(" \"name\": \"{}\",\n", self.output_name));
2538 output.push_str(&format!(" \"total_tokens\": {},\n", self.total_tokens));
2539 output.push_str(" \"sections\": [\n");
2540 for (i, section) in self.sections.iter().enumerate() {
2541 output.push_str(&format!(" {{\"name\": \"{}\", \"tokens\": {}, \"truncated\": {}, \"content\": \"{}\"}}",
2542 section.name,
2543 section.tokens_used,
2544 section.truncated,
2545 section.content.replace('"', "\\\"").replace('\n', "\\n")
2546 ));
2547 if i < self.sections.len() - 1 {
2548 output.push(',');
2549 }
2550 output.push('\n');
2551 }
2552 output.push_str(" ]\n}");
2553 }
2554 OutputFormat::Markdown => {
2555 output.push_str(&format!("# {}\n\n", self.output_name));
2556 output.push_str(&format!(
2557 "*Tokens: {}/{}*\n\n",
2558 self.total_tokens, self.token_limit
2559 ));
2560 for section in &self.sections {
2561 output.push_str(&format!("## {}", section.name));
2562 if section.truncated {
2563 output.push_str(" *(truncated)*");
2564 }
2565 output.push_str("\n\n");
2566 output.push_str(§ion.content);
2567 output.push_str("\n\n");
2568 }
2569 }
2570 }
2571
2572 output
2573 }
2574
2575 pub fn utilization(&self) -> f64 {
2577 if self.token_limit == 0 {
2578 return 0.0;
2579 }
2580 (self.total_tokens as f64 / self.token_limit as f64) * 100.0
2581 }
2582
2583 pub fn has_truncation(&self) -> bool {
2585 self.sections.iter().any(|s| s.truncated)
2586 }
2587}
2588
2589#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
2591pub struct SectionPriority(pub i32);
2592
2593impl SectionPriority {
2594 pub const CRITICAL: SectionPriority = SectionPriority(-100);
2595 pub const SYSTEM: SectionPriority = SectionPriority(-1);
2596 pub const USER: SectionPriority = SectionPriority(0);
2597 pub const HISTORY: SectionPriority = SectionPriority(1);
2598 pub const KNOWLEDGE: SectionPriority = SectionPriority(2);
2599 pub const SUPPLEMENTARY: SectionPriority = SectionPriority(10);
2600}
2601
2602pub struct ContextQueryBuilder {
2608 output_name: String,
2609 session: SessionReference,
2610 options: ContextQueryOptions,
2611 sections: Vec<ContextSection>,
2612}
2613
2614impl ContextQueryBuilder {
2615 pub fn new(output_name: &str) -> Self {
2617 Self {
2618 output_name: output_name.to_string(),
2619 session: SessionReference::None,
2620 options: ContextQueryOptions::default(),
2621 sections: Vec::new(),
2622 }
2623 }
2624
2625 pub fn from_session(mut self, session_id: &str) -> Self {
2627 self.session = SessionReference::Session(session_id.to_string());
2628 self
2629 }
2630
2631 pub fn from_agent(mut self, agent_id: &str) -> Self {
2633 self.session = SessionReference::Agent(agent_id.to_string());
2634 self
2635 }
2636
2637 pub fn with_token_limit(mut self, limit: usize) -> Self {
2639 self.options.token_limit = limit;
2640 self
2641 }
2642
2643 pub fn include_schema(mut self, include: bool) -> Self {
2645 self.options.include_schema = include;
2646 self
2647 }
2648
2649 pub fn format(mut self, format: OutputFormat) -> Self {
2651 self.options.format = format;
2652 self
2653 }
2654
2655 pub fn truncation(mut self, strategy: TruncationStrategy) -> Self {
2657 self.options.truncation = strategy;
2658 self
2659 }
2660
2661 pub fn get(mut self, name: &str, priority: i32, path: &str) -> Self {
2663 let path_expr = PathExpression::parse(path).unwrap_or(PathExpression {
2664 segments: vec![path.to_string()],
2665 fields: vec![],
2666 all_fields: true,
2667 });
2668
2669 self.sections.push(ContextSection {
2670 name: name.to_string(),
2671 priority,
2672 content: SectionContent::Get { path: path_expr },
2673 transform: None,
2674 });
2675 self
2676 }
2677
2678 pub fn last(mut self, name: &str, priority: i32, count: usize, table: &str) -> Self {
2680 self.sections.push(ContextSection {
2681 name: name.to_string(),
2682 priority,
2683 content: SectionContent::Last {
2684 count,
2685 table: table.to_string(),
2686 where_clause: None,
2687 },
2688 transform: None,
2689 });
2690 self
2691 }
2692
2693 pub fn search(
2695 mut self,
2696 name: &str,
2697 priority: i32,
2698 collection: &str,
2699 query_var: &str,
2700 top_k: usize,
2701 ) -> Self {
2702 self.sections.push(ContextSection {
2703 name: name.to_string(),
2704 priority,
2705 content: SectionContent::Search {
2706 collection: collection.to_string(),
2707 query: SimilarityQuery::Variable(query_var.to_string()),
2708 top_k,
2709 min_score: None,
2710 },
2711 transform: None,
2712 });
2713 self
2714 }
2715
2716 pub fn literal(mut self, name: &str, priority: i32, value: &str) -> Self {
2718 self.sections.push(ContextSection {
2719 name: name.to_string(),
2720 priority,
2721 content: SectionContent::Literal {
2722 value: value.to_string(),
2723 },
2724 transform: None,
2725 });
2726 self
2727 }
2728
2729 pub fn build(self) -> ContextSelectQuery {
2731 ContextSelectQuery {
2732 output_name: self.output_name,
2733 session: self.session,
2734 options: self.options,
2735 sections: self.sections,
2736 }
2737 }
2738}
2739
2740#[cfg(test)]
2745mod tests {
2746 use super::*;
2747
2748 #[test]
2749 fn test_path_expression_simple() {
2750 let path = PathExpression::parse("user.profile").unwrap();
2751 assert_eq!(path.segments, vec!["user", "profile"]);
2752 assert!(path.all_fields);
2753 }
2754
2755 #[test]
2756 fn test_path_expression_with_fields() {
2757 let path = PathExpression::parse("user.profile.{name, email}").unwrap();
2758 assert_eq!(path.segments, vec!["user", "profile"]);
2759 assert_eq!(path.fields, vec!["name", "email"]);
2760 assert!(!path.all_fields);
2761 }
2762
2763 #[test]
2764 fn test_path_expression_glob() {
2765 let path = PathExpression::parse("user.**").unwrap();
2766 assert_eq!(path.segments, vec!["user"]);
2767 assert!(path.all_fields);
2768 }
2769
2770 #[test]
2771 fn test_parse_simple_query() {
2772 let query = r#"
2773 CONTEXT SELECT prompt_context
2774 FROM session($SESSION_ID)
2775 WITH (token_limit = 2048, include_schema = true)
2776 SECTIONS (
2777 USER PRIORITY 0: GET user.profile.{name, preferences}
2778 )
2779 "#;
2780
2781 let mut parser = ContextQueryParser::new(query);
2782 let result = parser.parse().unwrap();
2783
2784 assert_eq!(result.output_name, "prompt_context");
2785 assert!(matches!(result.session, SessionReference::Session(s) if s == "SESSION_ID"));
2786 assert_eq!(result.options.token_limit, 2048);
2787 assert!(result.options.include_schema);
2788 assert_eq!(result.sections.len(), 1);
2789 assert_eq!(result.sections[0].name, "USER");
2790 assert_eq!(result.sections[0].priority, 0);
2791 }
2792
2793 #[test]
2794 fn test_parse_multiple_sections() {
2795 let query = r#"
2796 CONTEXT SELECT context
2797 SECTIONS (
2798 A PRIORITY 0: "literal value",
2799 B PRIORITY 1: LAST 10 FROM logs,
2800 C PRIORITY 2: SEARCH docs BY SIMILARITY($query) TOP 5
2801 )
2802 "#;
2803
2804 let mut parser = ContextQueryParser::new(query);
2805 let result = parser.parse().unwrap();
2806
2807 assert_eq!(result.sections.len(), 3);
2808
2809 assert_eq!(result.sections[0].name, "A");
2811 assert!(
2812 matches!(&result.sections[0].content, SectionContent::Literal { value } if value == "literal value")
2813 );
2814
2815 assert_eq!(result.sections[1].name, "B");
2817 assert!(
2818 matches!(&result.sections[1].content, SectionContent::Last { count: 10, table, .. } if table == "logs")
2819 );
2820
2821 assert_eq!(result.sections[2].name, "C");
2823 assert!(
2824 matches!(&result.sections[2].content, SectionContent::Search { collection, top_k: 5, .. } if collection == "docs")
2825 );
2826 }
2827
2828 #[test]
2829 fn test_builder() {
2830 let query = ContextQueryBuilder::new("prompt")
2831 .from_session("sess123")
2832 .with_token_limit(4096)
2833 .include_schema(false)
2834 .get("USER", 0, "user.profile.{name, email}")
2835 .last("HISTORY", 1, 20, "events")
2836 .search("DOCS", 2, "knowledge_base", "query_embedding", 10)
2837 .literal("SYSTEM", -1, "You are a helpful assistant")
2838 .build();
2839
2840 assert_eq!(query.output_name, "prompt");
2841 assert_eq!(query.options.token_limit, 4096);
2842 assert!(!query.options.include_schema);
2843 assert_eq!(query.sections.len(), 4);
2844
2845 let system = query.sections.iter().find(|s| s.name == "SYSTEM").unwrap();
2847 assert_eq!(system.priority, -1);
2848 }
2849
2850 #[test]
2851 fn test_output_format() {
2852 let query = r#"
2853 CONTEXT SELECT ctx
2854 WITH (format = markdown)
2855 SECTIONS ()
2856 "#;
2857
2858 let mut parser = ContextQueryParser::new(query);
2859 let result = parser.parse().unwrap();
2860
2861 assert_eq!(result.options.format, OutputFormat::Markdown);
2862 }
2863
2864 #[test]
2865 fn test_truncation_strategy() {
2866 let query = r#"
2867 CONTEXT SELECT ctx
2868 WITH (truncation = proportional)
2869 SECTIONS ()
2870 "#;
2871
2872 let mut parser = ContextQueryParser::new(query);
2873 let result = parser.parse().unwrap();
2874
2875 assert_eq!(result.options.truncation, TruncationStrategy::Proportional);
2876 }
2877
2878 #[test]
2883 fn test_simple_vector_index_creation() {
2884 let index = SimpleVectorIndex::new();
2885 index.create_collection("test", 3);
2886
2887 let stats = index.stats("test");
2888 assert!(stats.is_some());
2889 let stats = stats.unwrap();
2890 assert_eq!(stats.dimension, 3);
2891 assert_eq!(stats.vector_count, 0);
2892 assert_eq!(stats.metric, "cosine");
2893 }
2894
2895 #[test]
2896 fn test_simple_vector_index_insert_and_search() {
2897 let index = SimpleVectorIndex::new();
2898 index.create_collection("docs", 3);
2899
2900 index
2902 .insert(
2903 "docs",
2904 "doc1".to_string(),
2905 vec![1.0, 0.0, 0.0],
2906 "Document about cats".to_string(),
2907 HashMap::new(),
2908 )
2909 .unwrap();
2910
2911 index
2912 .insert(
2913 "docs",
2914 "doc2".to_string(),
2915 vec![0.9, 0.1, 0.0],
2916 "Document about dogs".to_string(),
2917 HashMap::new(),
2918 )
2919 .unwrap();
2920
2921 index
2922 .insert(
2923 "docs",
2924 "doc3".to_string(),
2925 vec![0.0, 0.0, 1.0],
2926 "Document about cars".to_string(),
2927 HashMap::new(),
2928 )
2929 .unwrap();
2930
2931 let results = index
2933 .search_by_embedding("docs", &[1.0, 0.0, 0.0], 2, None)
2934 .unwrap();
2935
2936 assert_eq!(results.len(), 2);
2937 assert_eq!(results[0].id, "doc1"); assert!((results[0].score - 1.0).abs() < 0.001);
2939 assert_eq!(results[1].id, "doc2"); assert!(results[1].score > 0.9); }
2942
2943 #[test]
2944 fn test_simple_vector_index_min_score_filter() {
2945 let index = SimpleVectorIndex::new();
2946 index.create_collection("docs", 3);
2947
2948 index
2949 .insert(
2950 "docs",
2951 "a".to_string(),
2952 vec![1.0, 0.0, 0.0],
2953 "A".to_string(),
2954 HashMap::new(),
2955 )
2956 .unwrap();
2957 index
2958 .insert(
2959 "docs",
2960 "b".to_string(),
2961 vec![0.0, 1.0, 0.0],
2962 "B".to_string(),
2963 HashMap::new(),
2964 )
2965 .unwrap();
2966 index
2967 .insert(
2968 "docs",
2969 "c".to_string(),
2970 vec![0.0, 0.0, 1.0],
2971 "C".to_string(),
2972 HashMap::new(),
2973 )
2974 .unwrap();
2975
2976 let results = index
2978 .search_by_embedding("docs", &[1.0, 0.0, 0.0], 10, Some(0.9))
2979 .unwrap();
2980
2981 assert_eq!(results.len(), 1);
2982 assert_eq!(results[0].id, "a");
2983 }
2984
2985 #[test]
2986 fn test_simple_vector_index_dimension_mismatch() {
2987 let index = SimpleVectorIndex::new();
2988 index.create_collection("docs", 3);
2989
2990 let result = index.insert(
2991 "docs",
2992 "bad".to_string(),
2993 vec![1.0, 0.0], "Content".to_string(),
2995 HashMap::new(),
2996 );
2997
2998 assert!(result.is_err());
2999 assert!(result.unwrap_err().contains("dimension mismatch"));
3000 }
3001
3002 #[test]
3003 fn test_simple_vector_index_nonexistent_collection() {
3004 let index = SimpleVectorIndex::new();
3005
3006 let result = index.search_by_embedding("nonexistent", &[1.0], 1, None);
3007 assert!(result.is_err());
3008 assert!(result.unwrap_err().contains("not found"));
3009 }
3010
3011 #[test]
3012 fn test_vector_index_with_metadata() {
3013 let index = SimpleVectorIndex::new();
3014 index.create_collection("docs", 2);
3015
3016 let mut metadata = HashMap::new();
3017 metadata.insert("author".to_string(), SochValue::Text("Alice".to_string()));
3018 metadata.insert("year".to_string(), SochValue::Int(2024));
3019
3020 index
3021 .insert(
3022 "docs",
3023 "doc1".to_string(),
3024 vec![1.0, 0.0],
3025 "Document content".to_string(),
3026 metadata,
3027 )
3028 .unwrap();
3029
3030 let results = index
3031 .search_by_embedding("docs", &[1.0, 0.0], 1, None)
3032 .unwrap();
3033
3034 assert_eq!(results.len(), 1);
3035 assert!(results[0].metadata.contains_key("author"));
3036 assert!(results[0].metadata.contains_key("year"));
3037 }
3038
3039 #[test]
3040 fn test_vector_index_text_search_unsupported() {
3041 let index = SimpleVectorIndex::new();
3042 index.create_collection("docs", 2);
3043
3044 let result = index.search_by_text("docs", "hello", 5, None);
3046 assert!(result.is_err());
3047 assert!(result.unwrap_err().contains("embedding model"));
3048 }
3049}