1use crate::soch_ql::{ComparisonOp, Condition, LogicalOp, SochValue, WhereClause};
55use crate::token_budget::{BudgetSection, TokenBudgetConfig, TokenBudgetEnforcer, TokenEstimator};
56use std::collections::HashMap;
57
58#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct ContextSelectQuery {
65 pub output_name: String,
67 pub session: SessionReference,
69 pub options: ContextQueryOptions,
71 pub sections: Vec<ContextSection>,
73}
74
75#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
77pub enum SessionReference {
78 Session(String),
80 Agent(String),
82 None,
84}
85
86#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
88pub struct ContextQueryOptions {
89 pub token_limit: usize,
91 pub include_schema: bool,
93 pub format: OutputFormat,
95 pub truncation: TruncationStrategy,
97 pub include_headers: bool,
99}
100
101impl Default for ContextQueryOptions {
102 fn default() -> Self {
103 Self {
104 token_limit: 4096,
105 include_schema: true,
106 format: OutputFormat::Soch,
107 truncation: TruncationStrategy::TailDrop,
108 include_headers: true,
109 }
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
115pub enum OutputFormat {
116 Soch,
118 Json,
120 Markdown,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
126pub enum TruncationStrategy {
127 TailDrop,
129 HeadDrop,
131 Proportional,
133 Fail,
135}
136
137#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
143pub struct ContextSection {
144 pub name: String,
146 pub priority: i32,
148 pub content: SectionContent,
150 pub transform: Option<SectionTransform>,
152}
153
154#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
156pub enum SectionContent {
157 Get { path: PathExpression },
160
161 Last {
164 count: usize,
165 table: String,
166 where_clause: Option<WhereClause>,
167 },
168
169 Search {
172 collection: String,
173 query: SimilarityQuery,
174 top_k: usize,
175 min_score: Option<f32>,
176 },
177
178 Select {
180 columns: Vec<String>,
181 table: String,
182 where_clause: Option<WhereClause>,
183 limit: Option<usize>,
184 },
185
186 Literal { value: String },
188
189 Variable { name: String },
191
192 ToolRegistry {
195 include: Vec<String>,
197 exclude: Vec<String>,
199 include_schema: bool,
201 },
202
203 ToolCalls {
206 count: usize,
208 tool_filter: Option<String>,
210 status_filter: Option<String>,
212 include_outputs: bool,
214 },
215}
216
217#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
219pub struct PathExpression {
220 pub segments: Vec<String>,
222 pub fields: Vec<String>,
224 pub all_fields: bool,
226}
227
228impl PathExpression {
229 pub fn parse(input: &str) -> Result<Self, ContextParseError> {
232 let input = input.trim();
233
234 if let Some(brace_start) = input.find('{') {
236 if !input.ends_with('}') {
237 return Err(ContextParseError::InvalidPath(
238 "unclosed field projection".to_string(),
239 ));
240 }
241
242 let path_part = &input[..brace_start].trim_end_matches('.');
243 let fields_part = &input[brace_start + 1..input.len() - 1];
244
245 let segments: Vec<String> = path_part.split('.').map(|s| s.to_string()).collect();
246 let fields: Vec<String> = fields_part
247 .split(',')
248 .map(|s| s.trim().to_string())
249 .filter(|s| !s.is_empty())
250 .collect();
251
252 Ok(PathExpression {
253 segments,
254 fields,
255 all_fields: false,
256 })
257 } else if let Some(path_part) = input.strip_suffix(".**") {
258 let segments: Vec<String> = path_part.split('.').map(|s| s.to_string()).collect();
260
261 Ok(PathExpression {
262 segments,
263 fields: vec![],
264 all_fields: true,
265 })
266 } else {
267 let segments: Vec<String> = input.split('.').map(|s| s.to_string()).collect();
269
270 Ok(PathExpression {
271 segments,
272 fields: vec![],
273 all_fields: true,
274 })
275 }
276 }
277
278 pub fn to_path_string(&self) -> String {
280 let base = self.segments.join(".");
281 if self.all_fields {
282 format!("{}.**", base)
283 } else if !self.fields.is_empty() {
284 format!("{}.{{{}}}", base, self.fields.join(", "))
285 } else {
286 base
287 }
288 }
289}
290
291#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
293pub enum SimilarityQuery {
294 Variable(String),
296 Embedding(Vec<f32>),
298 Text(String),
300}
301
302#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
304pub enum SectionTransform {
305 Summarize { max_tokens: usize },
307 Project { fields: Vec<String> },
309 Template { template: String },
311 Custom { function: String },
313}
314
315#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
327pub struct ContextRecipe {
328 pub id: String,
330 pub name: String,
332 pub description: String,
334 pub version: String,
336 pub query: ContextSelectQuery,
338 pub metadata: RecipeMetadata,
340 pub session_binding: Option<SessionBinding>,
342}
343
344#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
346pub struct RecipeMetadata {
347 pub author: Option<String>,
349 pub created_at: Option<String>,
351 pub updated_at: Option<String>,
353 pub tags: Vec<String>,
355 pub usage_count: u64,
357 pub avg_tokens: Option<f32>,
359}
360
361#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
363pub enum SessionBinding {
364 Session(String),
366 Agent(String),
368 Pattern(String),
370 None,
372}
373
374pub struct ContextRecipeStore {
376 recipes: std::sync::RwLock<HashMap<String, ContextRecipe>>,
378 versions: std::sync::RwLock<HashMap<String, Vec<String>>>,
380}
381
382impl ContextRecipeStore {
383 pub fn new() -> Self {
385 Self {
386 recipes: std::sync::RwLock::new(HashMap::new()),
387 versions: std::sync::RwLock::new(HashMap::new()),
388 }
389 }
390
391 pub fn save(&self, recipe: ContextRecipe) -> Result<(), String> {
393 let mut recipes = self.recipes.write().map_err(|e| e.to_string())?;
394 let mut versions = self.versions.write().map_err(|e| e.to_string())?;
395
396 let key = format!("{}:{}", recipe.id, recipe.version);
397 recipes.insert(key.clone(), recipe.clone());
398
399 versions
400 .entry(recipe.id.clone())
401 .or_default()
402 .push(recipe.version.clone());
403
404 Ok(())
405 }
406
407 pub fn get_latest(&self, recipe_id: &str) -> Option<ContextRecipe> {
409 let versions = self.versions.read().ok()?;
410 let latest_version = versions.get(recipe_id)?.last()?;
411
412 let recipes = self.recipes.read().ok()?;
413 let key = format!("{}:{}", recipe_id, latest_version);
414 recipes.get(&key).cloned()
415 }
416
417 pub fn get_version(&self, recipe_id: &str, version: &str) -> Option<ContextRecipe> {
419 let recipes = self.recipes.read().ok()?;
420 let key = format!("{}:{}", recipe_id, version);
421 recipes.get(&key).cloned()
422 }
423
424 pub fn list_versions(&self, recipe_id: &str) -> Vec<String> {
426 self.versions
427 .read()
428 .ok()
429 .and_then(|v| v.get(recipe_id).cloned())
430 .unwrap_or_default()
431 }
432
433 pub fn find_by_session(&self, session_id: &str) -> Vec<ContextRecipe> {
435 let recipes = match self.recipes.read() {
436 Ok(r) => r,
437 Err(_) => return Vec::new(),
438 };
439
440 recipes
441 .values()
442 .filter(|r| match &r.session_binding {
443 Some(SessionBinding::Session(sid)) => sid == session_id,
444 Some(SessionBinding::Pattern(pattern)) => glob_match(pattern, session_id),
445 _ => false,
446 })
447 .cloned()
448 .collect()
449 }
450
451 pub fn find_by_agent(&self, agent_id: &str) -> Vec<ContextRecipe> {
453 let recipes = match self.recipes.read() {
454 Ok(r) => r,
455 Err(_) => return Vec::new(),
456 };
457
458 recipes
459 .values()
460 .filter(|r| matches!(&r.session_binding, Some(SessionBinding::Agent(aid)) if aid == agent_id))
461 .cloned()
462 .collect()
463 }
464}
465
466impl Default for ContextRecipeStore {
467 fn default() -> Self {
468 Self::new()
469 }
470}
471
472fn glob_match(pattern: &str, input: &str) -> bool {
474 if pattern == "*" {
476 return true;
477 }
478 if pattern.contains('*') {
479 let parts: Vec<&str> = pattern.split('*').collect();
480 if parts.len() == 2 {
481 return input.starts_with(parts[0]) && input.ends_with(parts[1]);
482 }
483 }
484 pattern == input
485}
486
487#[derive(Debug, Clone)]
493pub struct VectorSearchResult {
494 pub id: String,
496 pub score: f32,
498 pub content: String,
500 pub metadata: HashMap<String, SochValue>,
502}
503
504pub trait VectorIndex: Send + Sync {
511 fn search_by_embedding(
513 &self,
514 collection: &str,
515 embedding: &[f32],
516 k: usize,
517 min_score: Option<f32>,
518 ) -> Result<Vec<VectorSearchResult>, String>;
519
520 fn search_by_text(
522 &self,
523 collection: &str,
524 text: &str,
525 k: usize,
526 min_score: Option<f32>,
527 ) -> Result<Vec<VectorSearchResult>, String>;
528
529 fn stats(&self, collection: &str) -> Option<VectorIndexStats>;
531}
532
533#[derive(Debug, Clone)]
535pub struct VectorIndexStats {
536 pub vector_count: usize,
538 pub dimension: usize,
540 pub metric: String,
542}
543
544pub struct SimpleVectorIndex {
549 collections: std::sync::RwLock<HashMap<String, VectorCollection>>,
551}
552
553struct VectorCollection {
555 #[allow(clippy::type_complexity)]
557 vectors: Vec<(String, Vec<f32>, String, HashMap<String, SochValue>)>,
558 dimension: usize,
560}
561
562impl SimpleVectorIndex {
563 pub fn new() -> Self {
565 Self {
566 collections: std::sync::RwLock::new(HashMap::new()),
567 }
568 }
569
570 pub fn create_collection(&self, name: &str, dimension: usize) {
572 let mut collections = self.collections.write().unwrap();
573 collections
574 .entry(name.to_string())
575 .or_insert_with(|| VectorCollection {
576 vectors: Vec::new(),
577 dimension,
578 });
579 }
580
581 pub fn insert(
583 &self,
584 collection: &str,
585 id: String,
586 vector: Vec<f32>,
587 content: String,
588 metadata: HashMap<String, SochValue>,
589 ) -> Result<(), String> {
590 let mut collections = self.collections.write().unwrap();
591 let coll = collections
592 .get_mut(collection)
593 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
594
595 if vector.len() != coll.dimension {
596 return Err(format!(
597 "Vector dimension mismatch: expected {}, got {}",
598 coll.dimension,
599 vector.len()
600 ));
601 }
602
603 coll.vectors.push((id, vector, content, metadata));
604 Ok(())
605 }
606
607 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
609 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
610 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
611 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
612 if norm_a == 0.0 || norm_b == 0.0 {
613 0.0
614 } else {
615 dot / (norm_a * norm_b)
616 }
617 }
618}
619
620impl Default for SimpleVectorIndex {
621 fn default() -> Self {
622 Self::new()
623 }
624}
625
626impl VectorIndex for SimpleVectorIndex {
627 fn search_by_embedding(
628 &self,
629 collection: &str,
630 embedding: &[f32],
631 k: usize,
632 min_score: Option<f32>,
633 ) -> Result<Vec<VectorSearchResult>, String> {
634 let collections = self.collections.read().unwrap();
635 let coll = collections
636 .get(collection)
637 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
638
639 let mut scored: Vec<_> = coll
641 .vectors
642 .iter()
643 .map(|(id, vec, content, meta)| {
644 let score = Self::cosine_similarity(embedding, vec);
645 (id, score, content, meta)
646 })
647 .filter(|(_, score, _, _)| min_score.map(|min| *score >= min).unwrap_or(true))
648 .collect();
649
650 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
652
653 Ok(scored
655 .into_iter()
656 .take(k)
657 .map(|(id, score, content, meta)| VectorSearchResult {
658 id: id.clone(),
659 score,
660 content: content.clone(),
661 metadata: meta.clone(),
662 })
663 .collect())
664 }
665
666 fn search_by_text(
667 &self,
668 _collection: &str,
669 _text: &str,
670 _k: usize,
671 _min_score: Option<f32>,
672 ) -> Result<Vec<VectorSearchResult>, String> {
673 Err(
675 "Text-based search requires an embedding model. Use search_by_embedding instead."
676 .to_string(),
677 )
678 }
679
680 fn stats(&self, collection: &str) -> Option<VectorIndexStats> {
681 let collections = self.collections.read().unwrap();
682 collections.get(collection).map(|coll| VectorIndexStats {
683 vector_count: coll.vectors.len(),
684 dimension: coll.dimension,
685 metric: "cosine".to_string(),
686 })
687 }
688}
689
690pub struct HnswVectorIndex {
699 collections: std::sync::RwLock<HashMap<String, HnswCollection>>,
701}
702
703struct HnswCollection {
705 index: sochdb_index::vector::VectorIndex,
707 #[allow(clippy::type_complexity)]
709 metadata: HashMap<u128, (String, String, HashMap<String, SochValue>)>,
710 next_edge_id: u128,
712 dimension: usize,
714}
715
716impl HnswVectorIndex {
717 pub fn new() -> Self {
719 Self {
720 collections: std::sync::RwLock::new(HashMap::new()),
721 }
722 }
723
724 pub fn create_collection(&self, name: &str, dimension: usize) {
726 let mut collections = self.collections.write().unwrap();
727 collections.entry(name.to_string()).or_insert_with(|| {
728 let index = sochdb_index::vector::VectorIndex::with_dimension(
729 sochdb_index::vector::DistanceMetric::Cosine,
730 dimension,
731 );
732 HnswCollection {
733 index,
734 metadata: HashMap::new(),
735 next_edge_id: 0,
736 dimension,
737 }
738 });
739 }
740
741 pub fn insert(
743 &self,
744 collection: &str,
745 id: String,
746 vector: Vec<f32>,
747 content: String,
748 metadata: HashMap<String, SochValue>,
749 ) -> Result<(), String> {
750 let mut collections = self.collections.write().unwrap();
751 let coll = collections
752 .get_mut(collection)
753 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
754
755 if vector.len() != coll.dimension {
756 return Err(format!(
757 "Vector dimension mismatch: expected {}, got {}",
758 coll.dimension,
759 vector.len()
760 ));
761 }
762
763 let edge_id = coll.next_edge_id;
765 coll.next_edge_id += 1;
766 coll.metadata.insert(edge_id, (id, content, metadata));
767
768 let embedding = ndarray::Array1::from_vec(vector);
770
771 coll.index.add(edge_id, embedding)?;
773
774 Ok(())
775 }
776
777 pub fn vector_count(&self, collection: &str) -> Option<usize> {
779 let collections = self.collections.read().unwrap();
780 collections.get(collection).map(|c| c.metadata.len())
781 }
782}
783
784impl Default for HnswVectorIndex {
785 fn default() -> Self {
786 Self::new()
787 }
788}
789
790impl VectorIndex for HnswVectorIndex {
791 fn search_by_embedding(
792 &self,
793 collection: &str,
794 embedding: &[f32],
795 k: usize,
796 min_score: Option<f32>,
797 ) -> Result<Vec<VectorSearchResult>, String> {
798 let collections = self.collections.read().unwrap();
799 let coll = collections
800 .get(collection)
801 .ok_or_else(|| format!("Collection '{}' not found", collection))?;
802
803 let query = ndarray::Array1::from_vec(embedding.to_vec());
805
806 let results = coll.index.search(&query, k)?;
808
809 let mut search_results = Vec::with_capacity(results.len());
812 for (edge_id, distance) in results {
813 let score = 1.0 - distance;
815
816 if let Some(min) = min_score {
818 if score < min {
819 continue;
820 }
821 }
822
823 if let Some((id, content, meta)) = coll.metadata.get(&edge_id) {
825 search_results.push(VectorSearchResult {
826 id: id.clone(),
827 score,
828 content: content.clone(),
829 metadata: meta.clone(),
830 });
831 }
832 }
833
834 Ok(search_results)
835 }
836
837 fn search_by_text(
838 &self,
839 _collection: &str,
840 _text: &str,
841 _k: usize,
842 _min_score: Option<f32>,
843 ) -> Result<Vec<VectorSearchResult>, String> {
844 Err(
846 "Text-based search requires an embedding model. Use search_by_embedding instead."
847 .to_string(),
848 )
849 }
850
851 fn stats(&self, collection: &str) -> Option<VectorIndexStats> {
852 let collections = self.collections.read().unwrap();
853 collections.get(collection).map(|coll| VectorIndexStats {
854 vector_count: coll.metadata.len(),
855 dimension: coll.dimension,
856 metric: "cosine".to_string(),
857 })
858 }
859}
860
861#[derive(Debug, Clone)]
867pub struct ContextResult {
868 pub context: String,
870 pub token_count: usize,
872 pub token_budget: usize,
874 pub sections_included: Vec<SectionResult>,
876 pub sections_truncated: Vec<String>,
878 pub sections_dropped: Vec<String>,
880}
881
882#[derive(Debug, Clone)]
884pub struct SectionResult {
885 pub name: String,
887 pub priority: i32,
889 pub content: String,
891 pub tokens: usize,
893 pub tokens_used: usize,
895 pub truncated: bool,
897 pub row_count: usize,
899}
900
901#[derive(Debug, Clone)]
907pub enum ContextQueryError {
908 SessionMismatch { expected: String, actual: String },
910 VariableNotFound(String),
912 InvalidVariableType { variable: String, expected: String },
914 BudgetExceeded {
916 section: String,
917 requested: usize,
918 available: usize,
919 },
920 BudgetExhausted(String),
922 PermissionDenied(String),
924 InvalidPath(String),
926 Parse(ContextParseError),
928 FormatError(String),
930 InvalidQuery(String),
932 VectorSearchError(String),
934}
935
936impl std::fmt::Display for ContextQueryError {
937 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
938 match self {
939 Self::SessionMismatch { expected, actual } => {
940 write!(f, "session mismatch: expected {}, got {}", expected, actual)
941 }
942 Self::VariableNotFound(name) => write!(f, "variable not found: {}", name),
943 Self::InvalidVariableType { variable, expected } => {
944 write!(
945 f,
946 "variable {} has invalid type, expected {}",
947 variable, expected
948 )
949 }
950 Self::BudgetExceeded {
951 section,
952 requested,
953 available,
954 } => {
955 write!(
956 f,
957 "section {} exceeds budget: {} > {}",
958 section, requested, available
959 )
960 }
961 Self::BudgetExhausted(msg) => write!(f, "budget exhausted: {}", msg),
962 Self::PermissionDenied(msg) => write!(f, "permission denied: {}", msg),
963 Self::InvalidPath(path) => write!(f, "invalid path: {}", path),
964 Self::Parse(e) => write!(f, "parse error: {}", e),
965 Self::FormatError(e) => write!(f, "format error: {}", e),
966 Self::InvalidQuery(msg) => write!(f, "invalid query: {}", msg),
967 Self::VectorSearchError(e) => write!(f, "vector search error: {}", e),
968 }
969 }
970}
971
972impl std::error::Error for ContextQueryError {}
973
974#[derive(Debug, Clone)]
976pub enum ContextParseError {
977 UnexpectedToken { expected: String, found: String },
979 MissingClause(String),
981 InvalidOption(String),
983 InvalidPath(String),
985 InvalidSection(String),
987 SyntaxError(String),
989}
990
991impl std::fmt::Display for ContextParseError {
992 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
993 match self {
994 Self::UnexpectedToken { expected, found } => {
995 write!(f, "expected {}, found '{}'", expected, found)
996 }
997 Self::MissingClause(clause) => write!(f, "missing {} clause", clause),
998 Self::InvalidOption(opt) => write!(f, "invalid option: {}", opt),
999 Self::InvalidPath(path) => write!(f, "invalid path: {}", path),
1000 Self::InvalidSection(sec) => write!(f, "invalid section: {}", sec),
1001 Self::SyntaxError(msg) => write!(f, "syntax error: {}", msg),
1002 }
1003 }
1004}
1005
1006impl std::error::Error for ContextParseError {}
1007
1008pub struct ContextQueryParser {
1010 pos: usize,
1012 tokens: Vec<Token>,
1014}
1015
1016#[derive(Debug, Clone, PartialEq)]
1018enum Token {
1019 Keyword(String),
1021 Ident(String),
1023 Number(f64),
1025 String(String),
1027 Punct(char),
1029 Variable(String),
1031 Eof,
1033}
1034
1035impl ContextQueryParser {
1036 pub fn new(input: &str) -> Self {
1038 let tokens = Self::tokenize(input);
1039 Self { pos: 0, tokens }
1040 }
1041
1042 pub fn parse(&mut self) -> Result<ContextSelectQuery, ContextParseError> {
1044 self.expect_keyword("CONTEXT")?;
1046 self.expect_keyword("SELECT")?;
1047 let output_name = self.expect_ident()?;
1048
1049 let session = if self.match_keyword("FROM") {
1051 self.parse_session_reference()?
1052 } else {
1053 SessionReference::None
1054 };
1055
1056 let options = if self.match_keyword("WITH") {
1058 self.parse_options()?
1059 } else {
1060 ContextQueryOptions::default()
1061 };
1062
1063 self.expect_keyword("SECTIONS")?;
1065 let sections = self.parse_sections()?;
1066
1067 Ok(ContextSelectQuery {
1068 output_name,
1069 session,
1070 options,
1071 sections,
1072 })
1073 }
1074
1075 fn parse_session_reference(&mut self) -> Result<SessionReference, ContextParseError> {
1077 if self.match_keyword("session") {
1078 self.expect_punct('(')?;
1079 let var = self.expect_variable()?;
1080 self.expect_punct(')')?;
1081 Ok(SessionReference::Session(var))
1082 } else if self.match_keyword("agent") {
1083 self.expect_punct('(')?;
1084 let var = self.expect_variable()?;
1085 self.expect_punct(')')?;
1086 Ok(SessionReference::Agent(var))
1087 } else {
1088 Err(ContextParseError::SyntaxError(
1089 "expected 'session' or 'agent'".to_string(),
1090 ))
1091 }
1092 }
1093
1094 fn parse_options(&mut self) -> Result<ContextQueryOptions, ContextParseError> {
1096 self.expect_punct('(')?;
1097 let mut options = ContextQueryOptions::default();
1098
1099 loop {
1100 let key = self.expect_ident()?;
1101 self.expect_punct('=')?;
1102
1103 match key.as_str() {
1104 "token_limit" => {
1105 if let Token::Number(n) = self.current().clone() {
1106 options.token_limit = n as usize;
1107 self.advance();
1108 }
1109 }
1110 "include_schema" => {
1111 options.include_schema = self.parse_bool()?;
1112 }
1113 "format" => {
1114 let format = self.expect_ident()?;
1115 options.format = match format.to_lowercase().as_str() {
1116 "toon" => OutputFormat::Soch,
1117 "json" => OutputFormat::Json,
1118 "markdown" => OutputFormat::Markdown,
1119 _ => return Err(ContextParseError::InvalidOption(format)),
1120 };
1121 }
1122 "truncation" => {
1123 let strategy = self.expect_ident()?;
1124 options.truncation = match strategy.to_lowercase().as_str() {
1125 "tail_drop" | "taildrop" => TruncationStrategy::TailDrop,
1126 "head_drop" | "headdrop" => TruncationStrategy::HeadDrop,
1127 "proportional" => TruncationStrategy::Proportional,
1128 "fail" => TruncationStrategy::Fail,
1129 _ => return Err(ContextParseError::InvalidOption(strategy)),
1130 };
1131 }
1132 "include_headers" => {
1133 options.include_headers = self.parse_bool()?;
1134 }
1135 _ => return Err(ContextParseError::InvalidOption(key)),
1136 }
1137
1138 if !self.match_punct(',') {
1139 break;
1140 }
1141 }
1142
1143 self.expect_punct(')')?;
1144 Ok(options)
1145 }
1146
1147 fn parse_sections(&mut self) -> Result<Vec<ContextSection>, ContextParseError> {
1149 self.expect_punct('(')?;
1150 let mut sections = Vec::new();
1151
1152 loop {
1153 if self.check_punct(')') {
1154 break;
1155 }
1156
1157 let section = self.parse_section()?;
1158 sections.push(section);
1159
1160 if !self.match_punct(',') {
1161 break;
1162 }
1163 }
1164
1165 self.expect_punct(')')?;
1166 Ok(sections)
1167 }
1168
1169 fn parse_section(&mut self) -> Result<ContextSection, ContextParseError> {
1171 let name = self.expect_ident()?;
1173
1174 self.expect_keyword("PRIORITY")?;
1175 let priority = if let Token::Number(n) = self.current().clone() {
1176 let val = n as i32;
1177 self.advance();
1178 val
1179 } else {
1180 0
1181 };
1182
1183 self.expect_punct(':')?;
1184
1185 let content = self.parse_section_content()?;
1186
1187 Ok(ContextSection {
1188 name,
1189 priority,
1190 content,
1191 transform: None,
1192 })
1193 }
1194
1195 fn parse_section_content(&mut self) -> Result<SectionContent, ContextParseError> {
1197 if self.match_keyword("GET") {
1198 let path_str = self.collect_until(&[',', ')']);
1200 let path = PathExpression::parse(&path_str)?;
1201 Ok(SectionContent::Get { path })
1202 } else if self.match_keyword("LAST") {
1203 let count = if let Token::Number(n) = self.current().clone() {
1205 let val = n as usize;
1206 self.advance();
1207 val
1208 } else {
1209 10 };
1211
1212 self.expect_keyword("FROM")?;
1213 let table = self.expect_ident()?;
1214
1215 let where_clause = if self.match_keyword("WHERE") {
1216 Some(self.parse_where_clause()?)
1217 } else {
1218 None
1219 };
1220
1221 Ok(SectionContent::Last {
1222 count,
1223 table,
1224 where_clause,
1225 })
1226 } else if self.match_keyword("SEARCH") {
1227 let collection = self.expect_ident()?;
1229 self.expect_keyword("BY")?;
1230 self.expect_keyword("SIMILARITY")?;
1231
1232 self.expect_punct('(')?;
1233 let query = if let Token::Variable(v) = self.current().clone() {
1234 self.advance();
1235 SimilarityQuery::Variable(v)
1236 } else if let Token::String(s) = self.current().clone() {
1237 self.advance();
1238 SimilarityQuery::Text(s)
1239 } else {
1240 return Err(ContextParseError::SyntaxError(
1241 "expected variable or string for similarity query".to_string(),
1242 ));
1243 };
1244 self.expect_punct(')')?;
1245
1246 self.expect_keyword("TOP")?;
1247 let top_k = if let Token::Number(n) = self.current().clone() {
1248 let val = n as usize;
1249 self.advance();
1250 val
1251 } else {
1252 5 };
1254
1255 Ok(SectionContent::Search {
1256 collection,
1257 query,
1258 top_k,
1259 min_score: None,
1260 })
1261 } else if self.match_keyword("SELECT") {
1262 let columns = self.parse_column_list()?;
1264 self.expect_keyword("FROM")?;
1265 let table = self.expect_ident()?;
1266
1267 let where_clause = if self.match_keyword("WHERE") {
1268 Some(self.parse_where_clause()?)
1269 } else {
1270 None
1271 };
1272
1273 let limit = if self.match_keyword("LIMIT") {
1274 if let Token::Number(n) = self.current().clone() {
1275 let val = n as usize;
1276 self.advance();
1277 Some(val)
1278 } else {
1279 None
1280 }
1281 } else {
1282 None
1283 };
1284
1285 Ok(SectionContent::Select {
1286 columns,
1287 table,
1288 where_clause,
1289 limit,
1290 })
1291 } else if let Token::Variable(v) = self.current().clone() {
1292 self.advance();
1293 Ok(SectionContent::Variable { name: v })
1294 } else if let Token::String(s) = self.current().clone() {
1295 self.advance();
1296 Ok(SectionContent::Literal { value: s })
1297 } else {
1298 Err(ContextParseError::InvalidSection(
1299 "expected GET, LAST, SEARCH, SELECT, or literal".to_string(),
1300 ))
1301 }
1302 }
1303
1304 fn parse_where_clause(&mut self) -> Result<WhereClause, ContextParseError> {
1306 let mut conditions = Vec::new();
1307
1308 loop {
1309 let column = self.expect_ident()?;
1310 let operator = self.parse_comparison_op()?;
1311 let value = self.parse_value()?;
1312
1313 conditions.push(Condition {
1314 column,
1315 operator,
1316 value,
1317 });
1318
1319 if !self.match_keyword("AND") && !self.match_keyword("OR") {
1320 break;
1321 }
1322 }
1323
1324 Ok(WhereClause {
1325 conditions,
1326 operator: LogicalOp::And,
1327 })
1328 }
1329
1330 fn parse_comparison_op(&mut self) -> Result<ComparisonOp, ContextParseError> {
1332 match self.current() {
1333 Token::Punct('=') => {
1334 self.advance();
1335 Ok(ComparisonOp::Eq)
1336 }
1337 Token::Punct('>') => {
1338 self.advance();
1339 if self.check_punct('=') {
1340 self.advance();
1341 Ok(ComparisonOp::Ge)
1342 } else {
1343 Ok(ComparisonOp::Gt)
1344 }
1345 }
1346 Token::Punct('<') => {
1347 self.advance();
1348 if self.check_punct('=') {
1349 self.advance();
1350 Ok(ComparisonOp::Le)
1351 } else {
1352 Ok(ComparisonOp::Lt)
1353 }
1354 }
1355 _ => {
1356 if self.match_keyword("LIKE") {
1357 Ok(ComparisonOp::Like)
1358 } else if self.match_keyword("IN") {
1359 Ok(ComparisonOp::In)
1360 } else {
1361 Err(ContextParseError::SyntaxError(
1362 "expected comparison operator".to_string(),
1363 ))
1364 }
1365 }
1366 }
1367 }
1368
1369 fn parse_value(&mut self) -> Result<SochValue, ContextParseError> {
1371 match self.current().clone() {
1372 Token::Number(n) => {
1373 self.advance();
1374 if n.fract() == 0.0 {
1375 Ok(SochValue::Int(n as i64))
1376 } else {
1377 Ok(SochValue::Float(n))
1378 }
1379 }
1380 Token::String(s) => {
1381 self.advance();
1382 Ok(SochValue::Text(s))
1383 }
1384 Token::Keyword(k) if k.eq_ignore_ascii_case("null") => {
1385 self.advance();
1386 Ok(SochValue::Null)
1387 }
1388 Token::Keyword(k) if k.eq_ignore_ascii_case("true") => {
1389 self.advance();
1390 Ok(SochValue::Bool(true))
1391 }
1392 Token::Keyword(k) if k.eq_ignore_ascii_case("false") => {
1393 self.advance();
1394 Ok(SochValue::Bool(false))
1395 }
1396 Token::Variable(v) => {
1397 self.advance();
1398 Ok(SochValue::Text(format!("${}", v)))
1400 }
1401 _ => Err(ContextParseError::SyntaxError("expected value".to_string())),
1402 }
1403 }
1404
1405 fn parse_column_list(&mut self) -> Result<Vec<String>, ContextParseError> {
1407 let mut columns = Vec::new();
1408
1409 if self.check_punct('*') {
1410 self.advance();
1411 columns.push("*".to_string());
1412 } else {
1413 loop {
1414 columns.push(self.expect_ident()?);
1415 if !self.match_punct(',') {
1416 break;
1417 }
1418 }
1419 }
1420
1421 Ok(columns)
1422 }
1423
1424 fn parse_bool(&mut self) -> Result<bool, ContextParseError> {
1426 match self.current() {
1427 Token::Keyword(k) if k.eq_ignore_ascii_case("true") => {
1428 self.advance();
1429 Ok(true)
1430 }
1431 Token::Keyword(k) if k.eq_ignore_ascii_case("false") => {
1432 self.advance();
1433 Ok(false)
1434 }
1435 _ => Err(ContextParseError::SyntaxError(
1436 "expected boolean".to_string(),
1437 )),
1438 }
1439 }
1440
1441 fn tokenize(input: &str) -> Vec<Token> {
1443 let mut tokens = Vec::new();
1444 let mut chars = input.chars().peekable();
1445
1446 while let Some(&ch) = chars.peek() {
1447 match ch {
1448 ' ' | '\t' | '\n' | '\r' => {
1450 chars.next();
1451 }
1452
1453 '(' | ')' | ',' | ':' | '=' | '<' | '>' | '*' | '{' | '}' | '.' => {
1455 tokens.push(Token::Punct(ch));
1456 chars.next();
1457 }
1458
1459 '$' => {
1461 chars.next();
1462 let mut name = String::new();
1463 while let Some(&c) = chars.peek() {
1464 if c.is_alphanumeric() || c == '_' {
1465 name.push(c);
1466 chars.next();
1467 } else {
1468 break;
1469 }
1470 }
1471 tokens.push(Token::Variable(name));
1472 }
1473
1474 '\'' | '"' => {
1476 let quote = ch;
1477 chars.next();
1478 let mut s = String::new();
1479 while let Some(&c) = chars.peek() {
1480 if c == quote {
1481 chars.next(); break;
1483 }
1484 s.push(c);
1485 chars.next();
1486 }
1487 tokens.push(Token::String(s));
1488 }
1489
1490 '0'..='9' | '-' => {
1492 let mut num_str = String::new();
1493 if ch == '-' {
1494 num_str.push(ch);
1495 chars.next();
1496 }
1497 while let Some(&c) = chars.peek() {
1498 if c.is_ascii_digit() || c == '.' {
1499 num_str.push(c);
1500 chars.next();
1501 } else {
1502 break;
1503 }
1504 }
1505 if let Ok(n) = num_str.parse::<f64>() {
1506 tokens.push(Token::Number(n));
1507 }
1508 }
1509
1510 'a'..='z' | 'A'..='Z' | '_' => {
1512 let mut ident = String::new();
1513 while let Some(&c) = chars.peek() {
1514 if c.is_alphanumeric() || c == '_' {
1515 ident.push(c);
1516 chars.next();
1517 } else {
1518 break;
1519 }
1520 }
1521
1522 let keywords = [
1524 "CONTEXT",
1525 "SELECT",
1526 "FROM",
1527 "WITH",
1528 "SECTIONS",
1529 "PRIORITY",
1530 "GET",
1531 "LAST",
1532 "SEARCH",
1533 "BY",
1534 "SIMILARITY",
1535 "TOP",
1536 "WHERE",
1537 "AND",
1538 "OR",
1539 "LIKE",
1540 "IN",
1541 "LIMIT",
1542 "session",
1543 "agent",
1544 "true",
1545 "false",
1546 "null",
1547 ];
1548
1549 if keywords.iter().any(|k| k.eq_ignore_ascii_case(&ident)) {
1550 tokens.push(Token::Keyword(ident.to_uppercase()));
1551 } else {
1552 tokens.push(Token::Ident(ident));
1553 }
1554 }
1555
1556 _ => {
1558 chars.next();
1559 }
1560 }
1561 }
1562
1563 tokens.push(Token::Eof);
1564 tokens
1565 }
1566
1567 fn current(&self) -> &Token {
1569 self.tokens.get(self.pos).unwrap_or(&Token::Eof)
1570 }
1571
1572 fn advance(&mut self) {
1573 if self.pos < self.tokens.len() {
1574 self.pos += 1;
1575 }
1576 }
1577
1578 fn expect_keyword(&mut self, kw: &str) -> Result<(), ContextParseError> {
1579 match self.current() {
1580 Token::Keyword(k) if k.eq_ignore_ascii_case(kw) => {
1581 self.advance();
1582 Ok(())
1583 }
1584 other => Err(ContextParseError::UnexpectedToken {
1585 expected: kw.to_string(),
1586 found: format!("{:?}", other),
1587 }),
1588 }
1589 }
1590
1591 fn match_keyword(&mut self, kw: &str) -> bool {
1592 match self.current() {
1593 Token::Keyword(k) if k.eq_ignore_ascii_case(kw) => {
1594 self.advance();
1595 true
1596 }
1597 _ => false,
1598 }
1599 }
1600
1601 fn expect_ident(&mut self) -> Result<String, ContextParseError> {
1602 match self.current().clone() {
1603 Token::Ident(s) => {
1604 self.advance();
1605 Ok(s)
1606 }
1607 Token::Keyword(s) => {
1608 self.advance();
1610 Ok(s)
1611 }
1612 other => Err(ContextParseError::UnexpectedToken {
1613 expected: "identifier".to_string(),
1614 found: format!("{:?}", other),
1615 }),
1616 }
1617 }
1618
1619 fn expect_variable(&mut self) -> Result<String, ContextParseError> {
1620 match self.current().clone() {
1621 Token::Variable(v) => {
1622 self.advance();
1623 Ok(v)
1624 }
1625 other => Err(ContextParseError::UnexpectedToken {
1626 expected: "variable ($name)".to_string(),
1627 found: format!("{:?}", other),
1628 }),
1629 }
1630 }
1631
1632 fn expect_punct(&mut self, p: char) -> Result<(), ContextParseError> {
1633 match self.current() {
1634 Token::Punct(c) if *c == p => {
1635 self.advance();
1636 Ok(())
1637 }
1638 other => Err(ContextParseError::UnexpectedToken {
1639 expected: p.to_string(),
1640 found: format!("{:?}", other),
1641 }),
1642 }
1643 }
1644
1645 fn match_punct(&mut self, p: char) -> bool {
1646 match self.current() {
1647 Token::Punct(c) if *c == p => {
1648 self.advance();
1649 true
1650 }
1651 _ => false,
1652 }
1653 }
1654
1655 fn check_punct(&self, p: char) -> bool {
1656 matches!(self.current(), Token::Punct(c) if *c == p)
1657 }
1658
1659 fn collect_until(&mut self, terminators: &[char]) -> String {
1660 let mut result = String::new();
1661 let mut depth = 0;
1662
1663 loop {
1664 match self.current() {
1665 Token::Punct('{') => {
1666 depth += 1;
1667 result.push('{');
1668 self.advance();
1669 }
1670 Token::Punct('}') => {
1671 depth -= 1;
1672 result.push('}');
1673 self.advance();
1674 }
1675 Token::Punct(c) if depth == 0 && terminators.contains(c) => {
1676 break;
1677 }
1678 Token::Punct(c) => {
1679 result.push(*c);
1680 self.advance();
1681 }
1682 Token::Ident(s) | Token::Keyword(s) => {
1683 if !result.is_empty() && !result.ends_with(['.', '{']) {
1684 result.push(' ');
1685 }
1686 result.push_str(s);
1687 self.advance();
1688 }
1689 Token::Eof => break,
1690 _ => {
1691 self.advance();
1692 }
1693 }
1694 }
1695
1696 result.trim().to_string()
1697 }
1698}
1699
1700use crate::agent_context::{AgentContext, AuditOperation, ContextValue};
1705
1706pub struct AgentContextIntegration<'a> {
1716 context: &'a mut AgentContext,
1718 budget_enforcer: TokenBudgetEnforcer,
1720 estimator: TokenEstimator,
1722 vector_index: Option<std::sync::Arc<dyn VectorIndex>>,
1724 embedding_provider: Option<std::sync::Arc<dyn EmbeddingProvider>>,
1726}
1727
1728pub trait EmbeddingProvider: Send + Sync {
1732 fn embed_text(&self, text: &str) -> Result<Vec<f32>, String>;
1734
1735 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, String> {
1737 texts.iter().map(|t| self.embed_text(t)).collect()
1738 }
1739
1740 fn dimension(&self) -> usize;
1742
1743 fn model_name(&self) -> &str;
1745}
1746
1747impl<'a> AgentContextIntegration<'a> {
1748 pub fn new(context: &'a mut AgentContext) -> Self {
1750 let config = TokenBudgetConfig {
1751 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1752 ..Default::default()
1753 };
1754
1755 Self {
1756 context,
1757 budget_enforcer: TokenBudgetEnforcer::new(config),
1758 estimator: TokenEstimator::default(),
1759 vector_index: None,
1760 embedding_provider: None,
1761 }
1762 }
1763
1764 pub fn with_vector_index(
1766 context: &'a mut AgentContext,
1767 vector_index: std::sync::Arc<dyn VectorIndex>,
1768 ) -> Self {
1769 let config = TokenBudgetConfig {
1770 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1771 ..Default::default()
1772 };
1773
1774 Self {
1775 context,
1776 budget_enforcer: TokenBudgetEnforcer::new(config),
1777 estimator: TokenEstimator::default(),
1778 vector_index: Some(vector_index),
1779 embedding_provider: None,
1780 }
1781 }
1782
1783 pub fn with_vector_and_embedding(
1785 context: &'a mut AgentContext,
1786 vector_index: std::sync::Arc<dyn VectorIndex>,
1787 embedding_provider: std::sync::Arc<dyn EmbeddingProvider>,
1788 ) -> Self {
1789 let config = TokenBudgetConfig {
1790 total_budget: context.budget.max_tokens.unwrap_or(4096) as usize,
1791 ..Default::default()
1792 };
1793
1794 Self {
1795 context,
1796 budget_enforcer: TokenBudgetEnforcer::new(config),
1797 estimator: TokenEstimator::default(),
1798 vector_index: Some(vector_index),
1799 embedding_provider: Some(embedding_provider),
1800 }
1801 }
1802
1803 pub fn set_embedding_provider(&mut self, provider: std::sync::Arc<dyn EmbeddingProvider>) {
1805 self.embedding_provider = Some(provider);
1806 }
1807
1808 pub fn set_vector_index(&mut self, index: std::sync::Arc<dyn VectorIndex>) {
1810 self.vector_index = Some(index);
1811 }
1812
1813 pub fn execute(
1815 &mut self,
1816 query: &ContextSelectQuery,
1817 ) -> Result<ContextQueryResult, ContextQueryError> {
1818 self.validate_session(&query.session)?;
1820
1821 self.context.audit.push(crate::agent_context::AuditEntry {
1823 timestamp: std::time::SystemTime::now(),
1824 operation: AuditOperation::DbQuery,
1825 resource: format!("CONTEXT SELECT {}", query.output_name),
1826 result: crate::agent_context::AuditResult::Success,
1827 metadata: std::collections::HashMap::new(),
1828 });
1829
1830 let resolved_sections = self.resolve_sections(&query.sections)?;
1832
1833 for section in &resolved_sections {
1835 self.check_section_permissions(section)?;
1836 }
1837
1838 let mut section_contents: Vec<(ContextSection, String)> = Vec::new();
1840 for section in &resolved_sections {
1841 let content = self.execute_section_content(section, query.options.token_limit)?;
1842 section_contents.push((section.clone(), content));
1843 }
1844
1845 let budget_sections: Vec<BudgetSection> = section_contents
1847 .iter()
1848 .map(|(section, content)| {
1849 let estimated = self.estimator.estimate_text(content);
1850 let minimum = if query.options.truncation == TruncationStrategy::Fail {
1852 None
1853 } else {
1854 Some(estimated.min(100).max(estimated / 10))
1855 };
1856 BudgetSection {
1857 name: section.name.clone(),
1858 estimated_tokens: estimated,
1859 minimum_tokens: minimum,
1860 priority: section.priority,
1861 required: section.priority == 0, weight: 1.0,
1863 }
1864 })
1865 .collect();
1866
1867 let allocation = self.budget_enforcer.allocate_sections(&budget_sections);
1869
1870 let mut result = ContextQueryResult::new(query.output_name.clone());
1872 result.format = query.options.format;
1873 result.allocation_explain = Some(allocation.explain.clone());
1874
1875 for (section, content) in section_contents.iter() {
1877 if allocation.full_sections.contains(§ion.name) {
1878 let tokens = self.estimator.estimate_text(content);
1879 result.sections.push(SectionResult {
1880 name: section.name.clone(),
1881 priority: section.priority,
1882 content: content.clone(),
1883 tokens,
1884 tokens_used: tokens,
1885 truncated: false,
1886 row_count: 0,
1887 });
1888 }
1889 }
1890
1891 for (section_name, _original, truncated_to) in &allocation.truncated_sections {
1893 if let Some((section, content)) = section_contents
1894 .iter()
1895 .find(|(s, _)| &s.name == section_name)
1896 {
1897 let truncated = self.estimator.truncate_to_tokens(content, *truncated_to);
1899 let actual_tokens = self.estimator.estimate_text(&truncated);
1900 result.sections.push(SectionResult {
1901 name: section.name.clone(),
1902 priority: section.priority,
1903 content: truncated,
1904 tokens: actual_tokens,
1905 tokens_used: actual_tokens,
1906 truncated: true,
1907 row_count: 0,
1908 });
1909 }
1910 }
1911
1912 result.sections.sort_by_key(|s| s.priority);
1914
1915 result.total_tokens = allocation.tokens_allocated;
1916 result.token_limit = query.options.token_limit;
1917
1918 self.context
1920 .consume_budget(result.total_tokens as u64, 0)
1921 .map_err(|e| ContextQueryError::BudgetExhausted(e.to_string()))?;
1922
1923 Ok(result)
1924 }
1925
1926 pub fn execute_explain(
1928 &mut self,
1929 query: &ContextSelectQuery,
1930 ) -> Result<(ContextQueryResult, String), ContextQueryError> {
1931 let result = self.execute(query)?;
1932 let explain = result
1933 .allocation_explain
1934 .as_ref()
1935 .map(|decisions| {
1936 use crate::token_budget::BudgetAllocation;
1937 let allocation = BudgetAllocation {
1938 full_sections: result
1939 .sections
1940 .iter()
1941 .filter(|s| !s.truncated)
1942 .map(|s| s.name.clone())
1943 .collect(),
1944 truncated_sections: result
1945 .sections
1946 .iter()
1947 .filter(|s| s.truncated)
1948 .map(|s| (s.name.clone(), s.tokens, s.tokens_used))
1949 .collect(),
1950 dropped_sections: Vec::new(),
1951 tokens_allocated: result.total_tokens,
1952 tokens_remaining: result.token_limit.saturating_sub(result.total_tokens),
1953 explain: decisions.clone(),
1954 };
1955 allocation.explain_text()
1956 })
1957 .unwrap_or_else(|| "No allocation explain available".to_string());
1958 Ok((result, explain))
1959 }
1960
1961 fn validate_session(&self, session_ref: &SessionReference) -> Result<(), ContextQueryError> {
1963 match session_ref {
1964 SessionReference::Session(sid) => {
1965 if sid.starts_with('$') {
1967 return Ok(());
1968 }
1969 if sid != &self.context.session_id && sid != "*" {
1971 return Err(ContextQueryError::SessionMismatch {
1972 expected: sid.clone(),
1973 actual: self.context.session_id.clone(),
1974 });
1975 }
1976 }
1977 SessionReference::Agent(aid) => {
1978 if let Some(ContextValue::String(agent_id)) = self.context.peek_var("agent_id")
1980 && aid != agent_id
1981 && aid != "*"
1982 {
1983 return Err(ContextQueryError::SessionMismatch {
1984 expected: aid.clone(),
1985 actual: agent_id.clone(),
1986 });
1987 }
1988 }
1989 SessionReference::None => {}
1990 }
1991 Ok(())
1992 }
1993
1994 fn resolve_sections(
1996 &self,
1997 sections: &[ContextSection],
1998 ) -> Result<Vec<ContextSection>, ContextQueryError> {
1999 let mut resolved = Vec::new();
2000
2001 for section in sections {
2002 let mut resolved_section = section.clone();
2003
2004 resolved_section.content = match §ion.content {
2006 SectionContent::Literal { value } => {
2007 let resolved_value = self.resolve_variables(value);
2008 SectionContent::Literal {
2009 value: resolved_value,
2010 }
2011 }
2012 SectionContent::Variable { name } => {
2013 if let Some(value) = self.context.peek_var(name) {
2014 SectionContent::Literal {
2015 value: value.to_string(),
2016 }
2017 } else {
2018 return Err(ContextQueryError::VariableNotFound(name.clone()));
2019 }
2020 }
2021 SectionContent::Search {
2022 collection,
2023 query,
2024 top_k,
2025 min_score,
2026 } => {
2027 let resolved_query = match query {
2028 SimilarityQuery::Variable(var) => {
2029 if let Some(value) = self.context.peek_var(var) {
2030 match value {
2031 ContextValue::String(s) => SimilarityQuery::Text(s.clone()),
2032 ContextValue::List(l) => {
2033 let vec: Vec<f32> = l
2034 .iter()
2035 .filter_map(|v| match v {
2036 ContextValue::Number(n) => Some(*n as f32),
2037 _ => None,
2038 })
2039 .collect();
2040 SimilarityQuery::Embedding(vec)
2041 }
2042 _ => {
2043 return Err(ContextQueryError::InvalidVariableType {
2044 variable: var.clone(),
2045 expected: "string or vector".to_string(),
2046 });
2047 }
2048 }
2049 } else {
2050 return Err(ContextQueryError::VariableNotFound(var.clone()));
2051 }
2052 }
2053 other => other.clone(),
2054 };
2055 SectionContent::Search {
2056 collection: collection.clone(),
2057 query: resolved_query,
2058 top_k: *top_k,
2059 min_score: *min_score,
2060 }
2061 }
2062 other => other.clone(),
2063 };
2064
2065 resolved.push(resolved_section);
2066 }
2067
2068 Ok(resolved)
2069 }
2070
2071 fn resolve_variables(&self, input: &str) -> String {
2073 self.context.substitute_vars(input)
2074 }
2075
2076 fn check_section_permissions(&self, section: &ContextSection) -> Result<(), ContextQueryError> {
2078 match §ion.content {
2079 SectionContent::Get { path } => {
2080 let path_str = path.to_path_string();
2082 if path_str.starts_with('/') {
2083 self.context
2084 .check_fs_permission(&path_str, AuditOperation::FsRead)
2085 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2086 } else {
2087 let table = path
2089 .segments
2090 .first()
2091 .ok_or_else(|| ContextQueryError::InvalidPath("empty path".to_string()))?;
2092 self.context
2093 .check_db_permission(table, AuditOperation::DbQuery)
2094 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2095 }
2096 }
2097 SectionContent::Last { table, .. } | SectionContent::Select { table, .. } => {
2098 self.context
2099 .check_db_permission(table, AuditOperation::DbQuery)
2100 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2101 }
2102 SectionContent::Search { collection, .. } => {
2103 self.context
2104 .check_db_permission(collection, AuditOperation::DbQuery)
2105 .map_err(|e| ContextQueryError::PermissionDenied(e.to_string()))?;
2106 }
2107 SectionContent::Literal { .. } | SectionContent::Variable { .. } => {
2108 }
2110 SectionContent::ToolRegistry { .. } | SectionContent::ToolCalls { .. } => {
2111 }
2113 }
2114 Ok(())
2115 }
2116
2117 fn execute_section_content(
2119 &self,
2120 section: &ContextSection,
2121 _budget: usize,
2122 ) -> Result<String, ContextQueryError> {
2123 match §ion.content {
2126 SectionContent::Literal { value } => Ok(value.clone()),
2127 SectionContent::Variable { name } => self
2128 .context
2129 .peek_var(name)
2130 .map(|v| v.to_string())
2131 .ok_or_else(|| ContextQueryError::VariableNotFound(name.clone())),
2132 SectionContent::Get { path } => {
2133 Ok(format!(
2135 "[{}: path={}]",
2136 section.name,
2137 path.to_path_string()
2138 ))
2139 }
2140 SectionContent::Last { count, table, .. } => {
2141 Ok(format!("[{}: last {} from {}]", section.name, count, table))
2143 }
2144 SectionContent::Search {
2145 collection,
2146 query: similarity_query,
2147 top_k,
2148 min_score,
2149 } => {
2150 match &self.vector_index {
2152 Some(index) => {
2153 let results = match similarity_query {
2155 SimilarityQuery::Embedding(emb) => {
2156 index.search_by_embedding(collection, emb, *top_k, *min_score)
2157 }
2158 SimilarityQuery::Text(text) => {
2159 self.search_by_text_with_embedding(
2161 index, collection, text, *top_k, *min_score,
2162 )
2163 }
2164 SimilarityQuery::Variable(var_name) => {
2165 match self.context.peek_var(var_name) {
2167 Some(ContextValue::String(text)) => self
2168 .search_by_text_with_embedding(
2169 index, collection, text, *top_k, *min_score,
2170 ),
2171 Some(ContextValue::List(list)) => {
2172 let embedding: Result<Vec<f32>, _> = list
2174 .iter()
2175 .map(|v| match v {
2176 ContextValue::Number(n) => Ok(*n as f32),
2177 ContextValue::String(s) => {
2178 s.parse::<f32>().map_err(|_| "not a number")
2179 }
2180 _ => Err("not a number"),
2181 })
2182 .collect();
2183
2184 match embedding {
2185 Ok(emb) => index.search_by_embedding(
2186 collection, &emb, *top_k, *min_score,
2187 ),
2188 Err(_) => {
2189 Err("Variable is not a valid embedding vector"
2190 .to_string())
2191 }
2192 }
2193 }
2194 _ => Err(format!(
2195 "Variable '{}' not found or has wrong type",
2196 var_name
2197 )),
2198 }
2199 }
2200 };
2201
2202 match results {
2203 Ok(search_results) => {
2204 self.format_search_results(§ion.name, &search_results)
2206 }
2207 Err(e) => {
2208 Ok(format!("[{}: search error: {}]", section.name, e))
2210 }
2211 }
2212 }
2213 None => {
2214 Ok(format!(
2216 "[{}: search {} top {}]",
2217 section.name, collection, top_k
2218 ))
2219 }
2220 }
2221 }
2222 SectionContent::Select { table, limit, .. } => {
2223 let limit_str = limit.map(|l| format!(" limit {}", l)).unwrap_or_default();
2225 Ok(format!(
2226 "[{}: select from {}{}]",
2227 section.name, table, limit_str
2228 ))
2229 }
2230 SectionContent::ToolRegistry {
2231 include,
2232 exclude,
2233 include_schema,
2234 } => {
2235 self.format_tool_registry(include, exclude, *include_schema)
2237 }
2238 SectionContent::ToolCalls {
2239 count,
2240 tool_filter,
2241 status_filter,
2242 include_outputs,
2243 } => {
2244 self.format_tool_calls(
2246 *count,
2247 tool_filter.as_deref(),
2248 status_filter.as_deref(),
2249 *include_outputs,
2250 )
2251 }
2252 }
2253 }
2254
2255 fn format_tool_registry(
2257 &self,
2258 include: &[String],
2259 exclude: &[String],
2260 include_schema: bool,
2261 ) -> Result<String, ContextQueryError> {
2262 use std::fmt::Write;
2263
2264 let tools = &self.context.tool_registry;
2266 let mut output = String::new();
2267
2268 writeln!(output, "[tool_registry ({} tools)]", tools.len())
2269 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2270
2271 for tool in tools {
2272 if !include.is_empty() && !include.contains(&tool.name) {
2274 continue;
2275 }
2276 if exclude.contains(&tool.name) {
2277 continue;
2278 }
2279
2280 writeln!(output, " [{}]", tool.name)
2281 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2282 writeln!(output, " description = {:?}", tool.description)
2283 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2284
2285 if include_schema {
2286 if let Some(schema) = &tool.parameters_schema {
2287 writeln!(output, " parameters = {}", schema)
2288 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2289 }
2290 }
2291 }
2292
2293 Ok(output)
2294 }
2295
2296 fn format_tool_calls(
2298 &self,
2299 count: usize,
2300 tool_filter: Option<&str>,
2301 status_filter: Option<&str>,
2302 include_outputs: bool,
2303 ) -> Result<String, ContextQueryError> {
2304 use std::fmt::Write;
2305
2306 let calls = &self.context.tool_calls;
2308 let mut output = String::new();
2309
2310 let filtered: Vec<_> = calls
2312 .iter()
2313 .filter(|call| {
2314 tool_filter.map(|f| call.tool_name == f).unwrap_or(true)
2315 && status_filter
2316 .map(|s| match s {
2317 "success" => call.result.is_some() && call.error.is_none(),
2318 "error" => call.error.is_some(),
2319 "pending" => call.result.is_none() && call.error.is_none(),
2320 _ => true,
2321 })
2322 .unwrap_or(true)
2323 })
2324 .rev() .take(count)
2326 .collect();
2327
2328 writeln!(output, "[tool_calls ({} calls)]", filtered.len())
2329 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2330
2331 for call in filtered {
2332 writeln!(output, " [call {}]", call.call_id)
2333 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2334 writeln!(output, " tool = {:?}", call.tool_name)
2335 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2336 writeln!(output, " arguments = {:?}", call.arguments)
2337 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2338
2339 if include_outputs {
2340 if let Some(result) = &call.result {
2341 writeln!(output, " result = {:?}", result)
2342 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2343 }
2344 if let Some(error) = &call.error {
2345 writeln!(output, " error = {:?}", error)
2346 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2347 }
2348 }
2349 }
2350
2351 Ok(output)
2352 }
2353
2354 fn search_by_text_with_embedding(
2360 &self,
2361 index: &std::sync::Arc<dyn VectorIndex>,
2362 collection: &str,
2363 text: &str,
2364 k: usize,
2365 min_score: Option<f32>,
2366 ) -> Result<Vec<VectorSearchResult>, String> {
2367 match &self.embedding_provider {
2368 Some(provider) => {
2369 let embedding = provider.embed_text(text)?;
2371 index.search_by_embedding(collection, &embedding, k, min_score)
2373 }
2374 None => {
2375 index.search_by_text(collection, text, k, min_score)
2377 }
2378 }
2379 }
2380
2381 fn format_search_results(
2383 &self,
2384 section_name: &str,
2385 results: &[VectorSearchResult],
2386 ) -> Result<String, ContextQueryError> {
2387 use std::fmt::Write;
2388
2389 let mut output = String::new();
2390 writeln!(output, "[{} ({} results)]", section_name, results.len())
2391 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2392
2393 for (i, result) in results.iter().enumerate() {
2394 writeln!(output, " [result {} score={:.4}]", i + 1, result.score)
2395 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2396 writeln!(output, " id = {}", result.id)
2397 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2398
2399 for line in result.content.lines() {
2401 writeln!(output, " {}", line)
2402 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2403 }
2404
2405 if !result.metadata.is_empty() {
2407 writeln!(output, " [metadata]")
2408 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2409 for (key, value) in &result.metadata {
2410 writeln!(output, " {} = {:?}", key, value)
2411 .map_err(|e| ContextQueryError::FormatError(e.to_string()))?;
2412 }
2413 }
2414 }
2415
2416 Ok(output)
2417 }
2418
2419 #[allow(dead_code)]
2421 fn truncate_content(
2422 &self,
2423 content: &str,
2424 max_tokens: usize,
2425 strategy: TruncationStrategy,
2426 ) -> String {
2427 let max_chars = max_tokens * 4;
2429
2430 if content.len() <= max_chars {
2431 return content.to_string();
2432 }
2433
2434 match strategy {
2435 TruncationStrategy::TailDrop => {
2436 let mut result: String = content.chars().take(max_chars - 3).collect();
2437 result.push_str("...");
2438 result
2439 }
2440 TruncationStrategy::HeadDrop => {
2441 let skip = content.len() - max_chars + 3;
2442 let mut result = "...".to_string();
2443 result.extend(content.chars().skip(skip));
2444 result
2445 }
2446 TruncationStrategy::Proportional => {
2447 let quarter = max_chars / 4;
2449 let first: String = content.chars().take(quarter).collect();
2450 let last: String = content
2451 .chars()
2452 .skip(content.len().saturating_sub(quarter))
2453 .collect();
2454 format!("{}...{}...", first, last)
2455 }
2456 TruncationStrategy::Fail => {
2457 content.to_string() }
2459 }
2460 }
2461
2462 pub fn get_session_context(&self) -> HashMap<String, String> {
2464 self.context
2465 .variables
2466 .iter()
2467 .map(|(k, v)| (k.clone(), v.to_string()))
2468 .collect()
2469 }
2470
2471 pub fn set_variable(&mut self, name: &str, value: ContextValue) {
2473 self.context.set_var(name, value);
2474 }
2475
2476 pub fn remaining_budget(&self) -> u64 {
2478 self.context
2479 .budget
2480 .max_tokens
2481 .map(|max| max.saturating_sub(self.context.budget.tokens_used))
2482 .unwrap_or(u64::MAX)
2483 }
2484}
2485
2486#[derive(Debug, Clone)]
2488pub struct ContextQueryResult {
2489 pub output_name: String,
2491 pub sections: Vec<SectionResult>,
2493 pub total_tokens: usize,
2495 pub token_limit: usize,
2497 pub format: OutputFormat,
2499 pub allocation_explain: Option<Vec<crate::token_budget::AllocationDecision>>,
2501}
2502
2503impl ContextQueryResult {
2504 fn new(output_name: String) -> Self {
2505 Self {
2506 output_name,
2507 sections: Vec::new(),
2508 total_tokens: 0,
2509 token_limit: 0,
2510 format: OutputFormat::Soch,
2511 allocation_explain: None,
2512 }
2513 }
2514
2515 pub fn render(&self) -> String {
2517 let mut output = String::new();
2518
2519 match self.format {
2520 OutputFormat::Soch => {
2521 output.push_str(&format!("{}[{}]:\n", self.output_name, self.sections.len()));
2523 for section in &self.sections {
2524 output.push_str(&format!(
2525 " {}[{}{}]:\n",
2526 section.name,
2527 section.tokens_used,
2528 if section.truncated { "T" } else { "" }
2529 ));
2530 for line in section.content.lines() {
2531 output.push_str(&format!(" {}\n", line));
2532 }
2533 }
2534 }
2535 OutputFormat::Json => {
2536 output.push_str("{\n");
2537 output.push_str(&format!(" \"name\": \"{}\",\n", self.output_name));
2538 output.push_str(&format!(" \"total_tokens\": {},\n", self.total_tokens));
2539 output.push_str(" \"sections\": [\n");
2540 for (i, section) in self.sections.iter().enumerate() {
2541 output.push_str(&format!(" {{\"name\": \"{}\", \"tokens\": {}, \"truncated\": {}, \"content\": \"{}\"}}",
2542 section.name,
2543 section.tokens_used,
2544 section.truncated,
2545 section.content.replace('"', "\\\"").replace('\n', "\\n")
2546 ));
2547 if i < self.sections.len() - 1 {
2548 output.push(',');
2549 }
2550 output.push('\n');
2551 }
2552 output.push_str(" ]\n}");
2553 }
2554 OutputFormat::Markdown => {
2555 output.push_str(&format!("# {}\n\n", self.output_name));
2556 output.push_str(&format!(
2557 "*Tokens: {}/{}*\n\n",
2558 self.total_tokens, self.token_limit
2559 ));
2560 for section in &self.sections {
2561 output.push_str(&format!("## {}", section.name));
2562 if section.truncated {
2563 output.push_str(" *(truncated)*");
2564 }
2565 output.push_str("\n\n");
2566 output.push_str(§ion.content);
2567 output.push_str("\n\n");
2568 }
2569 }
2570 }
2571
2572 output
2573 }
2574
2575 pub fn utilization(&self) -> f64 {
2577 if self.token_limit == 0 {
2578 return 0.0;
2579 }
2580 (self.total_tokens as f64 / self.token_limit as f64) * 100.0
2581 }
2582
2583 pub fn has_truncation(&self) -> bool {
2585 self.sections.iter().any(|s| s.truncated)
2586 }
2587}
2588
2589#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
2591pub struct SectionPriority(pub i32);
2592
2593impl SectionPriority {
2594 pub const CRITICAL: SectionPriority = SectionPriority(-100);
2595 pub const SYSTEM: SectionPriority = SectionPriority(-1);
2596 pub const USER: SectionPriority = SectionPriority(0);
2597 pub const HISTORY: SectionPriority = SectionPriority(1);
2598 pub const KNOWLEDGE: SectionPriority = SectionPriority(2);
2599 pub const SUPPLEMENTARY: SectionPriority = SectionPriority(10);
2600}
2601
2602pub struct ContextQueryBuilder {
2608 output_name: String,
2609 session: SessionReference,
2610 options: ContextQueryOptions,
2611 sections: Vec<ContextSection>,
2612}
2613
2614impl ContextQueryBuilder {
2615 pub fn new(output_name: &str) -> Self {
2617 Self {
2618 output_name: output_name.to_string(),
2619 session: SessionReference::None,
2620 options: ContextQueryOptions::default(),
2621 sections: Vec::new(),
2622 }
2623 }
2624
2625 pub fn from_session(mut self, session_id: &str) -> Self {
2627 self.session = SessionReference::Session(session_id.to_string());
2628 self
2629 }
2630
2631 pub fn from_agent(mut self, agent_id: &str) -> Self {
2633 self.session = SessionReference::Agent(agent_id.to_string());
2634 self
2635 }
2636
2637 pub fn with_token_limit(mut self, limit: usize) -> Self {
2639 self.options.token_limit = limit;
2640 self
2641 }
2642
2643 pub fn include_schema(mut self, include: bool) -> Self {
2645 self.options.include_schema = include;
2646 self
2647 }
2648
2649 pub fn format(mut self, format: OutputFormat) -> Self {
2651 self.options.format = format;
2652 self
2653 }
2654
2655 pub fn truncation(mut self, strategy: TruncationStrategy) -> Self {
2657 self.options.truncation = strategy;
2658 self
2659 }
2660
2661 pub fn get(mut self, name: &str, priority: i32, path: &str) -> Self {
2663 let path_expr = PathExpression::parse(path).unwrap_or(PathExpression {
2664 segments: vec![path.to_string()],
2665 fields: vec![],
2666 all_fields: true,
2667 });
2668
2669 self.sections.push(ContextSection {
2670 name: name.to_string(),
2671 priority,
2672 content: SectionContent::Get { path: path_expr },
2673 transform: None,
2674 });
2675 self
2676 }
2677
2678 pub fn last(mut self, name: &str, priority: i32, count: usize, table: &str) -> Self {
2680 self.sections.push(ContextSection {
2681 name: name.to_string(),
2682 priority,
2683 content: SectionContent::Last {
2684 count,
2685 table: table.to_string(),
2686 where_clause: None,
2687 },
2688 transform: None,
2689 });
2690 self
2691 }
2692
2693 pub fn search(
2695 mut self,
2696 name: &str,
2697 priority: i32,
2698 collection: &str,
2699 query_var: &str,
2700 top_k: usize,
2701 ) -> Self {
2702 self.sections.push(ContextSection {
2703 name: name.to_string(),
2704 priority,
2705 content: SectionContent::Search {
2706 collection: collection.to_string(),
2707 query: SimilarityQuery::Variable(query_var.to_string()),
2708 top_k,
2709 min_score: None,
2710 },
2711 transform: None,
2712 });
2713 self
2714 }
2715
2716 pub fn literal(mut self, name: &str, priority: i32, value: &str) -> Self {
2718 self.sections.push(ContextSection {
2719 name: name.to_string(),
2720 priority,
2721 content: SectionContent::Literal {
2722 value: value.to_string(),
2723 },
2724 transform: None,
2725 });
2726 self
2727 }
2728
2729 pub fn build(self) -> ContextSelectQuery {
2731 ContextSelectQuery {
2732 output_name: self.output_name,
2733 session: self.session,
2734 options: self.options,
2735 sections: self.sections,
2736 }
2737 }
2738}
2739
2740#[cfg(test)]
2745mod tests {
2746 use super::*;
2747
2748 #[test]
2749 fn test_path_expression_simple() {
2750 let path = PathExpression::parse("user.profile").unwrap();
2751 assert_eq!(path.segments, vec!["user", "profile"]);
2752 assert!(path.all_fields);
2753 }
2754
2755 #[test]
2756 fn test_path_expression_with_fields() {
2757 let path = PathExpression::parse("user.profile.{name, email}").unwrap();
2758 assert_eq!(path.segments, vec!["user", "profile"]);
2759 assert_eq!(path.fields, vec!["name", "email"]);
2760 assert!(!path.all_fields);
2761 }
2762
2763 #[test]
2764 fn test_path_expression_glob() {
2765 let path = PathExpression::parse("user.**").unwrap();
2766 assert_eq!(path.segments, vec!["user"]);
2767 assert!(path.all_fields);
2768 }
2769
2770 #[test]
2771 fn test_parse_simple_query() {
2772 let query = r#"
2773 CONTEXT SELECT prompt_context
2774 FROM session($SESSION_ID)
2775 WITH (token_limit = 2048, include_schema = true)
2776 SECTIONS (
2777 USER PRIORITY 0: GET user.profile.{name, preferences}
2778 )
2779 "#;
2780
2781 let mut parser = ContextQueryParser::new(query);
2782 let result = parser.parse().unwrap();
2783
2784 assert_eq!(result.output_name, "prompt_context");
2785 assert!(matches!(result.session, SessionReference::Session(s) if s == "SESSION_ID"));
2786 assert_eq!(result.options.token_limit, 2048);
2787 assert!(result.options.include_schema);
2788 assert_eq!(result.sections.len(), 1);
2789 assert_eq!(result.sections[0].name, "USER");
2790 assert_eq!(result.sections[0].priority, 0);
2791 }
2792
2793 #[test]
2794 fn test_parse_multiple_sections() {
2795 let query = r#"
2796 CONTEXT SELECT context
2797 SECTIONS (
2798 A PRIORITY 0: "literal value",
2799 B PRIORITY 1: LAST 10 FROM logs,
2800 C PRIORITY 2: SEARCH docs BY SIMILARITY($query) TOP 5
2801 )
2802 "#;
2803
2804 let mut parser = ContextQueryParser::new(query);
2805 let result = parser.parse().unwrap();
2806
2807 assert_eq!(result.sections.len(), 3);
2808
2809 assert_eq!(result.sections[0].name, "A");
2811 assert!(
2812 matches!(&result.sections[0].content, SectionContent::Literal { value } if value == "literal value")
2813 );
2814
2815 assert_eq!(result.sections[1].name, "B");
2817 assert!(
2818 matches!(&result.sections[1].content, SectionContent::Last { count: 10, table, .. } if table == "logs")
2819 );
2820
2821 assert_eq!(result.sections[2].name, "C");
2823 assert!(
2824 matches!(&result.sections[2].content, SectionContent::Search { collection, top_k: 5, .. } if collection == "docs")
2825 );
2826 }
2827
2828 #[test]
2829 fn test_builder() {
2830 let query = ContextQueryBuilder::new("prompt")
2831 .from_session("sess123")
2832 .with_token_limit(4096)
2833 .include_schema(false)
2834 .get("USER", 0, "user.profile.{name, email}")
2835 .last("HISTORY", 1, 20, "events")
2836 .search("DOCS", 2, "knowledge_base", "query_embedding", 10)
2837 .literal("SYSTEM", -1, "You are a helpful assistant")
2838 .build();
2839
2840 assert_eq!(query.output_name, "prompt");
2841 assert_eq!(query.options.token_limit, 4096);
2842 assert!(!query.options.include_schema);
2843 assert_eq!(query.sections.len(), 4);
2844
2845 let system = query.sections.iter().find(|s| s.name == "SYSTEM").unwrap();
2847 assert_eq!(system.priority, -1);
2848 }
2849
2850 #[test]
2851 fn test_output_format() {
2852 let query = r#"
2853 CONTEXT SELECT ctx
2854 WITH (format = markdown)
2855 SECTIONS ()
2856 "#;
2857
2858 let mut parser = ContextQueryParser::new(query);
2859 let result = parser.parse().unwrap();
2860
2861 assert_eq!(result.options.format, OutputFormat::Markdown);
2862 }
2863
2864 #[test]
2865 fn test_truncation_strategy() {
2866 let query = r#"
2867 CONTEXT SELECT ctx
2868 WITH (truncation = proportional)
2869 SECTIONS ()
2870 "#;
2871
2872 let mut parser = ContextQueryParser::new(query);
2873 let result = parser.parse().unwrap();
2874
2875 assert_eq!(result.options.truncation, TruncationStrategy::Proportional);
2876 }
2877
2878 #[test]
2883 fn test_simple_vector_index_creation() {
2884 let index = SimpleVectorIndex::new();
2885 index.create_collection("test", 3);
2886
2887 let stats = index.stats("test");
2888 assert!(stats.is_some());
2889 let stats = stats.unwrap();
2890 assert_eq!(stats.dimension, 3);
2891 assert_eq!(stats.vector_count, 0);
2892 assert_eq!(stats.metric, "cosine");
2893 }
2894
2895 #[test]
2896 fn test_simple_vector_index_insert_and_search() {
2897 let index = SimpleVectorIndex::new();
2898 index.create_collection("docs", 3);
2899
2900 index
2902 .insert(
2903 "docs",
2904 "doc1".to_string(),
2905 vec![1.0, 0.0, 0.0],
2906 "Document about cats".to_string(),
2907 HashMap::new(),
2908 )
2909 .unwrap();
2910
2911 index
2912 .insert(
2913 "docs",
2914 "doc2".to_string(),
2915 vec![0.9, 0.1, 0.0],
2916 "Document about dogs".to_string(),
2917 HashMap::new(),
2918 )
2919 .unwrap();
2920
2921 index
2922 .insert(
2923 "docs",
2924 "doc3".to_string(),
2925 vec![0.0, 0.0, 1.0],
2926 "Document about cars".to_string(),
2927 HashMap::new(),
2928 )
2929 .unwrap();
2930
2931 let results = index
2933 .search_by_embedding("docs", &[1.0, 0.0, 0.0], 2, None)
2934 .unwrap();
2935
2936 assert_eq!(results.len(), 2);
2937 assert_eq!(results[0].id, "doc1"); assert!((results[0].score - 1.0).abs() < 0.001);
2939 assert_eq!(results[1].id, "doc2"); assert!(results[1].score > 0.9); }
2942
2943 #[test]
2944 fn test_simple_vector_index_min_score_filter() {
2945 let index = SimpleVectorIndex::new();
2946 index.create_collection("docs", 3);
2947
2948 index
2949 .insert(
2950 "docs",
2951 "a".to_string(),
2952 vec![1.0, 0.0, 0.0],
2953 "A".to_string(),
2954 HashMap::new(),
2955 )
2956 .unwrap();
2957 index
2958 .insert(
2959 "docs",
2960 "b".to_string(),
2961 vec![0.0, 1.0, 0.0],
2962 "B".to_string(),
2963 HashMap::new(),
2964 )
2965 .unwrap();
2966 index
2967 .insert(
2968 "docs",
2969 "c".to_string(),
2970 vec![0.0, 0.0, 1.0],
2971 "C".to_string(),
2972 HashMap::new(),
2973 )
2974 .unwrap();
2975
2976 let results = index
2978 .search_by_embedding("docs", &[1.0, 0.0, 0.0], 10, Some(0.9))
2979 .unwrap();
2980
2981 assert_eq!(results.len(), 1);
2982 assert_eq!(results[0].id, "a");
2983 }
2984
2985 #[test]
2986 fn test_simple_vector_index_dimension_mismatch() {
2987 let index = SimpleVectorIndex::new();
2988 index.create_collection("docs", 3);
2989
2990 let result = index.insert(
2991 "docs",
2992 "bad".to_string(),
2993 vec![1.0, 0.0], "Content".to_string(),
2995 HashMap::new(),
2996 );
2997
2998 assert!(result.is_err());
2999 assert!(result.unwrap_err().contains("dimension mismatch"));
3000 }
3001
3002 #[test]
3003 fn test_simple_vector_index_nonexistent_collection() {
3004 let index = SimpleVectorIndex::new();
3005
3006 let result = index.search_by_embedding("nonexistent", &[1.0], 1, None);
3007 assert!(result.is_err());
3008 assert!(result.unwrap_err().contains("not found"));
3009 }
3010
3011 #[test]
3012 fn test_vector_index_with_metadata() {
3013 let index = SimpleVectorIndex::new();
3014 index.create_collection("docs", 2);
3015
3016 let mut metadata = HashMap::new();
3017 metadata.insert("author".to_string(), SochValue::Text("Alice".to_string()));
3018 metadata.insert("year".to_string(), SochValue::Int(2024));
3019
3020 index
3021 .insert(
3022 "docs",
3023 "doc1".to_string(),
3024 vec![1.0, 0.0],
3025 "Document content".to_string(),
3026 metadata,
3027 )
3028 .unwrap();
3029
3030 let results = index
3031 .search_by_embedding("docs", &[1.0, 0.0], 1, None)
3032 .unwrap();
3033
3034 assert_eq!(results.len(), 1);
3035 assert!(results[0].metadata.contains_key("author"));
3036 assert!(results[0].metadata.contains_key("year"));
3037 }
3038
3039 #[test]
3040 fn test_vector_index_text_search_unsupported() {
3041 let index = SimpleVectorIndex::new();
3042 index.create_collection("docs", 2);
3043
3044 let result = index.search_by_text("docs", "hello", 5, None);
3046 assert!(result.is_err());
3047 assert!(result.unwrap_err().contains("embedding model"));
3048 }
3049}