1use crate::soch_ql::SochValue;
48use std::sync::atomic::{AtomicUsize, Ordering};
49
50#[derive(Debug, Clone)]
67pub struct TokenEstimatorConfig {
68 pub int_factor: f32,
70 pub float_factor: f32,
72 pub string_factor: f32,
74 pub hex_factor: f32,
76 pub bytes_per_token: f32,
78 pub safety_margin: f32,
81 pub separator_tokens: usize,
83 pub newline_tokens: usize,
85 pub header_tokens: usize,
87}
88
89impl Default for TokenEstimatorConfig {
90 fn default() -> Self {
91 Self {
92 int_factor: 1.0,
93 float_factor: 1.2,
94 string_factor: 1.1,
95 hex_factor: 2.5,
96 bytes_per_token: 4.0, safety_margin: 1.15, separator_tokens: 1,
99 newline_tokens: 1,
100 header_tokens: 10, }
102 }
103}
104
105impl TokenEstimatorConfig {
106 pub fn gpt4() -> Self {
108 Self {
109 bytes_per_token: 3.8,
110 safety_margin: 1.15,
111 ..Default::default()
112 }
113 }
114
115 pub fn claude() -> Self {
117 Self {
118 bytes_per_token: 4.2,
119 safety_margin: 1.15,
120 ..Default::default()
121 }
122 }
123
124 pub fn conservative() -> Self {
126 Self {
127 int_factor: 1.2,
128 float_factor: 1.4,
129 string_factor: 1.3,
130 hex_factor: 3.0,
131 bytes_per_token: 3.5,
132 safety_margin: 1.25, ..Default::default()
134 }
135 }
136}
137
138pub struct TokenEstimator {
140 config: TokenEstimatorConfig,
141}
142
143impl TokenEstimator {
144 pub fn new() -> Self {
146 Self {
147 config: TokenEstimatorConfig::default(),
148 }
149 }
150
151 pub fn with_config(config: TokenEstimatorConfig) -> Self {
153 Self { config }
154 }
155
156 pub fn estimate_value(&self, value: &SochValue) -> usize {
161 let raw = self.estimate_value_raw(value);
162 ((raw as f32) * self.config.safety_margin).ceil() as usize
163 }
164
165 fn estimate_value_raw(&self, value: &SochValue) -> usize {
167 match value {
168 SochValue::Null => 1,
169 SochValue::Bool(_) => 1, SochValue::Int(n) => {
171 let digits = if *n == 0 {
173 1
174 } else {
175 ((*n).abs() as f64).log10().ceil() as usize + if *n < 0 { 1 } else { 0 }
176 };
177 ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
178 as usize
179 }
180 SochValue::UInt(n) => {
181 let digits = if *n == 0 {
182 1
183 } else {
184 ((*n as f64).log10().ceil() as usize).max(1)
185 };
186 ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
187 as usize
188 }
189 SochValue::Float(f) => {
190 let s = format!("{:.2}", f);
192 ((s.len() as f32 * self.config.float_factor) / self.config.bytes_per_token).ceil()
193 as usize
194 }
195 SochValue::Text(s) => {
196 ((s.len() as f32 * self.config.string_factor) / self.config.bytes_per_token).ceil()
198 as usize
199 }
200 SochValue::Binary(b) => {
201 let hex_len = 2 + b.len() * 2;
203 ((hex_len as f32 * self.config.hex_factor) / self.config.bytes_per_token).ceil()
204 as usize
205 }
206 SochValue::Array(arr) => {
207 let elem_tokens: usize = arr.iter().map(|v| self.estimate_value(v)).sum();
209 let separator_tokens = if arr.is_empty() { 0 } else { arr.len() - 1 };
210 2 + elem_tokens + separator_tokens }
212 }
213 }
214
215 pub fn estimate_row(&self, values: &[SochValue]) -> usize {
217 if values.is_empty() {
218 return 0;
219 }
220
221 let value_tokens: usize = values.iter().map(|v| self.estimate_value(v)).sum();
222 let separator_tokens = (values.len() - 1) * self.config.separator_tokens;
223 let newline = self.config.newline_tokens;
224
225 value_tokens + separator_tokens + newline
226 }
227
228 pub fn estimate_header(&self, table: &str, columns: &[String], row_count: usize) -> usize {
230 let base = self.config.header_tokens;
232 let table_tokens = ((table.len() as f32) / self.config.bytes_per_token).ceil() as usize;
233 let count_tokens = ((row_count as f64).log10().ceil() as usize).max(1);
234 let col_tokens: usize = columns
235 .iter()
236 .map(|c| ((c.len() as f32) / self.config.bytes_per_token).ceil() as usize)
237 .sum();
238
239 base + table_tokens + count_tokens + col_tokens
240 }
241
242 pub fn estimate_table(
244 &self,
245 table: &str,
246 columns: &[String],
247 rows: &[Vec<SochValue>],
248 ) -> usize {
249 let header = self.estimate_header(table, columns, rows.len());
250 let row_tokens: usize = rows.iter().map(|r| self.estimate_row(r)).sum();
251 header + row_tokens
252 }
253
254 pub fn estimate_text(&self, text: &str) -> usize {
256 let raw = ((text.len() as f32) / self.config.bytes_per_token).ceil() as usize;
257 ((raw as f32) * self.config.safety_margin).ceil() as usize
258 }
259
260 pub fn truncate_to_tokens(&self, text: &str, max_tokens: usize) -> String {
264 truncate_to_tokens(text, max_tokens, self, "...")
265 }
266}
267
268impl Default for TokenEstimator {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[derive(Debug, Clone)]
280pub struct BudgetAllocation {
281 pub full_sections: Vec<String>,
283 pub truncated_sections: Vec<(String, usize, usize)>,
285 pub dropped_sections: Vec<String>,
287 pub tokens_allocated: usize,
289 pub tokens_remaining: usize,
291 pub explain: Vec<AllocationDecision>,
293}
294
295#[derive(Debug, Clone)]
297pub struct AllocationDecision {
298 pub section: String,
300 pub priority: i32,
302 pub requested: usize,
304 pub allocated: usize,
306 pub outcome: AllocationOutcome,
308 pub reason: String,
310}
311
312#[derive(Debug, Clone, Copy, PartialEq, Eq)]
314pub enum AllocationOutcome {
315 Full,
317 Truncated,
319 Dropped,
321}
322
323#[derive(Debug, Clone)]
325pub struct BudgetSection {
326 pub name: String,
328 pub priority: i32,
330 pub estimated_tokens: usize,
332 pub minimum_tokens: Option<usize>,
334 pub required: bool,
336 pub weight: f32,
338}
339
340impl Default for BudgetSection {
341 fn default() -> Self {
342 Self {
343 name: String::new(),
344 priority: 0,
345 estimated_tokens: 0,
346 minimum_tokens: None,
347 required: false,
348 weight: 1.0,
349 }
350 }
351}
352
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
355pub enum AllocationStrategy {
356 #[default]
358 GreedyPriority,
359 Proportional,
361 StrictPriority,
363}
364
365pub struct TokenBudgetEnforcer {
370 budget: usize,
372 allocated: AtomicUsize,
374 estimator: TokenEstimator,
376 reserved: usize,
378 strategy: AllocationStrategy,
380}
381
382#[derive(Debug, Clone)]
384pub struct TokenBudgetConfig {
385 pub total_budget: usize,
387 pub reserved_tokens: usize,
389 pub strict: bool,
391 pub default_priority: i32,
393 pub strategy: AllocationStrategy,
395}
396
397impl Default for TokenBudgetConfig {
398 fn default() -> Self {
399 Self {
400 total_budget: 4096,
401 reserved_tokens: 100,
402 strict: false,
403 default_priority: 10,
404 strategy: AllocationStrategy::GreedyPriority,
405 }
406 }
407}
408
409impl TokenBudgetEnforcer {
410 pub fn new(config: TokenBudgetConfig) -> Self {
412 Self {
413 budget: config.total_budget,
414 allocated: AtomicUsize::new(0),
415 estimator: TokenEstimator::new(),
416 reserved: config.reserved_tokens,
417 strategy: config.strategy,
418 }
419 }
420
421 pub fn with_budget(budget: usize) -> Self {
423 Self {
424 budget,
425 allocated: AtomicUsize::new(0),
426 estimator: TokenEstimator::new(),
427 reserved: 0,
428 strategy: AllocationStrategy::GreedyPriority,
429 }
430 }
431
432 pub fn with_estimator(budget: usize, estimator: TokenEstimator) -> Self {
434 Self {
435 budget,
436 allocated: AtomicUsize::new(0),
437 estimator,
438 reserved: 0,
439 strategy: AllocationStrategy::GreedyPriority,
440 }
441 }
442
443 pub fn with_strategy(mut self, strategy: AllocationStrategy) -> Self {
445 self.strategy = strategy;
446 self
447 }
448
449 pub fn reserve(&mut self, tokens: usize) {
451 self.reserved = tokens;
452 }
453
454 pub fn available(&self) -> usize {
456 let allocated = self.allocated.load(Ordering::Acquire);
457 self.budget.saturating_sub(self.reserved + allocated)
458 }
459
460 pub fn total_budget(&self) -> usize {
462 self.budget
463 }
464
465 pub fn allocated(&self) -> usize {
467 self.allocated.load(Ordering::Acquire)
468 }
469
470 pub fn try_allocate(&self, tokens: usize) -> bool {
472 loop {
473 let current = self.allocated.load(Ordering::Acquire);
474 let new_total = current + tokens;
475
476 if new_total + self.reserved > self.budget {
477 return false;
478 }
479
480 if self
481 .allocated
482 .compare_exchange(current, new_total, Ordering::AcqRel, Ordering::Acquire)
483 .is_ok()
484 {
485 return true;
486 }
487 }
489 }
490
491 pub fn allocate_sections(&self, sections: &[BudgetSection]) -> BudgetAllocation {
493 match self.strategy {
494 AllocationStrategy::GreedyPriority => self.allocate_greedy(sections),
495 AllocationStrategy::Proportional => self.allocate_proportional(sections),
496 AllocationStrategy::StrictPriority => self.allocate_strict(sections),
497 }
498 }
499
500 fn allocate_greedy(&self, sections: &[BudgetSection]) -> BudgetAllocation {
502 let mut sorted: Vec<_> = sections.iter().collect();
504 sorted.sort_by_key(|s| s.priority);
505
506 let mut allocation = BudgetAllocation {
507 full_sections: Vec::new(),
508 truncated_sections: Vec::new(),
509 dropped_sections: Vec::new(),
510 tokens_allocated: 0,
511 tokens_remaining: self.budget.saturating_sub(self.reserved),
512 explain: Vec::new(),
513 };
514
515 for section in sorted {
516 let remaining = allocation.tokens_remaining;
517
518 if section.estimated_tokens <= remaining {
519 allocation.full_sections.push(section.name.clone());
521 allocation.tokens_allocated += section.estimated_tokens;
522 allocation.tokens_remaining -= section.estimated_tokens;
523 allocation.explain.push(AllocationDecision {
524 section: section.name.clone(),
525 priority: section.priority,
526 requested: section.estimated_tokens,
527 allocated: section.estimated_tokens,
528 outcome: AllocationOutcome::Full,
529 reason: format!("Fits in remaining budget ({} tokens)", remaining),
530 });
531 } else if let Some(min) = section.minimum_tokens {
532 if min <= remaining {
534 let truncated_to = remaining;
535 allocation.truncated_sections.push((
536 section.name.clone(),
537 section.estimated_tokens,
538 truncated_to,
539 ));
540 allocation.tokens_allocated += truncated_to;
541 allocation.explain.push(AllocationDecision {
542 section: section.name.clone(),
543 priority: section.priority,
544 requested: section.estimated_tokens,
545 allocated: truncated_to,
546 outcome: AllocationOutcome::Truncated,
547 reason: format!(
548 "Truncated from {} to {} tokens (min: {})",
549 section.estimated_tokens, truncated_to, min
550 ),
551 });
552 allocation.tokens_remaining = 0;
553 } else {
554 allocation.dropped_sections.push(section.name.clone());
555 allocation.explain.push(AllocationDecision {
556 section: section.name.clone(),
557 priority: section.priority,
558 requested: section.estimated_tokens,
559 allocated: 0,
560 outcome: AllocationOutcome::Dropped,
561 reason: format!(
562 "Minimum {} exceeds remaining {} tokens",
563 min, remaining
564 ),
565 });
566 }
567 } else {
568 allocation.dropped_sections.push(section.name.clone());
570 allocation.explain.push(AllocationDecision {
571 section: section.name.clone(),
572 priority: section.priority,
573 requested: section.estimated_tokens,
574 allocated: 0,
575 outcome: AllocationOutcome::Dropped,
576 reason: format!(
577 "Requested {} exceeds remaining {} (no truncation allowed)",
578 section.estimated_tokens, remaining
579 ),
580 });
581 }
582 }
583
584 allocation
585 }
586
587 fn allocate_proportional(&self, sections: &[BudgetSection]) -> BudgetAllocation {
594 let available = self.budget.saturating_sub(self.reserved);
595 let total_weight: f32 = sections.iter().map(|s| s.weight).sum();
596
597 if total_weight == 0.0 {
598 return self.allocate_greedy(sections);
599 }
600
601 let mut allocation = BudgetAllocation {
602 full_sections: Vec::new(),
603 truncated_sections: Vec::new(),
604 dropped_sections: Vec::new(),
605 tokens_allocated: 0,
606 tokens_remaining: available,
607 explain: Vec::new(),
608 };
609
610 let mut allocations: Vec<(usize, usize, bool)> = sections
612 .iter()
613 .map(|s| {
614 let proportional = ((available as f32) * s.weight / total_weight).floor() as usize;
615 let capped = proportional.min(s.estimated_tokens);
616 let min = s.minimum_tokens.unwrap_or(0);
617 (capped.max(min), s.estimated_tokens, capped < s.estimated_tokens)
618 })
619 .collect();
620
621 let mut total: usize = allocations.iter().map(|(a, _, _)| *a).sum();
623
624 while total > available {
626 let max_idx = allocations
628 .iter()
629 .enumerate()
630 .filter(|(i, (a, _, _))| {
631 *a > sections[*i].minimum_tokens.unwrap_or(0)
632 })
633 .max_by_key(|(_, (a, _, _))| *a)
634 .map(|(i, _)| i);
635
636 match max_idx {
637 Some(idx) => {
638 let reduce = (total - available).min(allocations[idx].0 - sections[idx].minimum_tokens.unwrap_or(0));
639 allocations[idx].0 -= reduce;
640 total -= reduce;
641 }
642 None => break, }
644 }
645
646 for (i, section) in sections.iter().enumerate() {
648 let (allocated, requested, truncated) = allocations[i];
649
650 if allocated == 0 {
651 allocation.dropped_sections.push(section.name.clone());
652 allocation.explain.push(AllocationDecision {
653 section: section.name.clone(),
654 priority: section.priority,
655 requested,
656 allocated: 0,
657 outcome: AllocationOutcome::Dropped,
658 reason: "No budget available after proportional allocation".to_string(),
659 });
660 } else if truncated {
661 allocation.truncated_sections.push((
662 section.name.clone(),
663 requested,
664 allocated,
665 ));
666 allocation.tokens_allocated += allocated;
667 allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
668 allocation.explain.push(AllocationDecision {
669 section: section.name.clone(),
670 priority: section.priority,
671 requested,
672 allocated,
673 outcome: AllocationOutcome::Truncated,
674 reason: format!(
675 "Proportional allocation: {:.1}% of budget (weight {:.1})",
676 (allocated as f32 / available as f32) * 100.0,
677 section.weight
678 ),
679 });
680 } else {
681 allocation.full_sections.push(section.name.clone());
682 allocation.tokens_allocated += allocated;
683 allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
684 allocation.explain.push(AllocationDecision {
685 section: section.name.clone(),
686 priority: section.priority,
687 requested,
688 allocated,
689 outcome: AllocationOutcome::Full,
690 reason: format!(
691 "Full allocation within proportional budget (weight {:.1})",
692 section.weight
693 ),
694 });
695 }
696 }
697
698 allocation
699 }
700
701 fn allocate_strict(&self, sections: &[BudgetSection]) -> BudgetAllocation {
703 let mut sorted: Vec<_> = sections.iter().collect();
704 sorted.sort_by_key(|s| (if s.required { 0 } else { 1 }, s.priority));
705
706 let mut allocation = BudgetAllocation {
708 full_sections: Vec::new(),
709 truncated_sections: Vec::new(),
710 dropped_sections: Vec::new(),
711 tokens_allocated: 0,
712 tokens_remaining: self.budget.saturating_sub(self.reserved),
713 explain: Vec::new(),
714 };
715
716 for section in sorted.iter().filter(|s| s.required) {
718 let remaining = allocation.tokens_remaining;
719 let min = section.minimum_tokens.unwrap_or(section.estimated_tokens);
720
721 if section.estimated_tokens <= remaining {
722 allocation.full_sections.push(section.name.clone());
723 allocation.tokens_allocated += section.estimated_tokens;
724 allocation.tokens_remaining -= section.estimated_tokens;
725 allocation.explain.push(AllocationDecision {
726 section: section.name.clone(),
727 priority: section.priority,
728 requested: section.estimated_tokens,
729 allocated: section.estimated_tokens,
730 outcome: AllocationOutcome::Full,
731 reason: "Required section - full allocation".to_string(),
732 });
733 } else if min <= remaining {
734 allocation.truncated_sections.push((
735 section.name.clone(),
736 section.estimated_tokens,
737 remaining,
738 ));
739 allocation.tokens_allocated += remaining;
740 allocation.explain.push(AllocationDecision {
741 section: section.name.clone(),
742 priority: section.priority,
743 requested: section.estimated_tokens,
744 allocated: remaining,
745 outcome: AllocationOutcome::Truncated,
746 reason: "Required section - truncated to fit".to_string(),
747 });
748 allocation.tokens_remaining = 0;
749 }
750 }
752
753 for section in sorted.iter().filter(|s| !s.required) {
755 let remaining = allocation.tokens_remaining;
756
757 if remaining == 0 {
758 allocation.dropped_sections.push(section.name.clone());
759 allocation.explain.push(AllocationDecision {
760 section: section.name.clone(),
761 priority: section.priority,
762 requested: section.estimated_tokens,
763 allocated: 0,
764 outcome: AllocationOutcome::Dropped,
765 reason: "No budget remaining after required sections".to_string(),
766 });
767 continue;
768 }
769
770 if section.estimated_tokens <= remaining {
771 allocation.full_sections.push(section.name.clone());
772 allocation.tokens_allocated += section.estimated_tokens;
773 allocation.tokens_remaining -= section.estimated_tokens;
774 allocation.explain.push(AllocationDecision {
775 section: section.name.clone(),
776 priority: section.priority,
777 requested: section.estimated_tokens,
778 allocated: section.estimated_tokens,
779 outcome: AllocationOutcome::Full,
780 reason: "Optional section - fits in remaining budget".to_string(),
781 });
782 } else if let Some(min) = section.minimum_tokens {
783 if min <= remaining {
784 allocation.truncated_sections.push((
785 section.name.clone(),
786 section.estimated_tokens,
787 remaining,
788 ));
789 allocation.tokens_allocated += remaining;
790 allocation.explain.push(AllocationDecision {
791 section: section.name.clone(),
792 priority: section.priority,
793 requested: section.estimated_tokens,
794 allocated: remaining,
795 outcome: AllocationOutcome::Truncated,
796 reason: "Optional section - truncated to fit".to_string(),
797 });
798 allocation.tokens_remaining = 0;
799 } else {
800 allocation.dropped_sections.push(section.name.clone());
801 allocation.explain.push(AllocationDecision {
802 section: section.name.clone(),
803 priority: section.priority,
804 requested: section.estimated_tokens,
805 allocated: 0,
806 outcome: AllocationOutcome::Dropped,
807 reason: format!("Minimum {} exceeds remaining {}", min, remaining),
808 });
809 }
810 } else {
811 allocation.dropped_sections.push(section.name.clone());
812 allocation.explain.push(AllocationDecision {
813 section: section.name.clone(),
814 priority: section.priority,
815 requested: section.estimated_tokens,
816 allocated: 0,
817 outcome: AllocationOutcome::Dropped,
818 reason: format!("Requested {} exceeds remaining {}", section.estimated_tokens, remaining),
819 });
820 }
821 }
822
823 allocation
824 }
825
826 pub fn reset(&self) {
828 self.allocated.store(0, Ordering::Release);
829 }
830
831 pub fn estimator(&self) -> &TokenEstimator {
833 &self.estimator
834 }
835}
836
837impl BudgetAllocation {
842 pub fn explain_text(&self) -> String {
844 let mut output = String::new();
845 output.push_str("=== CONTEXT BUDGET ALLOCATION ===\n\n");
846 output.push_str(&format!(
847 "Total Allocated: {} tokens\n",
848 self.tokens_allocated
849 ));
850 output.push_str(&format!("Remaining: {} tokens\n\n", self.tokens_remaining));
851
852 output.push_str("SECTIONS:\n");
853 for decision in &self.explain {
854 let status = match decision.outcome {
855 AllocationOutcome::Full => "✓ FULL",
856 AllocationOutcome::Truncated => "◐ TRUNCATED",
857 AllocationOutcome::Dropped => "✗ DROPPED",
858 };
859 output.push_str(&format!(
860 " [{:^12}] {} (priority {})\n",
861 status, decision.section, decision.priority
862 ));
863 output.push_str(&format!(
864 " Requested: {}, Allocated: {}\n",
865 decision.requested, decision.allocated
866 ));
867 output.push_str(&format!(" Reason: {}\n", decision.reason));
868 }
869
870 output
871 }
872
873 pub fn explain_json(&self) -> String {
875 serde_json::to_string_pretty(&ExplainOutput {
876 tokens_allocated: self.tokens_allocated,
877 tokens_remaining: self.tokens_remaining,
878 full_sections: self.full_sections.clone(),
879 truncated_sections: self.truncated_sections.clone(),
880 dropped_sections: self.dropped_sections.clone(),
881 decisions: self.explain.iter().map(|d| ExplainDecision {
882 section: d.section.clone(),
883 priority: d.priority,
884 requested: d.requested,
885 allocated: d.allocated,
886 outcome: format!("{:?}", d.outcome),
887 reason: d.reason.clone(),
888 }).collect(),
889 }).unwrap_or_else(|_| "{}".to_string())
890 }
891}
892
893#[derive(serde::Serialize)]
894struct ExplainOutput {
895 tokens_allocated: usize,
896 tokens_remaining: usize,
897 full_sections: Vec<String>,
898 truncated_sections: Vec<(String, usize, usize)>,
899 dropped_sections: Vec<String>,
900 decisions: Vec<ExplainDecision>,
901}
902
903#[derive(serde::Serialize)]
904struct ExplainDecision {
905 section: String,
906 priority: i32,
907 requested: usize,
908 allocated: usize,
909 outcome: String,
910 reason: String,
911}
912
913pub fn truncate_to_tokens(
919 text: &str,
920 max_tokens: usize,
921 estimator: &TokenEstimator,
922 suffix: &str,
923) -> String {
924 let current = estimator.estimate_text(text);
925
926 if current <= max_tokens {
927 return text.to_string();
928 }
929
930 let suffix_tokens = estimator.estimate_text(suffix);
931 let target_tokens = max_tokens.saturating_sub(suffix_tokens);
932
933 if target_tokens == 0 {
934 return suffix.to_string();
935 }
936
937 let mut low = 0;
939 let mut high = text.len();
940
941 while low < high {
942 let mid = (low + high).div_ceil(2);
943
944 let boundary = text
946 .char_indices()
947 .take_while(|(i, _)| *i < mid)
948 .last()
949 .map(|(i, c)| i + c.len_utf8())
950 .unwrap_or(0);
951
952 let truncated = &text[..boundary];
953 let tokens = estimator.estimate_text(truncated);
954
955 if tokens <= target_tokens {
956 low = boundary;
957 } else {
958 high = boundary.saturating_sub(1);
959 }
960 }
961
962 let truncated = &text[..low];
964 let word_boundary = truncated.rfind(|c: char| c.is_whitespace()).unwrap_or(low);
965
966 format!("{}{}", &text[..word_boundary], suffix)
967}
968
969pub fn truncate_rows(
971 rows: &[Vec<SochValue>],
972 max_tokens: usize,
973 estimator: &TokenEstimator,
974) -> Vec<Vec<SochValue>> {
975 let mut result = Vec::new();
976 let mut used = 0;
977
978 for row in rows {
979 let row_tokens = estimator.estimate_row(row);
980
981 if used + row_tokens <= max_tokens {
982 result.push(row.clone());
983 used += row_tokens;
984 } else {
985 break; }
987 }
988
989 result
990}
991
992#[cfg(test)]
997mod tests {
998 use super::*;
999
1000 #[test]
1001 fn test_estimate_value_int() {
1002 let est = TokenEstimator::new();
1003
1004 assert!(est.estimate_value(&SochValue::Int(0)) >= 1);
1006 assert!(est.estimate_value(&SochValue::Int(42)) >= 1);
1007
1008 let small = est.estimate_value(&SochValue::Int(42));
1010 let large = est.estimate_value(&SochValue::Int(1_000_000_000));
1011 assert!(large >= small);
1012 }
1013
1014 #[test]
1015 fn test_estimate_value_text() {
1016 let est = TokenEstimator::new();
1017
1018 let short = est.estimate_value(&SochValue::Text("hello".to_string()));
1019 let long = est.estimate_value(&SochValue::Text(
1020 "hello world this is a longer string".to_string(),
1021 ));
1022
1023 assert!(long > short);
1024 }
1025
1026 #[test]
1027 #[allow(clippy::approx_constant)]
1028 fn test_estimate_row() {
1029 let est = TokenEstimator::new();
1030
1031 let row = vec![
1032 SochValue::Int(1),
1033 SochValue::Text("Alice".to_string()),
1034 SochValue::Float(3.14),
1035 ];
1036
1037 let tokens = est.estimate_row(&row);
1038
1039 assert!(tokens >= 3); }
1042
1043 #[test]
1044 fn test_estimate_table() {
1045 let est = TokenEstimator::new();
1046
1047 let columns = vec!["id".to_string(), "name".to_string()];
1048 let rows = vec![
1049 vec![SochValue::Int(1), SochValue::Text("Alice".to_string())],
1050 vec![SochValue::Int(2), SochValue::Text("Bob".to_string())],
1051 ];
1052
1053 let tokens = est.estimate_table("users", &columns, &rows);
1054
1055 assert!(tokens > est.estimate_row(&rows[0]) * 2);
1057 }
1058
1059 #[test]
1060 fn test_budget_enforcer_allocation() {
1061 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1062
1063 assert!(enforcer.try_allocate(500));
1064 assert_eq!(enforcer.allocated(), 500);
1065 assert_eq!(enforcer.available(), 500);
1066
1067 assert!(enforcer.try_allocate(400));
1068 assert_eq!(enforcer.allocated(), 900);
1069
1070 assert!(!enforcer.try_allocate(200));
1072 assert_eq!(enforcer.allocated(), 900);
1073 }
1074
1075 #[test]
1076 fn test_budget_enforcer_reset() {
1077 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1078
1079 enforcer.try_allocate(800);
1080 assert_eq!(enforcer.allocated(), 800);
1081
1082 enforcer.reset();
1083 assert_eq!(enforcer.allocated(), 0);
1084 }
1085
1086 #[test]
1087 fn test_allocate_sections() {
1088 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1089
1090 let sections = vec![
1091 BudgetSection {
1092 name: "A".to_string(),
1093 priority: 0,
1094 estimated_tokens: 300,
1095 minimum_tokens: None,
1096 required: true,
1097 weight: 1.0,
1098 },
1099 BudgetSection {
1100 name: "B".to_string(),
1101 priority: 1,
1102 estimated_tokens: 400,
1103 minimum_tokens: Some(200),
1104 required: false,
1105 weight: 1.0,
1106 },
1107 BudgetSection {
1108 name: "C".to_string(),
1109 priority: 2,
1110 estimated_tokens: 500,
1111 minimum_tokens: None,
1112 required: false,
1113 weight: 1.0,
1114 },
1115 ];
1116
1117 let allocation = enforcer.allocate_sections(§ions);
1118
1119 assert!(allocation.full_sections.contains(&"A".to_string()));
1121
1122 assert!(allocation.dropped_sections.contains(&"C".to_string()));
1125
1126 assert!(allocation.tokens_allocated <= 1000);
1127 }
1128
1129 #[test]
1130 fn test_allocate_by_priority() {
1131 let enforcer = TokenBudgetEnforcer::with_budget(500);
1132
1133 let sections = vec![
1134 BudgetSection {
1135 name: "LowPriority".to_string(),
1136 priority: 10,
1137 estimated_tokens: 200,
1138 minimum_tokens: None,
1139 required: false,
1140 weight: 1.0,
1141 },
1142 BudgetSection {
1143 name: "HighPriority".to_string(),
1144 priority: 0,
1145 estimated_tokens: 400,
1146 minimum_tokens: None,
1147 required: true,
1148 weight: 1.0,
1149 },
1150 ];
1151
1152 let allocation = enforcer.allocate_sections(§ions);
1153
1154 assert!(
1156 allocation
1157 .full_sections
1158 .contains(&"HighPriority".to_string())
1159 );
1160
1161 assert!(
1163 allocation
1164 .dropped_sections
1165 .contains(&"LowPriority".to_string())
1166 );
1167 }
1168
1169 #[test]
1170 fn test_truncate_to_tokens() {
1171 let est = TokenEstimator::new();
1172
1173 let text = "This is a long text that needs to be truncated to fit within the token budget";
1174 let truncated = truncate_to_tokens(text, 10, &est, "...");
1175
1176 assert!(truncated.len() < text.len());
1178
1179 assert!(truncated.ends_with("..."));
1181
1182 assert!(est.estimate_text(&truncated) <= 10);
1184 }
1185
1186 #[test]
1187 fn test_truncate_rows() {
1188 let est = TokenEstimator::new();
1189
1190 let rows: Vec<Vec<SochValue>> = (0..100)
1191 .map(|i| vec![SochValue::Int(i), SochValue::Text(format!("row{}", i))])
1192 .collect();
1193
1194 let truncated = truncate_rows(&rows, 50, &est);
1195
1196 assert!(truncated.len() < rows.len());
1198
1199 let total: usize = truncated.iter().map(|r| est.estimate_row(r)).sum();
1201 assert!(total <= 50);
1202 }
1203
1204 #[test]
1205 fn test_reserved_budget() {
1206 let mut enforcer = TokenBudgetEnforcer::with_budget(1000);
1207 enforcer.reserve(200);
1208
1209 assert_eq!(enforcer.available(), 800);
1210
1211 assert!(enforcer.try_allocate(700));
1212 assert_eq!(enforcer.available(), 100);
1213
1214 assert!(!enforcer.try_allocate(200));
1216 }
1217
1218 #[test]
1219 fn test_estimator_configs() {
1220 let default = TokenEstimator::new();
1221 let gpt4 = TokenEstimator::with_config(TokenEstimatorConfig::gpt4());
1222 let conservative = TokenEstimator::with_config(TokenEstimatorConfig::conservative());
1223
1224 let text = "Hello, this is a test string for comparing token estimation across different configurations.";
1225
1226 let default_est = default.estimate_text(text);
1227 let gpt4_est = gpt4.estimate_text(text);
1228 let conservative_est = conservative.estimate_text(text);
1229
1230 assert!(conservative_est >= default_est);
1232
1233 assert!(default_est > 0);
1235 assert!(gpt4_est > 0);
1236 assert!(conservative_est > 0);
1237 }
1238
1239 #[test]
1240 fn test_section_with_truncation() {
1241 let enforcer = TokenBudgetEnforcer::with_budget(600);
1242
1243 let sections = vec![
1244 BudgetSection {
1245 name: "Required".to_string(),
1246 priority: 0,
1247 estimated_tokens: 500,
1248 minimum_tokens: None,
1249 required: true,
1250 weight: 1.0,
1251 },
1252 BudgetSection {
1253 name: "Optional".to_string(),
1254 priority: 1,
1255 estimated_tokens: 300,
1256 minimum_tokens: Some(50), required: false,
1258 weight: 1.0,
1259 },
1260 ];
1261
1262 let allocation = enforcer.allocate_sections(§ions);
1263
1264 assert!(allocation.full_sections.contains(&"Required".to_string()));
1266
1267 assert!(
1269 allocation
1270 .truncated_sections
1271 .iter()
1272 .any(|(n, _, _)| n == "Optional")
1273 );
1274 }
1275}