Skip to main content

synaptic_deep/middleware/
summarization.rs

1use async_trait::async_trait;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError};
5use synaptic_middleware::{AgentMiddleware, ModelRequest};
6
7use crate::backend::Backend;
8
9/// Middleware that auto-summarizes conversation history when approaching the token limit.
10///
11/// Before each model call, estimates the token count. If it exceeds
12/// `max_input_tokens * threshold_fraction`, older messages are summarized by the model
13/// and the full history is offloaded to a file in the backend.
14pub struct DeepSummarizationMiddleware {
15    backend: Arc<dyn Backend>,
16    model: Arc<dyn ChatModel>,
17    max_input_tokens: usize,
18    threshold_fraction: f64,
19    file_counter: AtomicUsize,
20}
21
22impl DeepSummarizationMiddleware {
23    pub fn new(
24        backend: Arc<dyn Backend>,
25        model: Arc<dyn ChatModel>,
26        max_input_tokens: usize,
27        threshold_fraction: f64,
28    ) -> Self {
29        Self {
30            backend,
31            model,
32            max_input_tokens,
33            threshold_fraction,
34            file_counter: AtomicUsize::new(0),
35        }
36    }
37
38    fn estimate_tokens(messages: &[Message]) -> usize {
39        // ~4 chars per token heuristic
40        messages.iter().map(|m| m.content().len() / 4 + 1).sum()
41    }
42}
43
44#[async_trait]
45impl AgentMiddleware for DeepSummarizationMiddleware {
46    async fn before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
47        let threshold = (self.max_input_tokens as f64 * self.threshold_fraction) as usize;
48        let estimated = Self::estimate_tokens(&request.messages);
49
50        if estimated <= threshold || request.messages.len() <= 2 {
51            return Ok(());
52        }
53
54        // Save full history to backend
55        let counter = self.file_counter.fetch_add(1, Ordering::Relaxed);
56        let history_path = format!(".context/history_{}.md", counter);
57        let full_history = request
58            .messages
59            .iter()
60            .map(|m| format!("## {}\n{}", m.role(), m.content()))
61            .collect::<Vec<_>>()
62            .join("\n\n");
63        let _ = self.backend.write_file(&history_path, &full_history).await;
64
65        // Keep last 2 messages, summarize the rest
66        let keep_count = 2.min(request.messages.len());
67        let to_summarize = &request.messages[..request.messages.len() - keep_count];
68
69        if to_summarize.is_empty() {
70            return Ok(());
71        }
72
73        let summary_prompt = format!(
74            "Summarize the following conversation concisely, \
75             preserving key decisions, facts, and context:\n\n{}",
76            to_summarize
77                .iter()
78                .map(|m| format!("{}: {}", m.role(), m.content()))
79                .collect::<Vec<_>>()
80                .join("\n")
81        );
82
83        let summary_request = ChatRequest::new(vec![Message::human(summary_prompt)]);
84        let summary_response = self.model.chat(summary_request).await?;
85        let summary = summary_response.message.content().to_string();
86
87        // Replace old messages with summary + recent messages
88        let recent: Vec<Message> = request.messages[request.messages.len() - keep_count..].to_vec();
89        request.messages = vec![Message::system(format!(
90            "[Conversation summary (full history saved to {})]\n{}",
91            history_path, summary
92        ))];
93        request.messages.extend(recent);
94
95        Ok(())
96    }
97}