synaptic_middleware/
context_editing.rs1use async_trait::async_trait;
2use synaptic_core::{Message, SynapticError};
3
4use crate::{AgentMiddleware, ModelRequest};
5
6#[derive(Debug, Clone)]
8pub enum ContextStrategy {
9 LastN(usize),
11 StripToolCalls,
14 StripAndTruncate(usize),
16}
17
18pub struct ContextEditingMiddleware {
24 strategy: ContextStrategy,
25}
26
27impl ContextEditingMiddleware {
28 pub fn new(strategy: ContextStrategy) -> Self {
29 Self { strategy }
30 }
31
32 pub fn last_n(n: usize) -> Self {
34 Self::new(ContextStrategy::LastN(n))
35 }
36
37 pub fn strip_tool_calls() -> Self {
39 Self::new(ContextStrategy::StripToolCalls)
40 }
41
42 fn apply_last_n(messages: &mut Vec<Message>, n: usize) {
43 if messages.len() <= n {
44 return;
45 }
46
47 let system_count = messages.iter().take_while(|m| m.is_system()).count();
49 let non_system = &messages[system_count..];
50 if non_system.len() <= n {
51 return;
52 }
53
54 let keep_from = non_system.len() - n;
55 let mut new_msgs: Vec<Message> = messages[..system_count].to_vec();
56 new_msgs.extend_from_slice(&messages[system_count + keep_from..]);
57 *messages = new_msgs;
58 }
59
60 fn apply_strip_tool_calls(messages: &mut Vec<Message>) {
61 messages.retain(|m| {
62 if m.is_tool() {
65 return false;
66 }
67 if m.is_ai() && !m.tool_calls().is_empty() && m.content().is_empty() {
68 return false;
69 }
70 true
71 });
72 }
73}
74
75#[async_trait]
76impl AgentMiddleware for ContextEditingMiddleware {
77 async fn before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
78 match &self.strategy {
79 ContextStrategy::LastN(n) => {
80 Self::apply_last_n(&mut request.messages, *n);
81 }
82 ContextStrategy::StripToolCalls => {
83 Self::apply_strip_tool_calls(&mut request.messages);
84 }
85 ContextStrategy::StripAndTruncate(n) => {
86 Self::apply_strip_tool_calls(&mut request.messages);
87 Self::apply_last_n(&mut request.messages, *n);
88 }
89 }
90 Ok(())
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn last_n_preserves_system() {
100 let mut msgs = vec![
101 Message::system("sys"),
102 Message::human("1"),
103 Message::ai("2"),
104 Message::human("3"),
105 Message::ai("4"),
106 ];
107 ContextEditingMiddleware::apply_last_n(&mut msgs, 2);
108 assert_eq!(msgs.len(), 3); assert!(msgs[0].is_system());
110 assert_eq!(msgs[1].content(), "3");
111 assert_eq!(msgs[2].content(), "4");
112 }
113
114 #[test]
115 fn strip_tool_calls() {
116 let mut msgs = vec![
117 Message::human("hello"),
118 Message::ai_with_tool_calls(
119 "",
120 vec![synaptic_core::ToolCall {
121 id: "1".into(),
122 name: "test".into(),
123 arguments: serde_json::json!({}),
124 }],
125 ),
126 Message::tool("result", "1"),
127 Message::ai("final answer"),
128 ];
129 ContextEditingMiddleware::apply_strip_tool_calls(&mut msgs);
130 assert_eq!(msgs.len(), 2);
131 assert!(msgs[0].is_human());
132 assert_eq!(msgs[1].content(), "final answer");
133 }
134}