Skip to main content

synaptic_condenser/
token_budget.rs

1use std::sync::Arc;
2
3use crate::Condenser;
4use async_trait::async_trait;
5use synaptic_core::{Message, SynapticError, TokenCounter};
6
7/// Trims messages to fit within a token budget, keeping the most recent messages.
8pub struct TokenBudgetCondenser {
9    max_tokens: usize,
10    counter: Arc<dyn TokenCounter>,
11    include_system: bool,
12}
13
14impl TokenBudgetCondenser {
15    pub fn new(max_tokens: usize, counter: Arc<dyn TokenCounter>) -> Self {
16        Self {
17            max_tokens,
18            counter,
19            include_system: true,
20        }
21    }
22
23    pub fn with_include_system(mut self, include: bool) -> Self {
24        self.include_system = include;
25        self
26    }
27}
28
29#[async_trait]
30impl Condenser for TokenBudgetCondenser {
31    async fn condense(&self, messages: Vec<Message>) -> Result<Vec<Message>, SynapticError> {
32        let total = self.counter.count_messages(&messages);
33        if total <= self.max_tokens {
34            return Ok(messages);
35        }
36
37        // Preserve system message if configured
38        let (system_msg, rest) =
39            if self.include_system && !messages.is_empty() && messages[0].is_system() {
40                (Some(messages[0].clone()), &messages[1..])
41            } else {
42                (None, messages.as_slice())
43            };
44
45        let system_tokens = system_msg
46            .as_ref()
47            .map(|m| self.counter.count_messages(std::slice::from_ref(m)))
48            .unwrap_or(0);
49        let budget = self.max_tokens.saturating_sub(system_tokens);
50
51        // Keep most recent messages that fit
52        let mut kept = Vec::new();
53        let mut used = 0;
54        for msg in rest.iter().rev() {
55            let tokens = self.counter.count_messages(std::slice::from_ref(msg));
56            if used + tokens > budget {
57                break;
58            }
59            used += tokens;
60            kept.push(msg.clone());
61        }
62        kept.reverse();
63
64        let mut result = Vec::new();
65        if let Some(sys) = system_msg {
66            result.push(sys);
67        }
68        result.extend(kept);
69        Ok(result)
70    }
71}