synwire_core/agents/
middleware.rs1use serde_json::Value;
4
5use crate::BoxFuture;
6use crate::agents::error::AgentError;
7use crate::tools::Tool;
8
9#[derive(Debug, Clone)]
11pub struct MiddlewareInput {
12 pub messages: Vec<Value>,
14 pub context: Value,
16}
17
18#[derive(Debug)]
20#[non_exhaustive]
21pub enum MiddlewareResult {
22 Continue(MiddlewareInput),
24 Terminate(String),
26}
27
28pub trait Middleware: Send + Sync {
30 fn name(&self) -> &str;
32
33 fn process(
37 &self,
38 input: MiddlewareInput,
39 ) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
40 Box::pin(async move { Ok(MiddlewareResult::Continue(input)) })
41 }
42
43 fn tools(&self) -> Vec<Box<dyn Tool>> {
45 Vec::new()
46 }
47
48 fn system_prompt_additions(&self) -> Vec<String> {
52 Vec::new()
53 }
54}
55
56pub struct MiddlewareStack {
62 components: Vec<Box<dyn Middleware>>,
63}
64
65impl std::fmt::Debug for MiddlewareStack {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("MiddlewareStack")
68 .field(
69 "components",
70 &self.components.iter().map(|m| m.name()).collect::<Vec<_>>(),
71 )
72 .finish()
73 }
74}
75
76impl MiddlewareStack {
77 #[must_use]
79 pub fn new() -> Self {
80 Self {
81 components: Vec::new(),
82 }
83 }
84
85 pub fn push(&mut self, middleware: impl Middleware + 'static) {
87 self.components.push(Box::new(middleware));
88 }
89
90 pub async fn run(&self, mut input: MiddlewareInput) -> Result<MiddlewareResult, AgentError> {
92 for mw in &self.components {
93 match mw.process(input).await? {
94 MiddlewareResult::Continue(next_input) => input = next_input,
95 term @ MiddlewareResult::Terminate(_) => return Ok(term),
96 }
97 }
98 Ok(MiddlewareResult::Continue(input))
99 }
100
101 #[must_use]
103 pub fn system_prompt_additions(&self) -> Vec<String> {
104 self.components
105 .iter()
106 .flat_map(|m| m.system_prompt_additions())
107 .collect()
108 }
109
110 pub fn tools(&self) -> Vec<Box<dyn Tool>> {
112 self.components.iter().flat_map(|m| m.tools()).collect()
113 }
114}
115
116impl Default for MiddlewareStack {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122#[cfg(test)]
123#[allow(
124 clippy::unwrap_used,
125 clippy::expect_used,
126 clippy::panic,
127 clippy::unnecessary_literal_bound
128)]
129mod tests {
130 use super::*;
131
132 struct OrderRecorder {
133 name: &'static str,
134 order: std::sync::Arc<std::sync::Mutex<Vec<&'static str>>>,
135 }
136
137 impl Middleware for OrderRecorder {
138 fn name(&self) -> &str {
139 self.name
140 }
141
142 fn process(
143 &self,
144 input: MiddlewareInput,
145 ) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
146 let order = self.order.clone();
147 Box::pin(async move {
148 if let Ok(mut g) = order.lock() {
149 g.push(self.name);
150 }
151 Ok(MiddlewareResult::Continue(input))
152 })
153 }
154
155 fn system_prompt_additions(&self) -> Vec<String> {
156 vec![format!("[{}]", self.name)]
157 }
158 }
159
160 struct EarlyTerminator;
161 impl Middleware for EarlyTerminator {
162 fn name(&self) -> &str {
163 "terminator"
164 }
165 fn process(
166 &self,
167 _input: MiddlewareInput,
168 ) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
169 Box::pin(async { Ok(MiddlewareResult::Terminate("stop".to_string())) })
170 }
171 }
172
173 fn base_input() -> MiddlewareInput {
174 MiddlewareInput {
175 messages: Vec::new(),
176 context: serde_json::json!({}),
177 }
178 }
179
180 #[tokio::test]
181 async fn test_stack_order() {
182 let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
183 let mut stack = MiddlewareStack::new();
184 stack.push(OrderRecorder {
185 name: "a",
186 order: order.clone(),
187 });
188 stack.push(OrderRecorder {
189 name: "b",
190 order: order.clone(),
191 });
192 let _ = stack.run(base_input()).await.expect("run");
193 let seen = order.lock().expect("lock").clone();
194 assert_eq!(seen, vec!["a", "b"]);
195 }
196
197 #[tokio::test]
198 async fn test_early_termination() {
199 let mut stack = MiddlewareStack::new();
200 stack.push(EarlyTerminator);
201 let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
203 stack.push(OrderRecorder {
204 name: "after",
205 order: order.clone(),
206 });
207 let result = stack.run(base_input()).await.expect("run");
208 assert!(matches!(result, MiddlewareResult::Terminate(_)));
209 assert!(order.lock().expect("lock").is_empty());
210 }
211
212 #[tokio::test]
213 async fn test_system_prompt_composition_order() {
214 let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
215 let mut stack = MiddlewareStack::new();
216 stack.push(OrderRecorder {
217 name: "first",
218 order: order.clone(),
219 });
220 stack.push(OrderRecorder {
221 name: "second",
222 order: order.clone(),
223 });
224 let additions = stack.system_prompt_additions();
225 assert_eq!(additions, vec!["[first]", "[second]"]);
226 }
227}