Skip to main content

synwire_core/agents/
middleware.rs

1//! Middleware stack for cross-cutting concerns.
2
3use serde_json::Value;
4
5use crate::BoxFuture;
6use crate::agents::error::AgentError;
7use crate::tools::Tool;
8
9/// Input passed through the middleware chain.
10#[derive(Debug, Clone)]
11pub struct MiddlewareInput {
12    /// Conversation messages as JSON values.
13    pub messages: Vec<Value>,
14    /// Arbitrary context metadata.
15    pub context: Value,
16}
17
18/// Outcome from a middleware component.
19#[derive(Debug)]
20#[non_exhaustive]
21pub enum MiddlewareResult {
22    /// Pass the (possibly modified) input to the next middleware.
23    Continue(MiddlewareInput),
24    /// Terminate the chain immediately with a message.
25    Terminate(String),
26}
27
28/// Cross-cutting concern injected into the agent loop.
29pub trait Middleware: Send + Sync {
30    /// Middleware identifier (for logging and ordering).
31    fn name(&self) -> &str;
32
33    /// Process the input and optionally call through to the next layer.
34    ///
35    /// The default implementation calls `next` unchanged.
36    fn process(
37        &self,
38        input: MiddlewareInput,
39    ) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
40        Box::pin(async move { Ok(MiddlewareResult::Continue(input)) })
41    }
42
43    /// Tools injected into the agent context by this middleware.
44    fn tools(&self) -> Vec<Box<dyn Tool>> {
45        Vec::new()
46    }
47
48    /// System prompt additions contributed by this middleware.
49    ///
50    /// Additions are concatenated in stack order.
51    fn system_prompt_additions(&self) -> Vec<String> {
52        Vec::new()
53    }
54}
55
56/// Executes a slice of middleware in order.
57///
58/// - If any middleware returns `Terminate`, the chain stops.
59/// - System prompt additions are collected from all middleware in order.
60/// - Tools are collected from all middleware in order.
61pub struct MiddlewareStack {
62    components: Vec<Box<dyn Middleware>>,
63}
64
65impl std::fmt::Debug for MiddlewareStack {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("MiddlewareStack")
68            .field(
69                "components",
70                &self.components.iter().map(|m| m.name()).collect::<Vec<_>>(),
71            )
72            .finish()
73    }
74}
75
76impl MiddlewareStack {
77    /// Create an empty middleware stack.
78    #[must_use]
79    pub fn new() -> Self {
80        Self {
81            components: Vec::new(),
82        }
83    }
84
85    /// Append a middleware component to the stack.
86    pub fn push(&mut self, middleware: impl Middleware + 'static) {
87        self.components.push(Box::new(middleware));
88    }
89
90    /// Run the input through all middleware in order.
91    pub async fn run(&self, mut input: MiddlewareInput) -> Result<MiddlewareResult, AgentError> {
92        for mw in &self.components {
93            match mw.process(input).await? {
94                MiddlewareResult::Continue(next_input) => input = next_input,
95                term @ MiddlewareResult::Terminate(_) => return Ok(term),
96            }
97        }
98        Ok(MiddlewareResult::Continue(input))
99    }
100
101    /// Collect all system prompt additions from all middleware in order.
102    #[must_use]
103    pub fn system_prompt_additions(&self) -> Vec<String> {
104        self.components
105            .iter()
106            .flat_map(|m| m.system_prompt_additions())
107            .collect()
108    }
109
110    /// Collect all tools from all middleware in order.
111    pub fn tools(&self) -> Vec<Box<dyn Tool>> {
112        self.components.iter().flat_map(|m| m.tools()).collect()
113    }
114}
115
116impl Default for MiddlewareStack {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122#[cfg(test)]
123#[allow(
124    clippy::unwrap_used,
125    clippy::expect_used,
126    clippy::panic,
127    clippy::unnecessary_literal_bound
128)]
129mod tests {
130    use super::*;
131
132    struct OrderRecorder {
133        name: &'static str,
134        order: std::sync::Arc<std::sync::Mutex<Vec<&'static str>>>,
135    }
136
137    impl Middleware for OrderRecorder {
138        fn name(&self) -> &str {
139            self.name
140        }
141
142        fn process(
143            &self,
144            input: MiddlewareInput,
145        ) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
146            let order = self.order.clone();
147            Box::pin(async move {
148                if let Ok(mut g) = order.lock() {
149                    g.push(self.name);
150                }
151                Ok(MiddlewareResult::Continue(input))
152            })
153        }
154
155        fn system_prompt_additions(&self) -> Vec<String> {
156            vec![format!("[{}]", self.name)]
157        }
158    }
159
160    struct EarlyTerminator;
161    impl Middleware for EarlyTerminator {
162        fn name(&self) -> &str {
163            "terminator"
164        }
165        fn process(
166            &self,
167            _input: MiddlewareInput,
168        ) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
169            Box::pin(async { Ok(MiddlewareResult::Terminate("stop".to_string())) })
170        }
171    }
172
173    fn base_input() -> MiddlewareInput {
174        MiddlewareInput {
175            messages: Vec::new(),
176            context: serde_json::json!({}),
177        }
178    }
179
180    #[tokio::test]
181    async fn test_stack_order() {
182        let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
183        let mut stack = MiddlewareStack::new();
184        stack.push(OrderRecorder {
185            name: "a",
186            order: order.clone(),
187        });
188        stack.push(OrderRecorder {
189            name: "b",
190            order: order.clone(),
191        });
192        let _ = stack.run(base_input()).await.expect("run");
193        let seen = order.lock().expect("lock").clone();
194        assert_eq!(seen, vec!["a", "b"]);
195    }
196
197    #[tokio::test]
198    async fn test_early_termination() {
199        let mut stack = MiddlewareStack::new();
200        stack.push(EarlyTerminator);
201        // Second middleware should NOT run.
202        let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
203        stack.push(OrderRecorder {
204            name: "after",
205            order: order.clone(),
206        });
207        let result = stack.run(base_input()).await.expect("run");
208        assert!(matches!(result, MiddlewareResult::Terminate(_)));
209        assert!(order.lock().expect("lock").is_empty());
210    }
211
212    #[tokio::test]
213    async fn test_system_prompt_composition_order() {
214        let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
215        let mut stack = MiddlewareStack::new();
216        stack.push(OrderRecorder {
217            name: "first",
218            order: order.clone(),
219        });
220        stack.push(OrderRecorder {
221            name: "second",
222            order: order.clone(),
223        });
224        let additions = stack.system_prompt_additions();
225        assert_eq!(additions, vec!["[first]", "[second]"]);
226    }
227}