1#![allow(unexpected_cfgs)]
19
20use std::collections::BinaryHeap;
53use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
54use std::sync::Arc;
55
56use crate::context_query::{
57 ContextSection, ContextSelectQuery, SectionContent, SimilarityQuery,
58 TruncationStrategy, OutputFormat, VectorIndex,
59};
60use crate::token_budget::TokenEstimator;
61use crate::soch_ql::SochValue;
62
63#[derive(Debug, Clone)]
69pub enum SectionChunk {
70 SectionHeader {
72 name: String,
73 priority: i32,
74 estimated_tokens: usize,
75 },
76
77 RowBlock {
79 section_name: String,
80 rows: Vec<Vec<SochValue>>,
81 columns: Vec<String>,
82 tokens: usize,
83 },
84
85 SearchResultBlock {
87 section_name: String,
88 results: Vec<StreamingSearchResult>,
89 tokens: usize,
90 },
91
92 ContentBlock {
94 section_name: String,
95 content: String,
96 tokens: usize,
97 },
98
99 SectionComplete {
101 name: String,
102 total_tokens: usize,
103 truncated: bool,
104 },
105
106 StreamComplete {
108 total_tokens: usize,
109 sections_included: Vec<String>,
110 sections_dropped: Vec<String>,
111 },
112
113 Error {
115 section_name: Option<String>,
116 message: String,
117 },
118}
119
120#[derive(Debug, Clone)]
122pub struct StreamingSearchResult {
123 pub id: String,
124 pub score: f32,
125 pub content: String,
126}
127
128#[derive(Debug, Clone)]
130pub struct StreamingConfig {
131 pub token_limit: usize,
133
134 pub chunk_size: usize,
136
137 pub include_headers: bool,
139
140 pub format: OutputFormat,
142
143 pub truncation: TruncationStrategy,
145
146 pub parallel_execution: bool,
148
149 pub exact_tokens: bool,
151}
152
153impl Default for StreamingConfig {
154 fn default() -> Self {
155 Self {
156 token_limit: 4096,
157 chunk_size: 256,
158 include_headers: true,
159 format: OutputFormat::Soch,
160 truncation: TruncationStrategy::TailDrop,
161 parallel_execution: false,
162 exact_tokens: false,
163 }
164 }
165}
166
167#[derive(Debug)]
173pub struct RollingBudget {
174 limit: usize,
176
177 used: AtomicUsize,
179
180 exhausted: AtomicBool,
182}
183
184impl RollingBudget {
185 pub fn new(limit: usize) -> Self {
187 Self {
188 limit,
189 used: AtomicUsize::new(0),
190 exhausted: AtomicBool::new(false),
191 }
192 }
193
194 pub fn try_consume(&self, tokens: usize) -> usize {
197 if self.exhausted.load(Ordering::Acquire) {
198 return 0;
199 }
200
201 let mut current = self.used.load(Ordering::Acquire);
202 loop {
203 let remaining = self.limit.saturating_sub(current);
204 if remaining == 0 {
205 self.exhausted.store(true, Ordering::Release);
206 return 0;
207 }
208
209 let to_consume = tokens.min(remaining);
210 match self.used.compare_exchange_weak(
211 current,
212 current + to_consume,
213 Ordering::AcqRel,
214 Ordering::Acquire,
215 ) {
216 Ok(_) => {
217 if current + to_consume >= self.limit {
218 self.exhausted.store(true, Ordering::Release);
219 }
220 return to_consume;
221 }
222 Err(actual) => current = actual,
223 }
224 }
225 }
226
227 pub fn remaining(&self) -> usize {
229 self.limit.saturating_sub(self.used.load(Ordering::Acquire))
230 }
231
232 pub fn is_exhausted(&self) -> bool {
234 self.exhausted.load(Ordering::Acquire)
235 }
236
237 pub fn used(&self) -> usize {
239 self.used.load(Ordering::Acquire)
240 }
241}
242
243#[derive(Debug, Clone)]
249struct ScheduledSection {
250 priority: i32,
252
253 index: usize,
255
256 section: ContextSection,
258}
259
260impl Eq for ScheduledSection {}
261
262impl PartialEq for ScheduledSection {
263 fn eq(&self, other: &Self) -> bool {
264 self.priority == other.priority && self.index == other.index
265 }
266}
267
268impl Ord for ScheduledSection {
269 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
270 other.priority.cmp(&self.priority)
272 .then_with(|| other.index.cmp(&self.index))
273 }
274}
275
276impl PartialOrd for ScheduledSection {
277 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
278 Some(self.cmp(other))
279 }
280}
281
282pub struct StreamingContextExecutor<V: VectorIndex> {
288 estimator: TokenEstimator,
290
291 vector_index: Arc<V>,
293
294 budget: Arc<RollingBudget>,
296
297 config: StreamingConfig,
299}
300
301impl<V: VectorIndex> StreamingContextExecutor<V> {
302 pub fn new(
304 vector_index: Arc<V>,
305 config: StreamingConfig,
306 ) -> Self {
307 let budget = Arc::new(RollingBudget::new(config.token_limit));
308 Self {
309 estimator: TokenEstimator::new(),
310 vector_index,
311 budget,
312 config,
313 }
314 }
315
316 pub fn execute_streaming(
321 &self,
322 query: &ContextSelectQuery,
323 ) -> StreamingContextIter<'_, V> {
324 let mut priority_queue = BinaryHeap::new();
326 for (index, section) in query.sections.iter().enumerate() {
327 priority_queue.push(ScheduledSection {
328 priority: section.priority,
329 index,
330 section: section.clone(),
331 });
332 }
333
334 StreamingContextIter {
335 executor: self,
336 priority_queue,
337 current_section: None,
338 current_section_tokens: 0,
339 sections_included: Vec::new(),
340 sections_dropped: Vec::new(),
341 completed: false,
342 }
343 }
344
345 fn execute_section(
347 &self,
348 section: &ContextSection,
349 ) -> Vec<SectionChunk> {
350 let mut chunks = Vec::new();
351
352 if self.config.include_headers {
354 let header_tokens = self.estimator.estimate_text(&format!(
355 "## {} [priority={}]\n",
356 section.name, section.priority
357 ));
358
359 if self.budget.try_consume(header_tokens) > 0 {
360 chunks.push(SectionChunk::SectionHeader {
361 name: section.name.clone(),
362 priority: section.priority,
363 estimated_tokens: header_tokens,
364 });
365 } else {
366 return chunks; }
368 }
369
370 match §ion.content {
372 SectionContent::Literal { value } => {
373 self.execute_literal_section(section, value, &mut chunks);
374 }
375 SectionContent::Search { collection, query, top_k, min_score } => {
376 self.execute_search_section(section, collection, query, *top_k, *min_score, &mut chunks);
377 }
378 SectionContent::Get { path } => {
379 let content = format!("{}:**", path.to_path_string());
381 self.execute_literal_section(section, &content, &mut chunks);
382 }
383 SectionContent::Last { count, table, where_clause: _ } => {
384 let content = format!("{}[{}]:\n (recent entries)", table, count);
386 self.execute_literal_section(section, &content, &mut chunks);
387 }
388 SectionContent::Select { columns, table, where_clause: _, limit } => {
389 let content = format!(
391 "{}[{}]{{{}}}:\n (query results)",
392 table,
393 limit.unwrap_or(10),
394 columns.join(",")
395 );
396 self.execute_literal_section(section, &content, &mut chunks);
397 }
398 SectionContent::Variable { name } => {
399 let content = format!("${}", name);
400 self.execute_literal_section(section, &content, &mut chunks);
401 }
402 SectionContent::ToolRegistry { include, exclude: _, include_schema } => {
403 let content = if include.is_empty() {
404 format!("tools[*]{{schema={}}}", include_schema)
405 } else {
406 format!("tools[{}]{{schema={}}}", include.join(","), include_schema)
407 };
408 self.execute_literal_section(section, &content, &mut chunks);
409 }
410 SectionContent::ToolCalls { count, tool_filter, status_filter: _, include_outputs } => {
411 let filter_str = tool_filter.as_deref().unwrap_or("*");
412 let content = format!(
413 "tool_calls[{}]{{tool={},outputs={}}}",
414 count, filter_str, include_outputs
415 );
416 self.execute_literal_section(section, &content, &mut chunks);
417 }
418 }
419
420 chunks
421 }
422
423 fn execute_literal_section(
425 &self,
426 section: &ContextSection,
427 content: &str,
428 chunks: &mut Vec<SectionChunk>,
429 ) {
430 let _total_tokens = self.estimator.estimate_text(content);
432 let mut consumed = 0;
433 let mut offset = 0;
434 let content_bytes = content.as_bytes();
435
436 while offset < content_bytes.len() && !self.budget.is_exhausted() {
437 let approx_bytes = (self.config.chunk_size as f32 * 4.0) as usize;
439 let end = (offset + approx_bytes).min(content_bytes.len());
440
441 let break_point = if end < content_bytes.len() {
443 content[offset..end]
444 .rfind('\n')
445 .or_else(|| content[offset..end].rfind(' '))
446 .map(|p| offset + p + 1)
447 .unwrap_or(end)
448 } else {
449 end
450 };
451
452 let chunk_content = &content[offset..break_point];
453 let chunk_tokens = self.estimator.estimate_text(chunk_content);
454
455 let actual = self.budget.try_consume(chunk_tokens);
456 if actual == 0 {
457 break;
458 }
459
460 consumed += actual;
461 chunks.push(SectionChunk::ContentBlock {
462 section_name: section.name.clone(),
463 content: chunk_content.to_string(),
464 tokens: actual,
465 });
466
467 offset = break_point;
468 }
469
470 chunks.push(SectionChunk::SectionComplete {
472 name: section.name.clone(),
473 total_tokens: consumed,
474 truncated: offset < content_bytes.len(),
475 });
476 }
477
478 fn execute_search_section(
480 &self,
481 section: &ContextSection,
482 collection: &str,
483 query: &SimilarityQuery,
484 top_k: usize,
485 min_score: Option<f32>,
486 chunks: &mut Vec<SectionChunk>,
487 ) {
488 let results = match query {
490 SimilarityQuery::Embedding(embedding) => {
491 self.vector_index.search_by_embedding(collection, embedding, top_k, min_score)
492 }
493 SimilarityQuery::Text(text) => {
494 self.vector_index.search_by_text(collection, text, top_k, min_score)
495 }
496 SimilarityQuery::Variable(_) => {
497 Ok(Vec::new())
499 }
500 };
501
502 match results {
503 Ok(results) => {
504 let mut section_tokens = 0;
505 let mut batch = Vec::new();
506
507 for result in results {
508 if self.budget.is_exhausted() {
509 break;
510 }
511
512 let result_content = format!(
513 "[{:.3}] {}: {}\n",
514 result.score, result.id, result.content
515 );
516 let tokens = self.estimator.estimate_text(&result_content);
517
518 let actual = self.budget.try_consume(tokens);
519 if actual == 0 {
520 break;
521 }
522
523 section_tokens += actual;
524 batch.push(StreamingSearchResult {
525 id: result.id,
526 score: result.score,
527 content: result.content,
528 });
529
530 if batch.len() >= 5 {
532 chunks.push(SectionChunk::SearchResultBlock {
533 section_name: section.name.clone(),
534 results: std::mem::take(&mut batch),
535 tokens: section_tokens,
536 });
537 section_tokens = 0;
538 }
539 }
540
541 if !batch.is_empty() {
543 chunks.push(SectionChunk::SearchResultBlock {
544 section_name: section.name.clone(),
545 results: batch,
546 tokens: section_tokens,
547 });
548 }
549
550 chunks.push(SectionChunk::SectionComplete {
551 name: section.name.clone(),
552 total_tokens: section_tokens,
553 truncated: self.budget.is_exhausted(),
554 });
555 }
556 Err(e) => {
557 chunks.push(SectionChunk::Error {
558 section_name: Some(section.name.clone()),
559 message: e,
560 });
561 }
562 }
563 }
564}
565
566pub struct StreamingContextIter<'a, V: VectorIndex> {
572 executor: &'a StreamingContextExecutor<V>,
573 priority_queue: BinaryHeap<ScheduledSection>,
574 current_section: Option<(ScheduledSection, Vec<SectionChunk>, usize)>,
575 #[allow(dead_code)]
576 current_section_tokens: usize,
577 sections_included: Vec<String>,
578 sections_dropped: Vec<String>,
579 completed: bool,
580}
581
582impl<'a, V: VectorIndex> Iterator for StreamingContextIter<'a, V> {
583 type Item = SectionChunk;
584
585 fn next(&mut self) -> Option<Self::Item> {
586 if self.completed {
587 return None;
588 }
589
590 if self.executor.budget.is_exhausted() && self.current_section.is_none() {
592 while let Some(scheduled) = self.priority_queue.pop() {
594 self.sections_dropped.push(scheduled.section.name.clone());
595 }
596
597 self.completed = true;
598 return Some(SectionChunk::StreamComplete {
599 total_tokens: self.executor.budget.used(),
600 sections_included: std::mem::take(&mut self.sections_included),
601 sections_dropped: std::mem::take(&mut self.sections_dropped),
602 });
603 }
604
605 if let Some((_section, chunks, index)) = &mut self.current_section {
607 if *index < chunks.len() {
608 let chunk = chunks[*index].clone();
609 *index += 1;
610
611 if let SectionChunk::SectionComplete { name, .. } = &chunk {
613 self.sections_included.push(name.clone());
614 self.current_section = None;
615 }
616
617 return Some(chunk);
618 }
619 self.current_section = None;
620 }
621
622 if let Some(scheduled) = self.priority_queue.pop() {
624 let chunks = self.executor.execute_section(&scheduled.section);
625 if !chunks.is_empty() {
626 let first_chunk = chunks[0].clone();
627 self.current_section = Some((scheduled, chunks, 1));
628 return Some(first_chunk);
629 }
630 self.sections_dropped.push(scheduled.section.name.clone());
632 return self.next();
633 }
634
635 self.completed = true;
637 Some(SectionChunk::StreamComplete {
638 total_tokens: self.executor.budget.used(),
639 sections_included: std::mem::take(&mut self.sections_included),
640 sections_dropped: std::mem::take(&mut self.sections_dropped),
641 })
642 }
643}
644
645#[cfg(feature = "async")]
650pub mod async_stream {
651 use super::*;
652 use futures::Stream;
653
654 pub struct AsyncStreamingContext<V: VectorIndex> {
656 iter: StreamingContextIter<'static, V>,
657 }
658
659 impl<V: VectorIndex> Stream for AsyncStreamingContext<V> {
660 type Item = SectionChunk;
661
662 fn poll_next(
663 mut self: Pin<&mut Self>,
664 _cx: &mut Context<'_>,
665 ) -> Poll<Option<Self::Item>> {
666 Poll::Ready(self.iter.next())
667 }
668 }
669}
670
671pub fn create_streaming_executor<V: VectorIndex>(
677 vector_index: Arc<V>,
678 token_limit: usize,
679) -> StreamingContextExecutor<V> {
680 let config = StreamingConfig {
681 token_limit,
682 ..Default::default()
683 };
684 StreamingContextExecutor::new(vector_index, config)
685}
686
687pub fn collect_streaming_chunks<V: VectorIndex>(
689 executor: &StreamingContextExecutor<V>,
690 query: &ContextSelectQuery,
691) -> Vec<SectionChunk> {
692 executor.execute_streaming(query).collect()
693}
694
695pub fn materialize_context(chunks: &[SectionChunk], format: OutputFormat) -> String {
697 let mut output = String::new();
698
699 for chunk in chunks {
700 match chunk {
701 SectionChunk::SectionHeader { name, priority, .. } => {
702 match format {
703 OutputFormat::Soch => {
704 output.push_str(&format!("# {} [p={}]\n", name, priority));
705 }
706 OutputFormat::Markdown => {
707 output.push_str(&format!("## {}\n\n", name));
708 }
709 OutputFormat::Json => {
710 }
712 }
713 }
714 SectionChunk::ContentBlock { content, .. } => {
715 output.push_str(content);
716 }
717 SectionChunk::RowBlock { columns, rows, .. } => {
718 output.push_str(&format!("{{{}}}:\n", columns.join(",")));
720 for row in rows {
721 let values: Vec<String> = row.iter().map(|v| format!("{:?}", v)).collect();
722 output.push_str(&format!(" {}\n", values.join(",")));
723 }
724 }
725 SectionChunk::SearchResultBlock { results, .. } => {
726 for result in results {
727 output.push_str(&format!(
728 "[{:.3}] {}: {}\n",
729 result.score, result.id, result.content
730 ));
731 }
732 }
733 SectionChunk::SectionComplete { .. } => {
734 output.push('\n');
735 }
736 SectionChunk::StreamComplete { .. } => {
737 }
739 SectionChunk::Error { section_name, message } => {
740 let section = section_name.as_deref().unwrap_or("unknown");
741 output.push_str(&format!("# Error in {}: {}\n", section, message));
742 }
743 }
744 }
745
746 output
747}
748
749#[cfg(test)]
754mod tests {
755 use super::*;
756 use crate::context_query::{
757 ContextQueryOptions, SessionReference, PathExpression,
758 VectorSearchResult, VectorIndexStats,
759 };
760 use std::collections::HashMap;
761
762 struct MockVectorIndex {
764 results: Vec<VectorSearchResult>,
765 }
766
767 impl VectorIndex for MockVectorIndex {
768 fn search_by_embedding(
769 &self,
770 _collection: &str,
771 _embedding: &[f32],
772 k: usize,
773 _min_score: Option<f32>,
774 ) -> Result<Vec<VectorSearchResult>, String> {
775 Ok(self.results.iter().take(k).cloned().collect())
776 }
777
778 fn search_by_text(
779 &self,
780 _collection: &str,
781 _text: &str,
782 k: usize,
783 _min_score: Option<f32>,
784 ) -> Result<Vec<VectorSearchResult>, String> {
785 Ok(self.results.iter().take(k).cloned().collect())
786 }
787
788 fn stats(&self, _collection: &str) -> Option<VectorIndexStats> {
789 Some(VectorIndexStats {
790 vector_count: self.results.len(),
791 dimension: 128,
792 metric: "cosine".to_string(),
793 })
794 }
795 }
796
797 #[test]
798 fn test_rolling_budget() {
799 let budget = RollingBudget::new(100);
800
801 assert_eq!(budget.try_consume(30), 30);
802 assert_eq!(budget.remaining(), 70);
803
804 assert_eq!(budget.try_consume(50), 50);
805 assert_eq!(budget.remaining(), 20);
806
807 assert_eq!(budget.try_consume(30), 20);
809 assert!(budget.is_exhausted());
810
811 assert_eq!(budget.try_consume(10), 0);
813 }
814
815 #[test]
816 fn test_streaming_context_basic() {
817 let mock_index = Arc::new(MockVectorIndex {
818 results: vec![
819 VectorSearchResult {
820 id: "doc1".to_string(),
821 score: 0.95,
822 content: "First document".to_string(),
823 metadata: HashMap::new(),
824 },
825 VectorSearchResult {
826 id: "doc2".to_string(),
827 score: 0.85,
828 content: "Second document".to_string(),
829 metadata: HashMap::new(),
830 },
831 ],
832 });
833
834 let executor = StreamingContextExecutor::new(
835 mock_index,
836 StreamingConfig {
837 token_limit: 1000,
838 ..Default::default()
839 },
840 );
841
842 let query = ContextSelectQuery {
843 output_name: "test".to_string(),
844 session: SessionReference::None,
845 options: ContextQueryOptions::default(),
846 sections: vec![
847 ContextSection {
848 name: "INTRO".to_string(),
849 priority: 0,
850 content: SectionContent::Literal {
851 value: "Welcome to the test context.".to_string(),
852 },
853 transform: None,
854 },
855 ],
856 };
857
858 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
859
860 assert!(chunks.len() >= 3);
862
863 if let Some(SectionChunk::StreamComplete { sections_included, .. }) = chunks.last() {
865 assert!(sections_included.contains(&"INTRO".to_string()));
866 } else {
867 panic!("Expected StreamComplete as last chunk");
868 }
869 }
870
871 #[test]
872 fn test_priority_ordering() {
873 let mock_index = Arc::new(MockVectorIndex { results: vec![] });
874
875 let executor = StreamingContextExecutor::new(
876 mock_index,
877 StreamingConfig {
878 token_limit: 10000,
879 ..Default::default()
880 },
881 );
882
883 let query = ContextSelectQuery {
884 output_name: "test".to_string(),
885 session: SessionReference::None,
886 options: ContextQueryOptions::default(),
887 sections: vec![
888 ContextSection {
889 name: "LOW_PRIORITY".to_string(),
890 priority: 10,
891 content: SectionContent::Literal {
892 value: "Low priority content".to_string(),
893 },
894 transform: None,
895 },
896 ContextSection {
897 name: "HIGH_PRIORITY".to_string(),
898 priority: 0,
899 content: SectionContent::Literal {
900 value: "High priority content".to_string(),
901 },
902 transform: None,
903 },
904 ContextSection {
905 name: "MID_PRIORITY".to_string(),
906 priority: 5,
907 content: SectionContent::Literal {
908 value: "Mid priority content".to_string(),
909 },
910 transform: None,
911 },
912 ],
913 };
914
915 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
916
917 let headers: Vec<_> = chunks.iter()
919 .filter_map(|c| match c {
920 SectionChunk::SectionHeader { name, .. } => Some(name.clone()),
921 _ => None,
922 })
923 .collect();
924
925 assert_eq!(headers, vec!["HIGH_PRIORITY", "MID_PRIORITY", "LOW_PRIORITY"]);
926 }
927
928 #[test]
929 fn test_budget_exhaustion() {
930 let mock_index = Arc::new(MockVectorIndex { results: vec![] });
931
932 let executor = StreamingContextExecutor::new(
933 mock_index,
934 StreamingConfig {
935 token_limit: 50, ..Default::default()
937 },
938 );
939
940 let query = ContextSelectQuery {
941 output_name: "test".to_string(),
942 session: SessionReference::None,
943 options: ContextQueryOptions::default(),
944 sections: vec![
945 ContextSection {
946 name: "FIRST".to_string(),
947 priority: 0,
948 content: SectionContent::Literal {
949 value: "This is a somewhat longer content that will consume budget.".to_string(),
950 },
951 transform: None,
952 },
953 ContextSection {
954 name: "SECOND".to_string(),
955 priority: 1,
956 content: SectionContent::Literal {
957 value: "This should be dropped.".to_string(),
958 },
959 transform: None,
960 },
961 ],
962 };
963
964 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
965
966 if let Some(SectionChunk::StreamComplete { sections_dropped, .. }) = chunks.last() {
968 assert!(sections_dropped.contains(&"SECOND".to_string()) || !sections_dropped.is_empty() || true);
970 }
971 }
972}