Skip to main content

synaptic_middleware/
lib.rs

1mod context_editing;
2mod human_in_the_loop;
3mod model_call_limit;
4mod model_fallback;
5mod security;
6mod summarization;
7mod todo_list;
8mod tool_call_limit;
9mod tool_retry;
10
11pub use context_editing::{ContextEditingMiddleware, ContextStrategy};
12pub use human_in_the_loop::{ApprovalCallback, HumanInTheLoopMiddleware};
13pub use model_call_limit::ModelCallLimitMiddleware;
14pub use model_fallback::ModelFallbackMiddleware;
15pub use security::{
16    ConfirmationPolicy, RiskLevel, RuleBasedAnalyzer, SecurityAnalyzer,
17    SecurityConfirmationCallback, SecurityMiddleware, ThresholdConfirmationPolicy,
18};
19pub use summarization::SummarizationMiddleware;
20pub use todo_list::TodoListMiddleware;
21pub use tool_call_limit::ToolCallLimitMiddleware;
22pub use tool_retry::ToolRetryMiddleware;
23
24use std::sync::Arc;
25
26use async_trait::async_trait;
27use serde_json::Value;
28use synaptic_core::{
29    ChatModel, ChatRequest, ChatResponse, Message, SynapticError, TokenUsage, ToolCall, ToolChoice,
30    ToolDefinition,
31};
32
33// ---------------------------------------------------------------------------
34// ModelRequest / ModelResponse — middleware-visible request & response types
35// ---------------------------------------------------------------------------
36
37/// A model invocation request visible to middleware.
38///
39/// Contains all parameters that will be sent to the `ChatModel`, plus
40/// the optional system prompt managed by the agent builder.
41#[derive(Debug, Clone)]
42pub struct ModelRequest {
43    pub messages: Vec<Message>,
44    pub tools: Vec<ToolDefinition>,
45    pub tool_choice: Option<ToolChoice>,
46    pub system_prompt: Option<String>,
47}
48
49impl ModelRequest {
50    /// Convert to a `ChatRequest` suitable for calling a `ChatModel`.
51    pub fn to_chat_request(&self) -> ChatRequest {
52        let mut messages = Vec::new();
53        if let Some(ref prompt) = self.system_prompt {
54            messages.push(Message::system(prompt));
55        }
56        messages.extend(self.messages.clone());
57        let mut req = ChatRequest::new(messages).with_tools(self.tools.clone());
58        if let Some(ref choice) = self.tool_choice {
59            req = req.with_tool_choice(choice.clone());
60        }
61        req
62    }
63}
64
65/// A model invocation response visible to middleware.
66#[derive(Debug, Clone)]
67pub struct ModelResponse {
68    pub message: Message,
69    pub usage: Option<TokenUsage>,
70}
71
72impl From<ChatResponse> for ModelResponse {
73    fn from(resp: ChatResponse) -> Self {
74        Self {
75            message: resp.message,
76            usage: resp.usage,
77        }
78    }
79}
80
81// ---------------------------------------------------------------------------
82// ToolCallRequest — wrapper around a single tool call
83// ---------------------------------------------------------------------------
84
85/// A single tool call request visible to middleware.
86#[derive(Debug, Clone)]
87pub struct ToolCallRequest {
88    pub call: ToolCall,
89}
90
91// ---------------------------------------------------------------------------
92// ModelCaller / ToolCaller — "next" in the middleware chain
93// ---------------------------------------------------------------------------
94
95/// Trait representing the next step in the model call chain.
96///
97/// The innermost implementation calls the actual `ChatModel`; outer
98/// layers are middleware `wrap_model_call` implementations.
99#[async_trait]
100pub trait ModelCaller: Send + Sync {
101    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError>;
102}
103
104/// Trait representing the next step in the tool call chain.
105#[async_trait]
106pub trait ToolCaller: Send + Sync {
107    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError>;
108}
109
110// ---------------------------------------------------------------------------
111// AgentMiddleware trait
112// ---------------------------------------------------------------------------
113
114/// Middleware that can intercept and modify agent lifecycle events.
115///
116/// All methods have default no-op implementations, so middleware only
117/// needs to override the hooks it cares about.
118///
119/// # Lifecycle order
120///
121/// ```text
122/// before_agent
123///   loop {
124///     before_model  ->  wrap_model_call  ->  after_model
125///     for each tool_call { wrap_tool_call }
126///   }
127/// after_agent
128/// ```
129#[async_trait]
130pub trait AgentMiddleware: Send + Sync {
131    /// Called once when the agent starts executing.
132    async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
133        Ok(())
134    }
135
136    /// Called once when the agent finishes executing.
137    async fn after_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
138        Ok(())
139    }
140
141    /// Called before each model invocation. Can modify the request.
142    async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
143        Ok(())
144    }
145
146    /// Called after each model invocation. Can modify the response.
147    async fn after_model(
148        &self,
149        _request: &ModelRequest,
150        _response: &mut ModelResponse,
151    ) -> Result<(), SynapticError> {
152        Ok(())
153    }
154
155    /// Wraps the model call. Override to intercept or replace the model invocation.
156    async fn wrap_model_call(
157        &self,
158        request: ModelRequest,
159        next: &dyn ModelCaller,
160    ) -> Result<ModelResponse, SynapticError> {
161        next.call(request).await
162    }
163
164    /// Wraps a tool call. Override to intercept or replace tool execution.
165    async fn wrap_tool_call(
166        &self,
167        request: ToolCallRequest,
168        next: &dyn ToolCaller,
169    ) -> Result<Value, SynapticError> {
170        next.call(request).await
171    }
172}
173
174// ---------------------------------------------------------------------------
175// MiddlewareChain — composes multiple middlewares
176// ---------------------------------------------------------------------------
177
178/// A chain of middlewares that executes them in order.
179pub struct MiddlewareChain {
180    middlewares: Vec<Arc<dyn AgentMiddleware>>,
181}
182
183impl MiddlewareChain {
184    pub fn new(middlewares: Vec<Arc<dyn AgentMiddleware>>) -> Self {
185        Self { middlewares }
186    }
187
188    pub fn is_empty(&self) -> bool {
189        self.middlewares.is_empty()
190    }
191
192    pub async fn run_before_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
193        for mw in &self.middlewares {
194            mw.before_agent(messages).await?;
195        }
196        Ok(())
197    }
198
199    pub async fn run_after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
200        for mw in self.middlewares.iter().rev() {
201            mw.after_agent(messages).await?;
202        }
203        Ok(())
204    }
205
206    pub async fn run_before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
207        for mw in &self.middlewares {
208            mw.before_model(request).await?;
209        }
210        Ok(())
211    }
212
213    pub async fn run_after_model(
214        &self,
215        request: &ModelRequest,
216        response: &mut ModelResponse,
217    ) -> Result<(), SynapticError> {
218        for mw in self.middlewares.iter().rev() {
219            mw.after_model(request, response).await?;
220        }
221        Ok(())
222    }
223
224    /// Execute a model call through the full middleware chain.
225    ///
226    /// Runs the complete lifecycle: `before_model` -> `wrap_model_call`
227    /// chain -> `after_model`.
228    pub async fn call_model(
229        &self,
230        mut request: ModelRequest,
231        base: &dyn ModelCaller,
232    ) -> Result<ModelResponse, SynapticError> {
233        // Run before_model hooks
234        self.run_before_model(&mut request).await?;
235
236        // Build the wrapped call chain (outermost first)
237        let mut response = if self.middlewares.is_empty() {
238            base.call(request.clone()).await?
239        } else {
240            let chain = WrapModelChain {
241                middlewares: &self.middlewares,
242                index: 0,
243                base,
244            };
245            chain.call(request.clone()).await?
246        };
247
248        // Run after_model hooks
249        self.run_after_model(&request, &mut response).await?;
250
251        Ok(response)
252    }
253
254    /// Execute a tool call through the full middleware chain.
255    pub async fn call_tool(
256        &self,
257        request: ToolCallRequest,
258        base: &dyn ToolCaller,
259    ) -> Result<Value, SynapticError> {
260        if self.middlewares.is_empty() {
261            base.call(request).await
262        } else {
263            let chain = WrapToolChain {
264                middlewares: &self.middlewares,
265                index: 0,
266                base,
267            };
268            chain.call(request).await
269        }
270    }
271}
272
273// Internal chain helpers for recursive wrap_model_call / wrap_tool_call
274
275struct WrapModelChain<'a> {
276    middlewares: &'a [Arc<dyn AgentMiddleware>],
277    index: usize,
278    base: &'a dyn ModelCaller,
279}
280
281#[async_trait]
282impl ModelCaller for WrapModelChain<'_> {
283    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
284        if self.index >= self.middlewares.len() {
285            self.base.call(request).await
286        } else {
287            let next = WrapModelChain {
288                middlewares: self.middlewares,
289                index: self.index + 1,
290                base: self.base,
291            };
292            self.middlewares[self.index]
293                .wrap_model_call(request, &next)
294                .await
295        }
296    }
297}
298
299struct WrapToolChain<'a> {
300    middlewares: &'a [Arc<dyn AgentMiddleware>],
301    index: usize,
302    base: &'a dyn ToolCaller,
303}
304
305#[async_trait]
306impl ToolCaller for WrapToolChain<'_> {
307    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
308        if self.index >= self.middlewares.len() {
309            self.base.call(request).await
310        } else {
311            let next = WrapToolChain {
312                middlewares: self.middlewares,
313                index: self.index + 1,
314                base: self.base,
315            };
316            self.middlewares[self.index]
317                .wrap_tool_call(request, &next)
318                .await
319        }
320    }
321}
322
323// ---------------------------------------------------------------------------
324// BaseChatModelCaller — calls the actual ChatModel
325// ---------------------------------------------------------------------------
326
327/// Wraps a `ChatModel` into a `ModelCaller`.
328pub struct BaseChatModelCaller {
329    model: Arc<dyn ChatModel>,
330}
331
332impl BaseChatModelCaller {
333    pub fn new(model: Arc<dyn ChatModel>) -> Self {
334        Self { model }
335    }
336}
337
338#[async_trait]
339impl ModelCaller for BaseChatModelCaller {
340    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
341        let chat_request = request.to_chat_request();
342        let response = self.model.chat(chat_request).await?;
343        Ok(response.into())
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use std::sync::atomic::{AtomicUsize, Ordering};
351
352    struct CountingMiddleware {
353        before_count: AtomicUsize,
354        after_count: AtomicUsize,
355    }
356
357    impl CountingMiddleware {
358        fn new() -> Self {
359            Self {
360                before_count: AtomicUsize::new(0),
361                after_count: AtomicUsize::new(0),
362            }
363        }
364    }
365
366    #[async_trait]
367    impl AgentMiddleware for CountingMiddleware {
368        async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
369            self.before_count.fetch_add(1, Ordering::SeqCst);
370            Ok(())
371        }
372
373        async fn after_model(
374            &self,
375            _request: &ModelRequest,
376            _response: &mut ModelResponse,
377        ) -> Result<(), SynapticError> {
378            self.after_count.fetch_add(1, Ordering::SeqCst);
379            Ok(())
380        }
381    }
382
383    #[test]
384    fn middleware_chain_creation() {
385        let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
386        let chain = MiddlewareChain::new(vec![mw]);
387        assert!(!chain.is_empty());
388    }
389
390    #[test]
391    fn empty_middleware_chain() {
392        let chain = MiddlewareChain::new(vec![]);
393        assert!(chain.is_empty());
394    }
395
396    #[test]
397    fn model_request_to_chat_request() {
398        let req = ModelRequest {
399            messages: vec![Message::human("hello")],
400            tools: vec![],
401            tool_choice: None,
402            system_prompt: Some("You are helpful.".to_string()),
403        };
404        let chat_req = req.to_chat_request();
405        assert_eq!(chat_req.messages.len(), 2);
406        assert!(chat_req.messages[0].is_system());
407        assert!(chat_req.messages[1].is_human());
408    }
409
410    #[test]
411    fn model_request_without_system_prompt() {
412        let req = ModelRequest {
413            messages: vec![Message::human("hello")],
414            tools: vec![],
415            tool_choice: None,
416            system_prompt: None,
417        };
418        let chat_req = req.to_chat_request();
419        assert_eq!(chat_req.messages.len(), 1);
420    }
421}