Skip to main content

synaptic_middleware/
lib.rs

1mod context_editing;
2mod human_in_the_loop;
3mod model_call_limit;
4mod model_fallback;
5mod summarization;
6mod todo_list;
7mod tool_call_limit;
8mod tool_retry;
9
10pub use context_editing::{ContextEditingMiddleware, ContextStrategy};
11pub use human_in_the_loop::{ApprovalCallback, HumanInTheLoopMiddleware};
12pub use model_call_limit::ModelCallLimitMiddleware;
13pub use model_fallback::ModelFallbackMiddleware;
14pub use summarization::SummarizationMiddleware;
15pub use todo_list::TodoListMiddleware;
16pub use tool_call_limit::ToolCallLimitMiddleware;
17pub use tool_retry::ToolRetryMiddleware;
18
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use serde_json::Value;
23use synaptic_core::{
24    ChatModel, ChatRequest, ChatResponse, Message, SynapticError, TokenUsage, ToolCall, ToolChoice,
25    ToolDefinition,
26};
27
28// ---------------------------------------------------------------------------
29// ModelRequest / ModelResponse — middleware-visible request & response types
30// ---------------------------------------------------------------------------
31
32/// A model invocation request visible to middleware.
33///
34/// Contains all parameters that will be sent to the `ChatModel`, plus
35/// the optional system prompt managed by the agent builder.
36#[derive(Debug, Clone)]
37pub struct ModelRequest {
38    pub messages: Vec<Message>,
39    pub tools: Vec<ToolDefinition>,
40    pub tool_choice: Option<ToolChoice>,
41    pub system_prompt: Option<String>,
42}
43
44impl ModelRequest {
45    /// Convert to a `ChatRequest` suitable for calling a `ChatModel`.
46    pub fn to_chat_request(&self) -> ChatRequest {
47        let mut messages = Vec::new();
48        if let Some(ref prompt) = self.system_prompt {
49            messages.push(Message::system(prompt));
50        }
51        messages.extend(self.messages.clone());
52        let mut req = ChatRequest::new(messages).with_tools(self.tools.clone());
53        if let Some(ref choice) = self.tool_choice {
54            req = req.with_tool_choice(choice.clone());
55        }
56        req
57    }
58}
59
60/// A model invocation response visible to middleware.
61#[derive(Debug, Clone)]
62pub struct ModelResponse {
63    pub message: Message,
64    pub usage: Option<TokenUsage>,
65}
66
67impl From<ChatResponse> for ModelResponse {
68    fn from(resp: ChatResponse) -> Self {
69        Self {
70            message: resp.message,
71            usage: resp.usage,
72        }
73    }
74}
75
76// ---------------------------------------------------------------------------
77// ToolCallRequest — wrapper around a single tool call
78// ---------------------------------------------------------------------------
79
80/// A single tool call request visible to middleware.
81#[derive(Debug, Clone)]
82pub struct ToolCallRequest {
83    pub call: ToolCall,
84}
85
86// ---------------------------------------------------------------------------
87// ModelCaller / ToolCaller — "next" in the middleware chain
88// ---------------------------------------------------------------------------
89
90/// Trait representing the next step in the model call chain.
91///
92/// The innermost implementation calls the actual `ChatModel`; outer
93/// layers are middleware `wrap_model_call` implementations.
94#[async_trait]
95pub trait ModelCaller: Send + Sync {
96    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError>;
97}
98
99/// Trait representing the next step in the tool call chain.
100#[async_trait]
101pub trait ToolCaller: Send + Sync {
102    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError>;
103}
104
105// ---------------------------------------------------------------------------
106// AgentMiddleware trait
107// ---------------------------------------------------------------------------
108
109/// Middleware that can intercept and modify agent lifecycle events.
110///
111/// All methods have default no-op implementations, so middleware only
112/// needs to override the hooks it cares about.
113///
114/// # Lifecycle order
115///
116/// ```text
117/// before_agent
118///   loop {
119///     before_model  ->  wrap_model_call  ->  after_model
120///     for each tool_call { wrap_tool_call }
121///   }
122/// after_agent
123/// ```
124#[async_trait]
125pub trait AgentMiddleware: Send + Sync {
126    /// Called once when the agent starts executing.
127    async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
128        Ok(())
129    }
130
131    /// Called once when the agent finishes executing.
132    async fn after_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
133        Ok(())
134    }
135
136    /// Called before each model invocation. Can modify the request.
137    async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
138        Ok(())
139    }
140
141    /// Called after each model invocation. Can modify the response.
142    async fn after_model(
143        &self,
144        _request: &ModelRequest,
145        _response: &mut ModelResponse,
146    ) -> Result<(), SynapticError> {
147        Ok(())
148    }
149
150    /// Wraps the model call. Override to intercept or replace the model invocation.
151    async fn wrap_model_call(
152        &self,
153        request: ModelRequest,
154        next: &dyn ModelCaller,
155    ) -> Result<ModelResponse, SynapticError> {
156        next.call(request).await
157    }
158
159    /// Wraps a tool call. Override to intercept or replace tool execution.
160    async fn wrap_tool_call(
161        &self,
162        request: ToolCallRequest,
163        next: &dyn ToolCaller,
164    ) -> Result<Value, SynapticError> {
165        next.call(request).await
166    }
167}
168
169// ---------------------------------------------------------------------------
170// MiddlewareChain — composes multiple middlewares
171// ---------------------------------------------------------------------------
172
173/// A chain of middlewares that executes them in order.
174pub struct MiddlewareChain {
175    middlewares: Vec<Arc<dyn AgentMiddleware>>,
176}
177
178impl MiddlewareChain {
179    pub fn new(middlewares: Vec<Arc<dyn AgentMiddleware>>) -> Self {
180        Self { middlewares }
181    }
182
183    pub fn is_empty(&self) -> bool {
184        self.middlewares.is_empty()
185    }
186
187    pub async fn run_before_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
188        for mw in &self.middlewares {
189            mw.before_agent(messages).await?;
190        }
191        Ok(())
192    }
193
194    pub async fn run_after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
195        for mw in self.middlewares.iter().rev() {
196            mw.after_agent(messages).await?;
197        }
198        Ok(())
199    }
200
201    pub async fn run_before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
202        for mw in &self.middlewares {
203            mw.before_model(request).await?;
204        }
205        Ok(())
206    }
207
208    pub async fn run_after_model(
209        &self,
210        request: &ModelRequest,
211        response: &mut ModelResponse,
212    ) -> Result<(), SynapticError> {
213        for mw in self.middlewares.iter().rev() {
214            mw.after_model(request, response).await?;
215        }
216        Ok(())
217    }
218
219    /// Execute a model call through the full middleware chain.
220    ///
221    /// Runs the complete lifecycle: `before_model` -> `wrap_model_call`
222    /// chain -> `after_model`.
223    pub async fn call_model(
224        &self,
225        mut request: ModelRequest,
226        base: &dyn ModelCaller,
227    ) -> Result<ModelResponse, SynapticError> {
228        // Run before_model hooks
229        self.run_before_model(&mut request).await?;
230
231        // Build the wrapped call chain (outermost first)
232        let mut response = if self.middlewares.is_empty() {
233            base.call(request.clone()).await?
234        } else {
235            let chain = WrapModelChain {
236                middlewares: &self.middlewares,
237                index: 0,
238                base,
239            };
240            chain.call(request.clone()).await?
241        };
242
243        // Run after_model hooks
244        self.run_after_model(&request, &mut response).await?;
245
246        Ok(response)
247    }
248
249    /// Execute a tool call through the full middleware chain.
250    pub async fn call_tool(
251        &self,
252        request: ToolCallRequest,
253        base: &dyn ToolCaller,
254    ) -> Result<Value, SynapticError> {
255        if self.middlewares.is_empty() {
256            base.call(request).await
257        } else {
258            let chain = WrapToolChain {
259                middlewares: &self.middlewares,
260                index: 0,
261                base,
262            };
263            chain.call(request).await
264        }
265    }
266}
267
268// Internal chain helpers for recursive wrap_model_call / wrap_tool_call
269
270struct WrapModelChain<'a> {
271    middlewares: &'a [Arc<dyn AgentMiddleware>],
272    index: usize,
273    base: &'a dyn ModelCaller,
274}
275
276#[async_trait]
277impl ModelCaller for WrapModelChain<'_> {
278    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
279        if self.index >= self.middlewares.len() {
280            self.base.call(request).await
281        } else {
282            let next = WrapModelChain {
283                middlewares: self.middlewares,
284                index: self.index + 1,
285                base: self.base,
286            };
287            self.middlewares[self.index]
288                .wrap_model_call(request, &next)
289                .await
290        }
291    }
292}
293
294struct WrapToolChain<'a> {
295    middlewares: &'a [Arc<dyn AgentMiddleware>],
296    index: usize,
297    base: &'a dyn ToolCaller,
298}
299
300#[async_trait]
301impl ToolCaller for WrapToolChain<'_> {
302    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
303        if self.index >= self.middlewares.len() {
304            self.base.call(request).await
305        } else {
306            let next = WrapToolChain {
307                middlewares: self.middlewares,
308                index: self.index + 1,
309                base: self.base,
310            };
311            self.middlewares[self.index]
312                .wrap_tool_call(request, &next)
313                .await
314        }
315    }
316}
317
318// ---------------------------------------------------------------------------
319// BaseChatModelCaller — calls the actual ChatModel
320// ---------------------------------------------------------------------------
321
322/// Wraps a `ChatModel` into a `ModelCaller`.
323pub struct BaseChatModelCaller {
324    model: Arc<dyn ChatModel>,
325}
326
327impl BaseChatModelCaller {
328    pub fn new(model: Arc<dyn ChatModel>) -> Self {
329        Self { model }
330    }
331}
332
333#[async_trait]
334impl ModelCaller for BaseChatModelCaller {
335    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
336        let chat_request = request.to_chat_request();
337        let response = self.model.chat(chat_request).await?;
338        Ok(response.into())
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use std::sync::atomic::{AtomicUsize, Ordering};
346
347    struct CountingMiddleware {
348        before_count: AtomicUsize,
349        after_count: AtomicUsize,
350    }
351
352    impl CountingMiddleware {
353        fn new() -> Self {
354            Self {
355                before_count: AtomicUsize::new(0),
356                after_count: AtomicUsize::new(0),
357            }
358        }
359    }
360
361    #[async_trait]
362    impl AgentMiddleware for CountingMiddleware {
363        async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
364            self.before_count.fetch_add(1, Ordering::SeqCst);
365            Ok(())
366        }
367
368        async fn after_model(
369            &self,
370            _request: &ModelRequest,
371            _response: &mut ModelResponse,
372        ) -> Result<(), SynapticError> {
373            self.after_count.fetch_add(1, Ordering::SeqCst);
374            Ok(())
375        }
376    }
377
378    #[test]
379    fn middleware_chain_creation() {
380        let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
381        let chain = MiddlewareChain::new(vec![mw]);
382        assert!(!chain.is_empty());
383    }
384
385    #[test]
386    fn empty_middleware_chain() {
387        let chain = MiddlewareChain::new(vec![]);
388        assert!(chain.is_empty());
389    }
390
391    #[test]
392    fn model_request_to_chat_request() {
393        let req = ModelRequest {
394            messages: vec![Message::human("hello")],
395            tools: vec![],
396            tool_choice: None,
397            system_prompt: Some("You are helpful.".to_string()),
398        };
399        let chat_req = req.to_chat_request();
400        assert_eq!(chat_req.messages.len(), 2);
401        assert!(chat_req.messages[0].is_system());
402        assert!(chat_req.messages[1].is_human());
403    }
404
405    #[test]
406    fn model_request_without_system_prompt() {
407        let req = ModelRequest {
408            messages: vec![Message::human("hello")],
409            tools: vec![],
410            tool_choice: None,
411            system_prompt: None,
412        };
413        let chat_req = req.to_chat_request();
414        assert_eq!(chat_req.messages.len(), 1);
415    }
416}