1use crate::soch_ql::SochValue;
48use std::sync::atomic::{AtomicUsize, Ordering};
49
50#[derive(Debug, Clone)]
56pub struct TokenEstimatorConfig {
57 pub int_factor: f32,
59 pub float_factor: f32,
61 pub string_factor: f32,
63 pub hex_factor: f32,
65 pub bytes_per_token: f32,
67 pub separator_tokens: usize,
69 pub newline_tokens: usize,
71 pub header_tokens: usize,
73}
74
75impl Default for TokenEstimatorConfig {
76 fn default() -> Self {
77 Self {
78 int_factor: 1.0,
79 float_factor: 1.2,
80 string_factor: 1.1,
81 hex_factor: 2.5,
82 bytes_per_token: 4.0, separator_tokens: 1,
84 newline_tokens: 1,
85 header_tokens: 10, }
87 }
88}
89
90impl TokenEstimatorConfig {
91 pub fn gpt4() -> Self {
93 Self {
94 bytes_per_token: 3.8,
95 ..Default::default()
96 }
97 }
98
99 pub fn claude() -> Self {
101 Self {
102 bytes_per_token: 4.2,
103 ..Default::default()
104 }
105 }
106
107 pub fn conservative() -> Self {
109 Self {
110 int_factor: 1.2,
111 float_factor: 1.4,
112 string_factor: 1.3,
113 hex_factor: 3.0,
114 bytes_per_token: 3.5,
115 ..Default::default()
116 }
117 }
118}
119
120pub struct TokenEstimator {
122 config: TokenEstimatorConfig,
123}
124
125impl TokenEstimator {
126 pub fn new() -> Self {
128 Self {
129 config: TokenEstimatorConfig::default(),
130 }
131 }
132
133 pub fn with_config(config: TokenEstimatorConfig) -> Self {
135 Self { config }
136 }
137
138 pub fn estimate_value(&self, value: &SochValue) -> usize {
140 match value {
141 SochValue::Null => 1,
142 SochValue::Bool(_) => 1, SochValue::Int(n) => {
144 let digits = if *n == 0 {
146 1
147 } else {
148 ((*n).abs() as f64).log10().ceil() as usize + if *n < 0 { 1 } else { 0 }
149 };
150 ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
151 as usize
152 }
153 SochValue::UInt(n) => {
154 let digits = if *n == 0 {
155 1
156 } else {
157 ((*n as f64).log10().ceil() as usize).max(1)
158 };
159 ((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
160 as usize
161 }
162 SochValue::Float(f) => {
163 let s = format!("{:.2}", f);
165 ((s.len() as f32 * self.config.float_factor) / self.config.bytes_per_token).ceil()
166 as usize
167 }
168 SochValue::Text(s) => {
169 ((s.len() as f32 * self.config.string_factor) / self.config.bytes_per_token).ceil()
171 as usize
172 }
173 SochValue::Binary(b) => {
174 let hex_len = 2 + b.len() * 2;
176 ((hex_len as f32 * self.config.hex_factor) / self.config.bytes_per_token).ceil()
177 as usize
178 }
179 SochValue::Array(arr) => {
180 let elem_tokens: usize = arr.iter().map(|v| self.estimate_value(v)).sum();
182 let separator_tokens = if arr.is_empty() { 0 } else { arr.len() - 1 };
183 2 + elem_tokens + separator_tokens }
185 }
186 }
187
188 pub fn estimate_row(&self, values: &[SochValue]) -> usize {
190 if values.is_empty() {
191 return 0;
192 }
193
194 let value_tokens: usize = values.iter().map(|v| self.estimate_value(v)).sum();
195 let separator_tokens = (values.len() - 1) * self.config.separator_tokens;
196 let newline = self.config.newline_tokens;
197
198 value_tokens + separator_tokens + newline
199 }
200
201 pub fn estimate_header(&self, table: &str, columns: &[String], row_count: usize) -> usize {
203 let base = self.config.header_tokens;
205 let table_tokens = ((table.len() as f32) / self.config.bytes_per_token).ceil() as usize;
206 let count_tokens = ((row_count as f64).log10().ceil() as usize).max(1);
207 let col_tokens: usize = columns
208 .iter()
209 .map(|c| ((c.len() as f32) / self.config.bytes_per_token).ceil() as usize)
210 .sum();
211
212 base + table_tokens + count_tokens + col_tokens
213 }
214
215 pub fn estimate_table(
217 &self,
218 table: &str,
219 columns: &[String],
220 rows: &[Vec<SochValue>],
221 ) -> usize {
222 let header = self.estimate_header(table, columns, rows.len());
223 let row_tokens: usize = rows.iter().map(|r| self.estimate_row(r)).sum();
224 header + row_tokens
225 }
226
227 pub fn estimate_text(&self, text: &str) -> usize {
229 ((text.len() as f32) / self.config.bytes_per_token).ceil() as usize
230 }
231
232 pub fn truncate_to_tokens(&self, text: &str, max_tokens: usize) -> String {
236 truncate_to_tokens(text, max_tokens, self, "...")
237 }
238}
239
240impl Default for TokenEstimator {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246#[derive(Debug, Clone)]
252pub struct BudgetAllocation {
253 pub full_sections: Vec<String>,
255 pub truncated_sections: Vec<(String, usize, usize)>,
257 pub dropped_sections: Vec<String>,
259 pub tokens_allocated: usize,
261 pub tokens_remaining: usize,
263 pub explain: Vec<AllocationDecision>,
265}
266
267#[derive(Debug, Clone)]
269pub struct AllocationDecision {
270 pub section: String,
272 pub priority: i32,
274 pub requested: usize,
276 pub allocated: usize,
278 pub outcome: AllocationOutcome,
280 pub reason: String,
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq)]
286pub enum AllocationOutcome {
287 Full,
289 Truncated,
291 Dropped,
293}
294
295#[derive(Debug, Clone)]
297pub struct BudgetSection {
298 pub name: String,
300 pub priority: i32,
302 pub estimated_tokens: usize,
304 pub minimum_tokens: Option<usize>,
306 pub required: bool,
308 pub weight: f32,
310}
311
312impl Default for BudgetSection {
313 fn default() -> Self {
314 Self {
315 name: String::new(),
316 priority: 0,
317 estimated_tokens: 0,
318 minimum_tokens: None,
319 required: false,
320 weight: 1.0,
321 }
322 }
323}
324
325#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
327pub enum AllocationStrategy {
328 #[default]
330 GreedyPriority,
331 Proportional,
333 StrictPriority,
335}
336
337pub struct TokenBudgetEnforcer {
342 budget: usize,
344 allocated: AtomicUsize,
346 estimator: TokenEstimator,
348 reserved: usize,
350 strategy: AllocationStrategy,
352}
353
354#[derive(Debug, Clone)]
356pub struct TokenBudgetConfig {
357 pub total_budget: usize,
359 pub reserved_tokens: usize,
361 pub strict: bool,
363 pub default_priority: i32,
365 pub strategy: AllocationStrategy,
367}
368
369impl Default for TokenBudgetConfig {
370 fn default() -> Self {
371 Self {
372 total_budget: 4096,
373 reserved_tokens: 100,
374 strict: false,
375 default_priority: 10,
376 strategy: AllocationStrategy::GreedyPriority,
377 }
378 }
379}
380
381impl TokenBudgetEnforcer {
382 pub fn new(config: TokenBudgetConfig) -> Self {
384 Self {
385 budget: config.total_budget,
386 allocated: AtomicUsize::new(0),
387 estimator: TokenEstimator::new(),
388 reserved: config.reserved_tokens,
389 strategy: config.strategy,
390 }
391 }
392
393 pub fn with_budget(budget: usize) -> Self {
395 Self {
396 budget,
397 allocated: AtomicUsize::new(0),
398 estimator: TokenEstimator::new(),
399 reserved: 0,
400 strategy: AllocationStrategy::GreedyPriority,
401 }
402 }
403
404 pub fn with_estimator(budget: usize, estimator: TokenEstimator) -> Self {
406 Self {
407 budget,
408 allocated: AtomicUsize::new(0),
409 estimator,
410 reserved: 0,
411 strategy: AllocationStrategy::GreedyPriority,
412 }
413 }
414
415 pub fn with_strategy(mut self, strategy: AllocationStrategy) -> Self {
417 self.strategy = strategy;
418 self
419 }
420
421 pub fn reserve(&mut self, tokens: usize) {
423 self.reserved = tokens;
424 }
425
426 pub fn available(&self) -> usize {
428 let allocated = self.allocated.load(Ordering::Acquire);
429 self.budget.saturating_sub(self.reserved + allocated)
430 }
431
432 pub fn total_budget(&self) -> usize {
434 self.budget
435 }
436
437 pub fn allocated(&self) -> usize {
439 self.allocated.load(Ordering::Acquire)
440 }
441
442 pub fn try_allocate(&self, tokens: usize) -> bool {
444 loop {
445 let current = self.allocated.load(Ordering::Acquire);
446 let new_total = current + tokens;
447
448 if new_total + self.reserved > self.budget {
449 return false;
450 }
451
452 if self
453 .allocated
454 .compare_exchange(current, new_total, Ordering::AcqRel, Ordering::Acquire)
455 .is_ok()
456 {
457 return true;
458 }
459 }
461 }
462
463 pub fn allocate_sections(&self, sections: &[BudgetSection]) -> BudgetAllocation {
465 match self.strategy {
466 AllocationStrategy::GreedyPriority => self.allocate_greedy(sections),
467 AllocationStrategy::Proportional => self.allocate_proportional(sections),
468 AllocationStrategy::StrictPriority => self.allocate_strict(sections),
469 }
470 }
471
472 fn allocate_greedy(&self, sections: &[BudgetSection]) -> BudgetAllocation {
474 let mut sorted: Vec<_> = sections.iter().collect();
476 sorted.sort_by_key(|s| s.priority);
477
478 let mut allocation = BudgetAllocation {
479 full_sections: Vec::new(),
480 truncated_sections: Vec::new(),
481 dropped_sections: Vec::new(),
482 tokens_allocated: 0,
483 tokens_remaining: self.budget.saturating_sub(self.reserved),
484 explain: Vec::new(),
485 };
486
487 for section in sorted {
488 let remaining = allocation.tokens_remaining;
489
490 if section.estimated_tokens <= remaining {
491 allocation.full_sections.push(section.name.clone());
493 allocation.tokens_allocated += section.estimated_tokens;
494 allocation.tokens_remaining -= section.estimated_tokens;
495 allocation.explain.push(AllocationDecision {
496 section: section.name.clone(),
497 priority: section.priority,
498 requested: section.estimated_tokens,
499 allocated: section.estimated_tokens,
500 outcome: AllocationOutcome::Full,
501 reason: format!("Fits in remaining budget ({} tokens)", remaining),
502 });
503 } else if let Some(min) = section.minimum_tokens {
504 if min <= remaining {
506 let truncated_to = remaining;
507 allocation.truncated_sections.push((
508 section.name.clone(),
509 section.estimated_tokens,
510 truncated_to,
511 ));
512 allocation.tokens_allocated += truncated_to;
513 allocation.explain.push(AllocationDecision {
514 section: section.name.clone(),
515 priority: section.priority,
516 requested: section.estimated_tokens,
517 allocated: truncated_to,
518 outcome: AllocationOutcome::Truncated,
519 reason: format!(
520 "Truncated from {} to {} tokens (min: {})",
521 section.estimated_tokens, truncated_to, min
522 ),
523 });
524 allocation.tokens_remaining = 0;
525 } else {
526 allocation.dropped_sections.push(section.name.clone());
527 allocation.explain.push(AllocationDecision {
528 section: section.name.clone(),
529 priority: section.priority,
530 requested: section.estimated_tokens,
531 allocated: 0,
532 outcome: AllocationOutcome::Dropped,
533 reason: format!(
534 "Minimum {} exceeds remaining {} tokens",
535 min, remaining
536 ),
537 });
538 }
539 } else {
540 allocation.dropped_sections.push(section.name.clone());
542 allocation.explain.push(AllocationDecision {
543 section: section.name.clone(),
544 priority: section.priority,
545 requested: section.estimated_tokens,
546 allocated: 0,
547 outcome: AllocationOutcome::Dropped,
548 reason: format!(
549 "Requested {} exceeds remaining {} (no truncation allowed)",
550 section.estimated_tokens, remaining
551 ),
552 });
553 }
554 }
555
556 allocation
557 }
558
559 fn allocate_proportional(&self, sections: &[BudgetSection]) -> BudgetAllocation {
566 let available = self.budget.saturating_sub(self.reserved);
567 let total_weight: f32 = sections.iter().map(|s| s.weight).sum();
568
569 if total_weight == 0.0 {
570 return self.allocate_greedy(sections);
571 }
572
573 let mut allocation = BudgetAllocation {
574 full_sections: Vec::new(),
575 truncated_sections: Vec::new(),
576 dropped_sections: Vec::new(),
577 tokens_allocated: 0,
578 tokens_remaining: available,
579 explain: Vec::new(),
580 };
581
582 let mut allocations: Vec<(usize, usize, bool)> = sections
584 .iter()
585 .map(|s| {
586 let proportional = ((available as f32) * s.weight / total_weight).floor() as usize;
587 let capped = proportional.min(s.estimated_tokens);
588 let min = s.minimum_tokens.unwrap_or(0);
589 (capped.max(min), s.estimated_tokens, capped < s.estimated_tokens)
590 })
591 .collect();
592
593 let mut total: usize = allocations.iter().map(|(a, _, _)| *a).sum();
595
596 while total > available {
598 let max_idx = allocations
600 .iter()
601 .enumerate()
602 .filter(|(i, (a, _, _))| {
603 *a > sections[*i].minimum_tokens.unwrap_or(0)
604 })
605 .max_by_key(|(_, (a, _, _))| *a)
606 .map(|(i, _)| i);
607
608 match max_idx {
609 Some(idx) => {
610 let reduce = (total - available).min(allocations[idx].0 - sections[idx].minimum_tokens.unwrap_or(0));
611 allocations[idx].0 -= reduce;
612 total -= reduce;
613 }
614 None => break, }
616 }
617
618 for (i, section) in sections.iter().enumerate() {
620 let (allocated, requested, truncated) = allocations[i];
621
622 if allocated == 0 {
623 allocation.dropped_sections.push(section.name.clone());
624 allocation.explain.push(AllocationDecision {
625 section: section.name.clone(),
626 priority: section.priority,
627 requested,
628 allocated: 0,
629 outcome: AllocationOutcome::Dropped,
630 reason: "No budget available after proportional allocation".to_string(),
631 });
632 } else if truncated {
633 allocation.truncated_sections.push((
634 section.name.clone(),
635 requested,
636 allocated,
637 ));
638 allocation.tokens_allocated += allocated;
639 allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
640 allocation.explain.push(AllocationDecision {
641 section: section.name.clone(),
642 priority: section.priority,
643 requested,
644 allocated,
645 outcome: AllocationOutcome::Truncated,
646 reason: format!(
647 "Proportional allocation: {:.1}% of budget (weight {:.1})",
648 (allocated as f32 / available as f32) * 100.0,
649 section.weight
650 ),
651 });
652 } else {
653 allocation.full_sections.push(section.name.clone());
654 allocation.tokens_allocated += allocated;
655 allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
656 allocation.explain.push(AllocationDecision {
657 section: section.name.clone(),
658 priority: section.priority,
659 requested,
660 allocated,
661 outcome: AllocationOutcome::Full,
662 reason: format!(
663 "Full allocation within proportional budget (weight {:.1})",
664 section.weight
665 ),
666 });
667 }
668 }
669
670 allocation
671 }
672
673 fn allocate_strict(&self, sections: &[BudgetSection]) -> BudgetAllocation {
675 let mut sorted: Vec<_> = sections.iter().collect();
676 sorted.sort_by_key(|s| (if s.required { 0 } else { 1 }, s.priority));
677
678 let mut allocation = BudgetAllocation {
680 full_sections: Vec::new(),
681 truncated_sections: Vec::new(),
682 dropped_sections: Vec::new(),
683 tokens_allocated: 0,
684 tokens_remaining: self.budget.saturating_sub(self.reserved),
685 explain: Vec::new(),
686 };
687
688 for section in sorted.iter().filter(|s| s.required) {
690 let remaining = allocation.tokens_remaining;
691 let min = section.minimum_tokens.unwrap_or(section.estimated_tokens);
692
693 if section.estimated_tokens <= remaining {
694 allocation.full_sections.push(section.name.clone());
695 allocation.tokens_allocated += section.estimated_tokens;
696 allocation.tokens_remaining -= section.estimated_tokens;
697 allocation.explain.push(AllocationDecision {
698 section: section.name.clone(),
699 priority: section.priority,
700 requested: section.estimated_tokens,
701 allocated: section.estimated_tokens,
702 outcome: AllocationOutcome::Full,
703 reason: "Required section - full allocation".to_string(),
704 });
705 } else if min <= remaining {
706 allocation.truncated_sections.push((
707 section.name.clone(),
708 section.estimated_tokens,
709 remaining,
710 ));
711 allocation.tokens_allocated += remaining;
712 allocation.explain.push(AllocationDecision {
713 section: section.name.clone(),
714 priority: section.priority,
715 requested: section.estimated_tokens,
716 allocated: remaining,
717 outcome: AllocationOutcome::Truncated,
718 reason: "Required section - truncated to fit".to_string(),
719 });
720 allocation.tokens_remaining = 0;
721 }
722 }
724
725 for section in sorted.iter().filter(|s| !s.required) {
727 let remaining = allocation.tokens_remaining;
728
729 if remaining == 0 {
730 allocation.dropped_sections.push(section.name.clone());
731 allocation.explain.push(AllocationDecision {
732 section: section.name.clone(),
733 priority: section.priority,
734 requested: section.estimated_tokens,
735 allocated: 0,
736 outcome: AllocationOutcome::Dropped,
737 reason: "No budget remaining after required sections".to_string(),
738 });
739 continue;
740 }
741
742 if section.estimated_tokens <= remaining {
743 allocation.full_sections.push(section.name.clone());
744 allocation.tokens_allocated += section.estimated_tokens;
745 allocation.tokens_remaining -= section.estimated_tokens;
746 allocation.explain.push(AllocationDecision {
747 section: section.name.clone(),
748 priority: section.priority,
749 requested: section.estimated_tokens,
750 allocated: section.estimated_tokens,
751 outcome: AllocationOutcome::Full,
752 reason: "Optional section - fits in remaining budget".to_string(),
753 });
754 } else if let Some(min) = section.minimum_tokens {
755 if min <= remaining {
756 allocation.truncated_sections.push((
757 section.name.clone(),
758 section.estimated_tokens,
759 remaining,
760 ));
761 allocation.tokens_allocated += remaining;
762 allocation.explain.push(AllocationDecision {
763 section: section.name.clone(),
764 priority: section.priority,
765 requested: section.estimated_tokens,
766 allocated: remaining,
767 outcome: AllocationOutcome::Truncated,
768 reason: "Optional section - truncated to fit".to_string(),
769 });
770 allocation.tokens_remaining = 0;
771 } else {
772 allocation.dropped_sections.push(section.name.clone());
773 allocation.explain.push(AllocationDecision {
774 section: section.name.clone(),
775 priority: section.priority,
776 requested: section.estimated_tokens,
777 allocated: 0,
778 outcome: AllocationOutcome::Dropped,
779 reason: format!("Minimum {} exceeds remaining {}", min, remaining),
780 });
781 }
782 } else {
783 allocation.dropped_sections.push(section.name.clone());
784 allocation.explain.push(AllocationDecision {
785 section: section.name.clone(),
786 priority: section.priority,
787 requested: section.estimated_tokens,
788 allocated: 0,
789 outcome: AllocationOutcome::Dropped,
790 reason: format!("Requested {} exceeds remaining {}", section.estimated_tokens, remaining),
791 });
792 }
793 }
794
795 allocation
796 }
797
798 pub fn reset(&self) {
800 self.allocated.store(0, Ordering::Release);
801 }
802
803 pub fn estimator(&self) -> &TokenEstimator {
805 &self.estimator
806 }
807}
808
809impl BudgetAllocation {
814 pub fn explain_text(&self) -> String {
816 let mut output = String::new();
817 output.push_str("=== CONTEXT BUDGET ALLOCATION ===\n\n");
818 output.push_str(&format!(
819 "Total Allocated: {} tokens\n",
820 self.tokens_allocated
821 ));
822 output.push_str(&format!("Remaining: {} tokens\n\n", self.tokens_remaining));
823
824 output.push_str("SECTIONS:\n");
825 for decision in &self.explain {
826 let status = match decision.outcome {
827 AllocationOutcome::Full => "✓ FULL",
828 AllocationOutcome::Truncated => "◐ TRUNCATED",
829 AllocationOutcome::Dropped => "✗ DROPPED",
830 };
831 output.push_str(&format!(
832 " [{:^12}] {} (priority {})\n",
833 status, decision.section, decision.priority
834 ));
835 output.push_str(&format!(
836 " Requested: {}, Allocated: {}\n",
837 decision.requested, decision.allocated
838 ));
839 output.push_str(&format!(" Reason: {}\n", decision.reason));
840 }
841
842 output
843 }
844
845 pub fn explain_json(&self) -> String {
847 serde_json::to_string_pretty(&ExplainOutput {
848 tokens_allocated: self.tokens_allocated,
849 tokens_remaining: self.tokens_remaining,
850 full_sections: self.full_sections.clone(),
851 truncated_sections: self.truncated_sections.clone(),
852 dropped_sections: self.dropped_sections.clone(),
853 decisions: self.explain.iter().map(|d| ExplainDecision {
854 section: d.section.clone(),
855 priority: d.priority,
856 requested: d.requested,
857 allocated: d.allocated,
858 outcome: format!("{:?}", d.outcome),
859 reason: d.reason.clone(),
860 }).collect(),
861 }).unwrap_or_else(|_| "{}".to_string())
862 }
863}
864
865#[derive(serde::Serialize)]
866struct ExplainOutput {
867 tokens_allocated: usize,
868 tokens_remaining: usize,
869 full_sections: Vec<String>,
870 truncated_sections: Vec<(String, usize, usize)>,
871 dropped_sections: Vec<String>,
872 decisions: Vec<ExplainDecision>,
873}
874
875#[derive(serde::Serialize)]
876struct ExplainDecision {
877 section: String,
878 priority: i32,
879 requested: usize,
880 allocated: usize,
881 outcome: String,
882 reason: String,
883}
884
885pub fn truncate_to_tokens(
891 text: &str,
892 max_tokens: usize,
893 estimator: &TokenEstimator,
894 suffix: &str,
895) -> String {
896 let current = estimator.estimate_text(text);
897
898 if current <= max_tokens {
899 return text.to_string();
900 }
901
902 let suffix_tokens = estimator.estimate_text(suffix);
903 let target_tokens = max_tokens.saturating_sub(suffix_tokens);
904
905 if target_tokens == 0 {
906 return suffix.to_string();
907 }
908
909 let mut low = 0;
911 let mut high = text.len();
912
913 while low < high {
914 let mid = (low + high).div_ceil(2);
915
916 let boundary = text
918 .char_indices()
919 .take_while(|(i, _)| *i < mid)
920 .last()
921 .map(|(i, c)| i + c.len_utf8())
922 .unwrap_or(0);
923
924 let truncated = &text[..boundary];
925 let tokens = estimator.estimate_text(truncated);
926
927 if tokens <= target_tokens {
928 low = boundary;
929 } else {
930 high = boundary.saturating_sub(1);
931 }
932 }
933
934 let truncated = &text[..low];
936 let word_boundary = truncated.rfind(|c: char| c.is_whitespace()).unwrap_or(low);
937
938 format!("{}{}", &text[..word_boundary], suffix)
939}
940
941pub fn truncate_rows(
943 rows: &[Vec<SochValue>],
944 max_tokens: usize,
945 estimator: &TokenEstimator,
946) -> Vec<Vec<SochValue>> {
947 let mut result = Vec::new();
948 let mut used = 0;
949
950 for row in rows {
951 let row_tokens = estimator.estimate_row(row);
952
953 if used + row_tokens <= max_tokens {
954 result.push(row.clone());
955 used += row_tokens;
956 } else {
957 break; }
959 }
960
961 result
962}
963
964#[cfg(test)]
969mod tests {
970 use super::*;
971
972 #[test]
973 fn test_estimate_value_int() {
974 let est = TokenEstimator::new();
975
976 assert!(est.estimate_value(&SochValue::Int(0)) >= 1);
978 assert!(est.estimate_value(&SochValue::Int(42)) >= 1);
979
980 let small = est.estimate_value(&SochValue::Int(42));
982 let large = est.estimate_value(&SochValue::Int(1_000_000_000));
983 assert!(large >= small);
984 }
985
986 #[test]
987 fn test_estimate_value_text() {
988 let est = TokenEstimator::new();
989
990 let short = est.estimate_value(&SochValue::Text("hello".to_string()));
991 let long = est.estimate_value(&SochValue::Text(
992 "hello world this is a longer string".to_string(),
993 ));
994
995 assert!(long > short);
996 }
997
998 #[test]
999 #[allow(clippy::approx_constant)]
1000 fn test_estimate_row() {
1001 let est = TokenEstimator::new();
1002
1003 let row = vec![
1004 SochValue::Int(1),
1005 SochValue::Text("Alice".to_string()),
1006 SochValue::Float(3.14),
1007 ];
1008
1009 let tokens = est.estimate_row(&row);
1010
1011 assert!(tokens >= 3); }
1014
1015 #[test]
1016 fn test_estimate_table() {
1017 let est = TokenEstimator::new();
1018
1019 let columns = vec!["id".to_string(), "name".to_string()];
1020 let rows = vec![
1021 vec![SochValue::Int(1), SochValue::Text("Alice".to_string())],
1022 vec![SochValue::Int(2), SochValue::Text("Bob".to_string())],
1023 ];
1024
1025 let tokens = est.estimate_table("users", &columns, &rows);
1026
1027 assert!(tokens > est.estimate_row(&rows[0]) * 2);
1029 }
1030
1031 #[test]
1032 fn test_budget_enforcer_allocation() {
1033 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1034
1035 assert!(enforcer.try_allocate(500));
1036 assert_eq!(enforcer.allocated(), 500);
1037 assert_eq!(enforcer.available(), 500);
1038
1039 assert!(enforcer.try_allocate(400));
1040 assert_eq!(enforcer.allocated(), 900);
1041
1042 assert!(!enforcer.try_allocate(200));
1044 assert_eq!(enforcer.allocated(), 900);
1045 }
1046
1047 #[test]
1048 fn test_budget_enforcer_reset() {
1049 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1050
1051 enforcer.try_allocate(800);
1052 assert_eq!(enforcer.allocated(), 800);
1053
1054 enforcer.reset();
1055 assert_eq!(enforcer.allocated(), 0);
1056 }
1057
1058 #[test]
1059 fn test_allocate_sections() {
1060 let enforcer = TokenBudgetEnforcer::with_budget(1000);
1061
1062 let sections = vec![
1063 BudgetSection {
1064 name: "A".to_string(),
1065 priority: 0,
1066 estimated_tokens: 300,
1067 minimum_tokens: None,
1068 required: true,
1069 weight: 1.0,
1070 },
1071 BudgetSection {
1072 name: "B".to_string(),
1073 priority: 1,
1074 estimated_tokens: 400,
1075 minimum_tokens: Some(200),
1076 required: false,
1077 weight: 1.0,
1078 },
1079 BudgetSection {
1080 name: "C".to_string(),
1081 priority: 2,
1082 estimated_tokens: 500,
1083 minimum_tokens: None,
1084 required: false,
1085 weight: 1.0,
1086 },
1087 ];
1088
1089 let allocation = enforcer.allocate_sections(§ions);
1090
1091 assert!(allocation.full_sections.contains(&"A".to_string()));
1093
1094 assert!(allocation.dropped_sections.contains(&"C".to_string()));
1097
1098 assert!(allocation.tokens_allocated <= 1000);
1099 }
1100
1101 #[test]
1102 fn test_allocate_by_priority() {
1103 let enforcer = TokenBudgetEnforcer::with_budget(500);
1104
1105 let sections = vec![
1106 BudgetSection {
1107 name: "LowPriority".to_string(),
1108 priority: 10,
1109 estimated_tokens: 200,
1110 minimum_tokens: None,
1111 required: false,
1112 weight: 1.0,
1113 },
1114 BudgetSection {
1115 name: "HighPriority".to_string(),
1116 priority: 0,
1117 estimated_tokens: 400,
1118 minimum_tokens: None,
1119 required: true,
1120 weight: 1.0,
1121 },
1122 ];
1123
1124 let allocation = enforcer.allocate_sections(§ions);
1125
1126 assert!(
1128 allocation
1129 .full_sections
1130 .contains(&"HighPriority".to_string())
1131 );
1132
1133 assert!(
1135 allocation
1136 .dropped_sections
1137 .contains(&"LowPriority".to_string())
1138 );
1139 }
1140
1141 #[test]
1142 fn test_truncate_to_tokens() {
1143 let est = TokenEstimator::new();
1144
1145 let text = "This is a long text that needs to be truncated to fit within the token budget";
1146 let truncated = truncate_to_tokens(text, 10, &est, "...");
1147
1148 assert!(truncated.len() < text.len());
1150
1151 assert!(truncated.ends_with("..."));
1153
1154 assert!(est.estimate_text(&truncated) <= 10);
1156 }
1157
1158 #[test]
1159 fn test_truncate_rows() {
1160 let est = TokenEstimator::new();
1161
1162 let rows: Vec<Vec<SochValue>> = (0..100)
1163 .map(|i| vec![SochValue::Int(i), SochValue::Text(format!("row{}", i))])
1164 .collect();
1165
1166 let truncated = truncate_rows(&rows, 50, &est);
1167
1168 assert!(truncated.len() < rows.len());
1170
1171 let total: usize = truncated.iter().map(|r| est.estimate_row(r)).sum();
1173 assert!(total <= 50);
1174 }
1175
1176 #[test]
1177 fn test_reserved_budget() {
1178 let mut enforcer = TokenBudgetEnforcer::with_budget(1000);
1179 enforcer.reserve(200);
1180
1181 assert_eq!(enforcer.available(), 800);
1182
1183 assert!(enforcer.try_allocate(700));
1184 assert_eq!(enforcer.available(), 100);
1185
1186 assert!(!enforcer.try_allocate(200));
1188 }
1189
1190 #[test]
1191 fn test_estimator_configs() {
1192 let default = TokenEstimator::new();
1193 let gpt4 = TokenEstimator::with_config(TokenEstimatorConfig::gpt4());
1194 let conservative = TokenEstimator::with_config(TokenEstimatorConfig::conservative());
1195
1196 let text = "Hello, this is a test string for comparing token estimation across different configurations.";
1197
1198 let default_est = default.estimate_text(text);
1199 let gpt4_est = gpt4.estimate_text(text);
1200 let conservative_est = conservative.estimate_text(text);
1201
1202 assert!(conservative_est >= default_est);
1204
1205 assert!(default_est > 0);
1207 assert!(gpt4_est > 0);
1208 assert!(conservative_est > 0);
1209 }
1210
1211 #[test]
1212 fn test_section_with_truncation() {
1213 let enforcer = TokenBudgetEnforcer::with_budget(600);
1214
1215 let sections = vec![
1216 BudgetSection {
1217 name: "Required".to_string(),
1218 priority: 0,
1219 estimated_tokens: 500,
1220 minimum_tokens: None,
1221 required: true,
1222 weight: 1.0,
1223 },
1224 BudgetSection {
1225 name: "Optional".to_string(),
1226 priority: 1,
1227 estimated_tokens: 300,
1228 minimum_tokens: Some(50), required: false,
1230 weight: 1.0,
1231 },
1232 ];
1233
1234 let allocation = enforcer.allocate_sections(§ions);
1235
1236 assert!(allocation.full_sections.contains(&"Required".to_string()));
1238
1239 assert!(
1241 allocation
1242 .truncated_sections
1243 .iter()
1244 .any(|(n, _, _)| n == "Optional")
1245 );
1246 }
1247}