Skip to main content

synaptic_middleware/
context_editing.rs

1use async_trait::async_trait;
2use synaptic_core::{Message, SynapticError};
3
4use crate::{AgentMiddleware, ModelRequest};
5
6/// Strategy for editing context before model calls.
7#[derive(Debug, Clone)]
8pub enum ContextStrategy {
9    /// Keep the last N messages (preserving system messages).
10    LastN(usize),
11    /// Remove tool call/result pairs from the history, keeping only
12    /// human and AI content messages.
13    StripToolCalls,
14    /// Apply both: strip tool calls, then keep last N.
15    StripAndTruncate(usize),
16}
17
18/// Manages conversation context by trimming or filtering messages
19/// before each model invocation.
20///
21/// This is useful for keeping the context window manageable without
22/// full summarization.
23pub struct ContextEditingMiddleware {
24    strategy: ContextStrategy,
25}
26
27impl ContextEditingMiddleware {
28    pub fn new(strategy: ContextStrategy) -> Self {
29        Self { strategy }
30    }
31
32    /// Keep last N messages, always preserving leading system messages.
33    pub fn last_n(n: usize) -> Self {
34        Self::new(ContextStrategy::LastN(n))
35    }
36
37    /// Strip tool call/result message pairs from history.
38    pub fn strip_tool_calls() -> Self {
39        Self::new(ContextStrategy::StripToolCalls)
40    }
41
42    fn apply_last_n(messages: &mut Vec<Message>, n: usize) {
43        if messages.len() <= n {
44            return;
45        }
46
47        // Preserve leading system messages
48        let system_count = messages.iter().take_while(|m| m.is_system()).count();
49        let non_system = &messages[system_count..];
50        if non_system.len() <= n {
51            return;
52        }
53
54        let keep_from = non_system.len() - n;
55        let mut new_msgs: Vec<Message> = messages[..system_count].to_vec();
56        new_msgs.extend_from_slice(&messages[system_count + keep_from..]);
57        *messages = new_msgs;
58    }
59
60    fn apply_strip_tool_calls(messages: &mut Vec<Message>) {
61        messages.retain(|m| {
62            // Keep all non-tool messages, but strip AI messages that
63            // contain only tool calls (no text content)
64            if m.is_tool() {
65                return false;
66            }
67            if m.is_ai() && !m.tool_calls().is_empty() && m.content().is_empty() {
68                return false;
69            }
70            true
71        });
72    }
73}
74
75#[async_trait]
76impl AgentMiddleware for ContextEditingMiddleware {
77    async fn before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
78        match &self.strategy {
79            ContextStrategy::LastN(n) => {
80                Self::apply_last_n(&mut request.messages, *n);
81            }
82            ContextStrategy::StripToolCalls => {
83                Self::apply_strip_tool_calls(&mut request.messages);
84            }
85            ContextStrategy::StripAndTruncate(n) => {
86                Self::apply_strip_tool_calls(&mut request.messages);
87                Self::apply_last_n(&mut request.messages, *n);
88            }
89        }
90        Ok(())
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn last_n_preserves_system() {
100        let mut msgs = vec![
101            Message::system("sys"),
102            Message::human("1"),
103            Message::ai("2"),
104            Message::human("3"),
105            Message::ai("4"),
106        ];
107        ContextEditingMiddleware::apply_last_n(&mut msgs, 2);
108        assert_eq!(msgs.len(), 3); // sys + last 2
109        assert!(msgs[0].is_system());
110        assert_eq!(msgs[1].content(), "3");
111        assert_eq!(msgs[2].content(), "4");
112    }
113
114    #[test]
115    fn strip_tool_calls() {
116        let mut msgs = vec![
117            Message::human("hello"),
118            Message::ai_with_tool_calls(
119                "",
120                vec![synaptic_core::ToolCall {
121                    id: "1".into(),
122                    name: "test".into(),
123                    arguments: serde_json::json!({}),
124                }],
125            ),
126            Message::tool("result", "1"),
127            Message::ai("final answer"),
128        ];
129        ContextEditingMiddleware::apply_strip_tool_calls(&mut msgs);
130        assert_eq!(msgs.len(), 2);
131        assert!(msgs[0].is_human());
132        assert_eq!(msgs[1].content(), "final answer");
133    }
134}