Skip to main content

synaptic_middleware/
lib.rs

1mod callback_adapter;
2mod circuit_breaker;
3mod context_editing;
4mod human_in_the_loop;
5mod model_call_limit;
6mod model_fallback;
7mod security;
8mod ssrf_guard;
9mod summarization;
10mod todo_list;
11mod tool_call_limit;
12mod tool_retry;
13
14pub use callback_adapter::CallbackMiddleware;
15pub use circuit_breaker::{CircuitBreakerConfig, CircuitBreakerMiddleware, CircuitState};
16pub use context_editing::{ContextEditingMiddleware, ContextStrategy};
17pub use human_in_the_loop::{ApprovalCallback, HumanInTheLoopMiddleware};
18pub use model_call_limit::ModelCallLimitMiddleware;
19pub use model_fallback::ModelFallbackMiddleware;
20pub use security::{
21    ConfirmationPolicy, RiskLevel, RuleBasedAnalyzer, SecurityAnalyzer,
22    SecurityConfirmationCallback, SecurityMiddleware, ThresholdConfirmationPolicy,
23};
24pub use ssrf_guard::{SsrfGuardConfig, SsrfGuardMiddleware};
25pub use summarization::SummarizationMiddleware;
26pub use todo_list::TodoListMiddleware;
27pub use tool_call_limit::ToolCallLimitMiddleware;
28pub use tool_retry::ToolRetryMiddleware;
29
30use std::sync::Arc;
31
32use async_trait::async_trait;
33use serde_json::Value;
34use synaptic_core::{
35    ChatModel, ChatRequest, ChatResponse, Message, SynapticError, TokenUsage, ToolCall, ToolChoice,
36    ToolDefinition,
37};
38
39// ---------------------------------------------------------------------------
40// ModelRequest / ModelResponse — middleware-visible request & response types
41// ---------------------------------------------------------------------------
42
43/// A model invocation request visible to middleware.
44///
45/// Contains all parameters that will be sent to the `ChatModel`, plus
46/// the optional system prompt managed by the agent builder.
47#[derive(Debug, Clone)]
48pub struct ModelRequest {
49    pub messages: Vec<Message>,
50    pub tools: Vec<ToolDefinition>,
51    pub tool_choice: Option<ToolChoice>,
52    pub system_prompt: Option<String>,
53}
54
55impl ModelRequest {
56    /// Convert to a `ChatRequest` suitable for calling a `ChatModel`.
57    pub fn to_chat_request(&self) -> ChatRequest {
58        let mut messages = Vec::new();
59        if let Some(ref prompt) = self.system_prompt {
60            messages.push(Message::system(prompt));
61        }
62        messages.extend(self.messages.clone());
63        let mut req = ChatRequest::new(messages).with_tools(self.tools.clone());
64        if let Some(ref choice) = self.tool_choice {
65            req = req.with_tool_choice(choice.clone());
66        }
67        req
68    }
69}
70
71/// A model invocation response visible to middleware.
72#[derive(Debug, Clone)]
73pub struct ModelResponse {
74    pub message: Message,
75    pub usage: Option<TokenUsage>,
76}
77
78impl From<ChatResponse> for ModelResponse {
79    fn from(resp: ChatResponse) -> Self {
80        Self {
81            message: resp.message,
82            usage: resp.usage,
83        }
84    }
85}
86
87// ---------------------------------------------------------------------------
88// ToolCallRequest — wrapper around a single tool call
89// ---------------------------------------------------------------------------
90
91/// A single tool call request visible to middleware.
92#[derive(Debug, Clone)]
93pub struct ToolCallRequest {
94    pub call: ToolCall,
95}
96
97// ---------------------------------------------------------------------------
98// File/Shell hook types
99// ---------------------------------------------------------------------------
100
101/// Describes a file operation intercepted by middleware.
102#[derive(Debug, Clone)]
103pub struct FileOp {
104    pub path: String,
105    pub kind: FileOpKind,
106}
107
108/// The kind of file operation.
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum FileOpKind {
111    Read,
112    Write,
113    Delete,
114}
115
116/// Result of a file operation.
117#[derive(Debug, Clone)]
118pub struct FileOpResult {
119    pub success: bool,
120    pub error: Option<String>,
121}
122
123/// Decision for a file operation.
124#[derive(Debug, Clone)]
125pub enum FileOpDecision {
126    /// Allow the operation to proceed.
127    Allow,
128    /// Deny the operation with a reason.
129    Deny(String),
130}
131
132/// Describes a shell command intercepted by middleware.
133#[derive(Debug, Clone)]
134pub struct CommandOp {
135    pub command: String,
136    pub args: Vec<String>,
137    pub working_dir: Option<String>,
138}
139
140/// Result of a command execution.
141#[derive(Debug, Clone)]
142pub struct CommandResult {
143    pub exit_code: i32,
144    pub stdout: String,
145    pub stderr: String,
146}
147
148/// Decision for a command execution.
149#[derive(Debug, Clone)]
150pub enum CommandDecision {
151    Allow,
152    Deny(String),
153}
154
155// ---------------------------------------------------------------------------
156// ModelCaller / ToolCaller — "next" in the middleware chain
157// ---------------------------------------------------------------------------
158
159/// Trait representing the next step in the model call chain.
160///
161/// The innermost implementation calls the actual `ChatModel`; outer
162/// layers are middleware `wrap_model_call` implementations.
163#[async_trait]
164pub trait ModelCaller: Send + Sync {
165    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError>;
166}
167
168/// Trait representing the next step in the tool call chain.
169#[async_trait]
170pub trait ToolCaller: Send + Sync {
171    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError>;
172}
173
174// ---------------------------------------------------------------------------
175// AgentMiddleware trait
176// ---------------------------------------------------------------------------
177
178/// Middleware that can intercept and modify agent lifecycle events.
179///
180/// All methods have default no-op implementations, so middleware only
181/// needs to override the hooks it cares about.
182///
183/// # Lifecycle order
184///
185/// ```text
186/// before_agent
187///   loop {
188///     before_model  ->  wrap_model_call  ->  after_model
189///     for each tool_call { wrap_tool_call }
190///   }
191/// after_agent
192/// ```
193#[async_trait]
194pub trait AgentMiddleware: Send + Sync {
195    /// Called once when the agent starts executing.
196    async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
197        Ok(())
198    }
199
200    /// Called once when the agent finishes executing.
201    async fn after_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
202        Ok(())
203    }
204
205    /// Called before each model invocation. Can modify the request.
206    async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
207        Ok(())
208    }
209
210    /// Called after each model invocation. Can modify the response.
211    async fn after_model(
212        &self,
213        _request: &ModelRequest,
214        _response: &mut ModelResponse,
215    ) -> Result<(), SynapticError> {
216        Ok(())
217    }
218
219    /// Wraps the model call. Override to intercept or replace the model invocation.
220    async fn wrap_model_call(
221        &self,
222        request: ModelRequest,
223        next: &dyn ModelCaller,
224    ) -> Result<ModelResponse, SynapticError> {
225        next.call(request).await
226    }
227
228    /// Wraps a tool call. Override to intercept or replace tool execution.
229    async fn wrap_tool_call(
230        &self,
231        request: ToolCallRequest,
232        next: &dyn ToolCaller,
233    ) -> Result<Value, SynapticError> {
234        next.call(request).await
235    }
236
237    /// Called before a file operation. Return `Deny` to block it.
238    async fn before_file_op(&self, _op: &FileOp) -> Result<FileOpDecision, SynapticError> {
239        Ok(FileOpDecision::Allow)
240    }
241
242    /// Called after a file operation completes.
243    async fn after_file_op(
244        &self,
245        _op: &FileOp,
246        _result: &FileOpResult,
247    ) -> Result<(), SynapticError> {
248        Ok(())
249    }
250
251    /// Called before a shell command. Return `Deny` to block it.
252    async fn before_command(&self, _cmd: &CommandOp) -> Result<CommandDecision, SynapticError> {
253        Ok(CommandDecision::Allow)
254    }
255
256    /// Called after a shell command completes.
257    async fn after_command(
258        &self,
259        _cmd: &CommandOp,
260        _result: &CommandResult,
261    ) -> Result<(), SynapticError> {
262        Ok(())
263    }
264}
265
266// ---------------------------------------------------------------------------
267// MiddlewareChain — composes multiple middlewares
268// ---------------------------------------------------------------------------
269
270/// A chain of middlewares that executes them in order.
271pub struct MiddlewareChain {
272    middlewares: Vec<Arc<dyn AgentMiddleware>>,
273}
274
275impl MiddlewareChain {
276    pub fn new(middlewares: Vec<Arc<dyn AgentMiddleware>>) -> Self {
277        Self { middlewares }
278    }
279
280    pub fn is_empty(&self) -> bool {
281        self.middlewares.is_empty()
282    }
283
284    pub async fn run_before_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
285        for mw in &self.middlewares {
286            mw.before_agent(messages).await?;
287        }
288        Ok(())
289    }
290
291    pub async fn run_after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
292        for mw in self.middlewares.iter().rev() {
293            mw.after_agent(messages).await?;
294        }
295        Ok(())
296    }
297
298    pub async fn run_before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
299        for mw in &self.middlewares {
300            mw.before_model(request).await?;
301        }
302        Ok(())
303    }
304
305    pub async fn run_after_model(
306        &self,
307        request: &ModelRequest,
308        response: &mut ModelResponse,
309    ) -> Result<(), SynapticError> {
310        for mw in self.middlewares.iter().rev() {
311            mw.after_model(request, response).await?;
312        }
313        Ok(())
314    }
315
316    /// Execute a model call through the full middleware chain.
317    ///
318    /// Runs the complete lifecycle: `before_model` -> `wrap_model_call`
319    /// chain -> `after_model`.
320    pub async fn call_model(
321        &self,
322        mut request: ModelRequest,
323        base: &dyn ModelCaller,
324    ) -> Result<ModelResponse, SynapticError> {
325        // Run before_model hooks
326        self.run_before_model(&mut request).await?;
327
328        // Build the wrapped call chain (outermost first)
329        let mut response = if self.middlewares.is_empty() {
330            base.call(request.clone()).await?
331        } else {
332            let chain = WrapModelChain {
333                middlewares: &self.middlewares,
334                index: 0,
335                base,
336            };
337            chain.call(request.clone()).await?
338        };
339
340        // Run after_model hooks
341        self.run_after_model(&request, &mut response).await?;
342
343        Ok(response)
344    }
345
346    /// Execute a tool call through the full middleware chain.
347    pub async fn call_tool(
348        &self,
349        request: ToolCallRequest,
350        base: &dyn ToolCaller,
351    ) -> Result<Value, SynapticError> {
352        if self.middlewares.is_empty() {
353            base.call(request).await
354        } else {
355            let chain = WrapToolChain {
356                middlewares: &self.middlewares,
357                index: 0,
358                base,
359            };
360            chain.call(request).await
361        }
362    }
363
364    pub async fn run_before_file_op(&self, op: &FileOp) -> Result<FileOpDecision, SynapticError> {
365        for mw in &self.middlewares {
366            match mw.before_file_op(op).await? {
367                FileOpDecision::Allow => continue,
368                deny => return Ok(deny),
369            }
370        }
371        Ok(FileOpDecision::Allow)
372    }
373
374    pub async fn run_after_file_op(
375        &self,
376        op: &FileOp,
377        result: &FileOpResult,
378    ) -> Result<(), SynapticError> {
379        for mw in self.middlewares.iter().rev() {
380            mw.after_file_op(op, result).await?;
381        }
382        Ok(())
383    }
384
385    pub async fn run_before_command(
386        &self,
387        cmd: &CommandOp,
388    ) -> Result<CommandDecision, SynapticError> {
389        for mw in &self.middlewares {
390            match mw.before_command(cmd).await? {
391                CommandDecision::Allow => continue,
392                deny => return Ok(deny),
393            }
394        }
395        Ok(CommandDecision::Allow)
396    }
397
398    pub async fn run_after_command(
399        &self,
400        cmd: &CommandOp,
401        result: &CommandResult,
402    ) -> Result<(), SynapticError> {
403        for mw in self.middlewares.iter().rev() {
404            mw.after_command(cmd, result).await?;
405        }
406        Ok(())
407    }
408}
409
410// Internal chain helpers for recursive wrap_model_call / wrap_tool_call
411
412struct WrapModelChain<'a> {
413    middlewares: &'a [Arc<dyn AgentMiddleware>],
414    index: usize,
415    base: &'a dyn ModelCaller,
416}
417
418#[async_trait]
419impl ModelCaller for WrapModelChain<'_> {
420    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
421        if self.index >= self.middlewares.len() {
422            self.base.call(request).await
423        } else {
424            let next = WrapModelChain {
425                middlewares: self.middlewares,
426                index: self.index + 1,
427                base: self.base,
428            };
429            self.middlewares[self.index]
430                .wrap_model_call(request, &next)
431                .await
432        }
433    }
434}
435
436struct WrapToolChain<'a> {
437    middlewares: &'a [Arc<dyn AgentMiddleware>],
438    index: usize,
439    base: &'a dyn ToolCaller,
440}
441
442#[async_trait]
443impl ToolCaller for WrapToolChain<'_> {
444    async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
445        if self.index >= self.middlewares.len() {
446            self.base.call(request).await
447        } else {
448            let next = WrapToolChain {
449                middlewares: self.middlewares,
450                index: self.index + 1,
451                base: self.base,
452            };
453            self.middlewares[self.index]
454                .wrap_tool_call(request, &next)
455                .await
456        }
457    }
458}
459
460// ---------------------------------------------------------------------------
461// BaseChatModelCaller — calls the actual ChatModel
462// ---------------------------------------------------------------------------
463
464/// Wraps a `ChatModel` into a `ModelCaller`.
465pub struct BaseChatModelCaller {
466    model: Arc<dyn ChatModel>,
467}
468
469impl BaseChatModelCaller {
470    pub fn new(model: Arc<dyn ChatModel>) -> Self {
471        Self { model }
472    }
473}
474
475#[async_trait]
476impl ModelCaller for BaseChatModelCaller {
477    async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
478        let chat_request = request.to_chat_request();
479        let response = self.model.chat(chat_request).await?;
480        Ok(response.into())
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use std::sync::atomic::{AtomicUsize, Ordering};
488
489    struct CountingMiddleware {
490        before_count: AtomicUsize,
491        after_count: AtomicUsize,
492    }
493
494    impl CountingMiddleware {
495        fn new() -> Self {
496            Self {
497                before_count: AtomicUsize::new(0),
498                after_count: AtomicUsize::new(0),
499            }
500        }
501    }
502
503    #[async_trait]
504    impl AgentMiddleware for CountingMiddleware {
505        async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
506            self.before_count.fetch_add(1, Ordering::SeqCst);
507            Ok(())
508        }
509
510        async fn after_model(
511            &self,
512            _request: &ModelRequest,
513            _response: &mut ModelResponse,
514        ) -> Result<(), SynapticError> {
515            self.after_count.fetch_add(1, Ordering::SeqCst);
516            Ok(())
517        }
518    }
519
520    #[test]
521    fn middleware_chain_creation() {
522        let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
523        let chain = MiddlewareChain::new(vec![mw]);
524        assert!(!chain.is_empty());
525    }
526
527    #[test]
528    fn empty_middleware_chain() {
529        let chain = MiddlewareChain::new(vec![]);
530        assert!(chain.is_empty());
531    }
532
533    #[test]
534    fn model_request_to_chat_request() {
535        let req = ModelRequest {
536            messages: vec![Message::human("hello")],
537            tools: vec![],
538            tool_choice: None,
539            system_prompt: Some("You are helpful.".to_string()),
540        };
541        let chat_req = req.to_chat_request();
542        assert_eq!(chat_req.messages.len(), 2);
543        assert!(chat_req.messages[0].is_system());
544        assert!(chat_req.messages[1].is_human());
545    }
546
547    #[test]
548    fn model_request_without_system_prompt() {
549        let req = ModelRequest {
550            messages: vec![Message::human("hello")],
551            tools: vec![],
552            tool_choice: None,
553            system_prompt: None,
554        };
555        let chat_req = req.to_chat_request();
556        assert_eq!(chat_req.messages.len(), 1);
557    }
558
559    #[tokio::test]
560    async fn file_hook_default_allows() {
561        let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
562        let chain = MiddlewareChain::new(vec![mw]);
563        let op = FileOp {
564            path: "/tmp/test".to_string(),
565            kind: FileOpKind::Write,
566        };
567        let decision = chain.run_before_file_op(&op).await.unwrap();
568        assert!(matches!(decision, FileOpDecision::Allow));
569    }
570
571    #[tokio::test]
572    async fn command_hook_default_allows() {
573        let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
574        let chain = MiddlewareChain::new(vec![mw]);
575        let cmd = CommandOp {
576            command: "ls".to_string(),
577            args: vec![],
578            working_dir: None,
579        };
580        let decision = chain.run_before_command(&cmd).await.unwrap();
581        assert!(matches!(decision, CommandDecision::Allow));
582    }
583}