Skip to main content

synaptic_memory/
token_buffer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{MemoryStore, Message, SynapseError};
5
6/// A memory strategy that keeps messages within a token budget.
7///
8/// Uses a simple estimator (~4 chars per token) to approximate token counts.
9/// On `load`, removes the oldest messages until the total estimated tokens
10/// fit within `max_tokens`.
11pub struct ConversationTokenBufferMemory {
12    store: Arc<dyn MemoryStore>,
13    max_tokens: usize,
14}
15
16impl ConversationTokenBufferMemory {
17    /// Create a new token buffer memory wrapping the given store.
18    pub fn new(store: Arc<dyn MemoryStore>, max_tokens: usize) -> Self {
19        Self { store, max_tokens }
20    }
21
22    /// Estimate the number of tokens in a text string.
23    ///
24    /// Uses the simple heuristic of ~4 characters per token, with a minimum of 1.
25    pub fn estimate_tokens(text: &str) -> usize {
26        text.len() / 4 + 1
27    }
28}
29
30#[async_trait]
31impl MemoryStore for ConversationTokenBufferMemory {
32    async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapseError> {
33        self.store.append(session_id, message).await
34    }
35
36    async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapseError> {
37        let messages = self.store.load(session_id).await?;
38
39        // Calculate total tokens for all messages
40        let total_tokens: usize = messages
41            .iter()
42            .map(|m| Self::estimate_tokens(m.content()))
43            .sum();
44
45        if total_tokens <= self.max_tokens {
46            return Ok(messages);
47        }
48
49        // Remove oldest messages until we fit within the budget
50        let mut kept = messages;
51        let mut current_tokens: usize = kept
52            .iter()
53            .map(|m| Self::estimate_tokens(m.content()))
54            .sum();
55
56        while current_tokens > self.max_tokens && !kept.is_empty() {
57            let removed = kept.remove(0);
58            current_tokens -= Self::estimate_tokens(removed.content());
59        }
60
61        Ok(kept)
62    }
63
64    async fn clear(&self, session_id: &str) -> Result<(), SynapseError> {
65        self.store.clear(session_id).await
66    }
67}