Skip to main content

synaptic_middleware/
summarization.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError};
5
6use crate::{AgentMiddleware, ModelRequest};
7
8/// Automatically summarizes conversation history when it exceeds a token limit.
9///
10/// Uses a configurable token counter to estimate message sizes. When
11/// the total exceeds `max_tokens`, older messages (excluding the
12/// system prompt) are summarized into a single summary message using
13/// the provided `ChatModel`.
14pub struct SummarizationMiddleware {
15    model: Arc<dyn ChatModel>,
16    max_tokens: usize,
17    token_counter: Box<dyn Fn(&Message) -> usize + Send + Sync>,
18}
19
20impl SummarizationMiddleware {
21    /// Create a new summarization middleware.
22    ///
23    /// * `model` — The model to use for generating summaries.
24    /// * `max_tokens` — When total tokens exceed this, summarize older messages.
25    /// * `token_counter` — Function that estimates the token count for a message.
26    pub fn new(
27        model: Arc<dyn ChatModel>,
28        max_tokens: usize,
29        token_counter: impl Fn(&Message) -> usize + Send + Sync + 'static,
30    ) -> Self {
31        Self {
32            model,
33            max_tokens,
34            token_counter: Box::new(token_counter),
35        }
36    }
37}
38
39#[async_trait]
40impl AgentMiddleware for SummarizationMiddleware {
41    async fn before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
42        let total: usize = request
43            .messages
44            .iter()
45            .map(|m| (self.token_counter)(m))
46            .sum();
47        if total <= self.max_tokens {
48            return Ok(());
49        }
50
51        // Keep the most recent messages that fit in half the budget,
52        // summarize everything before them.
53        let half_budget = self.max_tokens / 2;
54        let mut keep_from = request.messages.len();
55        let mut kept_tokens = 0;
56        for (i, msg) in request.messages.iter().enumerate().rev() {
57            let t = (self.token_counter)(msg);
58            if kept_tokens + t > half_budget {
59                break;
60            }
61            kept_tokens += t;
62            keep_from = i;
63        }
64
65        if keep_from == 0 {
66            // Everything fits or there's nothing to summarize
67            return Ok(());
68        }
69
70        let to_summarize: Vec<_> = request.messages[..keep_from].to_vec();
71
72        // Build a summary request
73        let summary_prompt = Message::human(
74            "Summarize the following conversation concisely, preserving key facts and context:\n\n"
75                .to_string()
76                + &to_summarize
77                    .iter()
78                    .map(|m| format!("{}: {}", m.role(), m.content()))
79                    .collect::<Vec<_>>()
80                    .join("\n"),
81        );
82
83        let summary_req = ChatRequest::new(vec![
84            Message::system("You are a conversation summarizer. Output a brief summary."),
85            summary_prompt,
86        ]);
87
88        let summary_resp = self.model.chat(summary_req).await?;
89        let summary_text = summary_resp.message.content().to_string();
90
91        // Replace old messages with the summary
92        let mut new_messages = vec![Message::system(format!(
93            "[Previous conversation summary]: {summary_text}"
94        ))];
95        new_messages.extend_from_slice(&request.messages[keep_from..]);
96        request.messages = new_messages;
97
98        Ok(())
99    }
100}