1mod callback_adapter;
2mod circuit_breaker;
3mod context_editing;
4mod human_in_the_loop;
5mod model_call_limit;
6mod model_fallback;
7mod security;
8mod ssrf_guard;
9mod summarization;
10mod todo_list;
11mod tool_call_limit;
12mod tool_retry;
13
14pub use callback_adapter::CallbackMiddleware;
15pub use circuit_breaker::{CircuitBreakerConfig, CircuitBreakerMiddleware, CircuitState};
16pub use context_editing::{ContextEditingMiddleware, ContextStrategy};
17pub use human_in_the_loop::{ApprovalCallback, HumanInTheLoopMiddleware};
18pub use model_call_limit::ModelCallLimitMiddleware;
19pub use model_fallback::ModelFallbackMiddleware;
20pub use security::{
21 ConfirmationPolicy, RiskLevel, RuleBasedAnalyzer, SecurityAnalyzer,
22 SecurityConfirmationCallback, SecurityMiddleware, ThresholdConfirmationPolicy,
23};
24pub use ssrf_guard::{SsrfGuardConfig, SsrfGuardMiddleware};
25pub use summarization::SummarizationMiddleware;
26pub use todo_list::TodoListMiddleware;
27pub use tool_call_limit::ToolCallLimitMiddleware;
28pub use tool_retry::ToolRetryMiddleware;
29
30use std::sync::Arc;
31
32use async_trait::async_trait;
33use serde_json::Value;
34use synaptic_core::{
35 ChatModel, ChatRequest, ChatResponse, Message, SynapticError, TokenUsage, ToolCall, ToolChoice,
36 ToolDefinition,
37};
38
39#[derive(Debug, Clone)]
48pub struct ModelRequest {
49 pub messages: Vec<Message>,
50 pub tools: Vec<ToolDefinition>,
51 pub tool_choice: Option<ToolChoice>,
52 pub system_prompt: Option<String>,
53}
54
55impl ModelRequest {
56 pub fn to_chat_request(&self) -> ChatRequest {
58 let mut messages = Vec::new();
59 if let Some(ref prompt) = self.system_prompt {
60 messages.push(Message::system(prompt));
61 }
62 messages.extend(self.messages.clone());
63 let mut req = ChatRequest::new(messages).with_tools(self.tools.clone());
64 if let Some(ref choice) = self.tool_choice {
65 req = req.with_tool_choice(choice.clone());
66 }
67 req
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct ModelResponse {
74 pub message: Message,
75 pub usage: Option<TokenUsage>,
76}
77
78impl From<ChatResponse> for ModelResponse {
79 fn from(resp: ChatResponse) -> Self {
80 Self {
81 message: resp.message,
82 usage: resp.usage,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
93pub struct ToolCallRequest {
94 pub call: ToolCall,
95}
96
97#[derive(Debug, Clone)]
103pub struct FileOp {
104 pub path: String,
105 pub kind: FileOpKind,
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum FileOpKind {
111 Read,
112 Write,
113 Delete,
114}
115
116#[derive(Debug, Clone)]
118pub struct FileOpResult {
119 pub success: bool,
120 pub error: Option<String>,
121}
122
123#[derive(Debug, Clone)]
125pub enum FileOpDecision {
126 Allow,
128 Deny(String),
130}
131
132#[derive(Debug, Clone)]
134pub struct CommandOp {
135 pub command: String,
136 pub args: Vec<String>,
137 pub working_dir: Option<String>,
138}
139
140#[derive(Debug, Clone)]
142pub struct CommandResult {
143 pub exit_code: i32,
144 pub stdout: String,
145 pub stderr: String,
146}
147
148#[derive(Debug, Clone)]
150pub enum CommandDecision {
151 Allow,
152 Deny(String),
153}
154
155#[async_trait]
164pub trait ModelCaller: Send + Sync {
165 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError>;
166}
167
168#[async_trait]
170pub trait ToolCaller: Send + Sync {
171 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError>;
172}
173
174#[async_trait]
194pub trait AgentMiddleware: Send + Sync {
195 async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
197 Ok(())
198 }
199
200 async fn after_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
202 Ok(())
203 }
204
205 async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
207 Ok(())
208 }
209
210 async fn after_model(
212 &self,
213 _request: &ModelRequest,
214 _response: &mut ModelResponse,
215 ) -> Result<(), SynapticError> {
216 Ok(())
217 }
218
219 async fn wrap_model_call(
221 &self,
222 request: ModelRequest,
223 next: &dyn ModelCaller,
224 ) -> Result<ModelResponse, SynapticError> {
225 next.call(request).await
226 }
227
228 async fn wrap_tool_call(
230 &self,
231 request: ToolCallRequest,
232 next: &dyn ToolCaller,
233 ) -> Result<Value, SynapticError> {
234 next.call(request).await
235 }
236
237 async fn before_file_op(&self, _op: &FileOp) -> Result<FileOpDecision, SynapticError> {
239 Ok(FileOpDecision::Allow)
240 }
241
242 async fn after_file_op(
244 &self,
245 _op: &FileOp,
246 _result: &FileOpResult,
247 ) -> Result<(), SynapticError> {
248 Ok(())
249 }
250
251 async fn before_command(&self, _cmd: &CommandOp) -> Result<CommandDecision, SynapticError> {
253 Ok(CommandDecision::Allow)
254 }
255
256 async fn after_command(
258 &self,
259 _cmd: &CommandOp,
260 _result: &CommandResult,
261 ) -> Result<(), SynapticError> {
262 Ok(())
263 }
264}
265
266pub struct MiddlewareChain {
272 middlewares: Vec<Arc<dyn AgentMiddleware>>,
273}
274
275impl MiddlewareChain {
276 pub fn new(middlewares: Vec<Arc<dyn AgentMiddleware>>) -> Self {
277 Self { middlewares }
278 }
279
280 pub fn is_empty(&self) -> bool {
281 self.middlewares.is_empty()
282 }
283
284 pub async fn run_before_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
285 for mw in &self.middlewares {
286 mw.before_agent(messages).await?;
287 }
288 Ok(())
289 }
290
291 pub async fn run_after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
292 for mw in self.middlewares.iter().rev() {
293 mw.after_agent(messages).await?;
294 }
295 Ok(())
296 }
297
298 pub async fn run_before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
299 for mw in &self.middlewares {
300 mw.before_model(request).await?;
301 }
302 Ok(())
303 }
304
305 pub async fn run_after_model(
306 &self,
307 request: &ModelRequest,
308 response: &mut ModelResponse,
309 ) -> Result<(), SynapticError> {
310 for mw in self.middlewares.iter().rev() {
311 mw.after_model(request, response).await?;
312 }
313 Ok(())
314 }
315
316 pub async fn call_model(
321 &self,
322 mut request: ModelRequest,
323 base: &dyn ModelCaller,
324 ) -> Result<ModelResponse, SynapticError> {
325 self.run_before_model(&mut request).await?;
327
328 let mut response = if self.middlewares.is_empty() {
330 base.call(request.clone()).await?
331 } else {
332 let chain = WrapModelChain {
333 middlewares: &self.middlewares,
334 index: 0,
335 base,
336 };
337 chain.call(request.clone()).await?
338 };
339
340 self.run_after_model(&request, &mut response).await?;
342
343 Ok(response)
344 }
345
346 pub async fn call_tool(
348 &self,
349 request: ToolCallRequest,
350 base: &dyn ToolCaller,
351 ) -> Result<Value, SynapticError> {
352 if self.middlewares.is_empty() {
353 base.call(request).await
354 } else {
355 let chain = WrapToolChain {
356 middlewares: &self.middlewares,
357 index: 0,
358 base,
359 };
360 chain.call(request).await
361 }
362 }
363
364 pub async fn run_before_file_op(&self, op: &FileOp) -> Result<FileOpDecision, SynapticError> {
365 for mw in &self.middlewares {
366 match mw.before_file_op(op).await? {
367 FileOpDecision::Allow => continue,
368 deny => return Ok(deny),
369 }
370 }
371 Ok(FileOpDecision::Allow)
372 }
373
374 pub async fn run_after_file_op(
375 &self,
376 op: &FileOp,
377 result: &FileOpResult,
378 ) -> Result<(), SynapticError> {
379 for mw in self.middlewares.iter().rev() {
380 mw.after_file_op(op, result).await?;
381 }
382 Ok(())
383 }
384
385 pub async fn run_before_command(
386 &self,
387 cmd: &CommandOp,
388 ) -> Result<CommandDecision, SynapticError> {
389 for mw in &self.middlewares {
390 match mw.before_command(cmd).await? {
391 CommandDecision::Allow => continue,
392 deny => return Ok(deny),
393 }
394 }
395 Ok(CommandDecision::Allow)
396 }
397
398 pub async fn run_after_command(
399 &self,
400 cmd: &CommandOp,
401 result: &CommandResult,
402 ) -> Result<(), SynapticError> {
403 for mw in self.middlewares.iter().rev() {
404 mw.after_command(cmd, result).await?;
405 }
406 Ok(())
407 }
408}
409
410struct WrapModelChain<'a> {
413 middlewares: &'a [Arc<dyn AgentMiddleware>],
414 index: usize,
415 base: &'a dyn ModelCaller,
416}
417
418#[async_trait]
419impl ModelCaller for WrapModelChain<'_> {
420 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
421 if self.index >= self.middlewares.len() {
422 self.base.call(request).await
423 } else {
424 let next = WrapModelChain {
425 middlewares: self.middlewares,
426 index: self.index + 1,
427 base: self.base,
428 };
429 self.middlewares[self.index]
430 .wrap_model_call(request, &next)
431 .await
432 }
433 }
434}
435
436struct WrapToolChain<'a> {
437 middlewares: &'a [Arc<dyn AgentMiddleware>],
438 index: usize,
439 base: &'a dyn ToolCaller,
440}
441
442#[async_trait]
443impl ToolCaller for WrapToolChain<'_> {
444 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
445 if self.index >= self.middlewares.len() {
446 self.base.call(request).await
447 } else {
448 let next = WrapToolChain {
449 middlewares: self.middlewares,
450 index: self.index + 1,
451 base: self.base,
452 };
453 self.middlewares[self.index]
454 .wrap_tool_call(request, &next)
455 .await
456 }
457 }
458}
459
460pub struct BaseChatModelCaller {
466 model: Arc<dyn ChatModel>,
467}
468
469impl BaseChatModelCaller {
470 pub fn new(model: Arc<dyn ChatModel>) -> Self {
471 Self { model }
472 }
473}
474
475#[async_trait]
476impl ModelCaller for BaseChatModelCaller {
477 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
478 let chat_request = request.to_chat_request();
479 let response = self.model.chat(chat_request).await?;
480 Ok(response.into())
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use std::sync::atomic::{AtomicUsize, Ordering};
488
489 struct CountingMiddleware {
490 before_count: AtomicUsize,
491 after_count: AtomicUsize,
492 }
493
494 impl CountingMiddleware {
495 fn new() -> Self {
496 Self {
497 before_count: AtomicUsize::new(0),
498 after_count: AtomicUsize::new(0),
499 }
500 }
501 }
502
503 #[async_trait]
504 impl AgentMiddleware for CountingMiddleware {
505 async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
506 self.before_count.fetch_add(1, Ordering::SeqCst);
507 Ok(())
508 }
509
510 async fn after_model(
511 &self,
512 _request: &ModelRequest,
513 _response: &mut ModelResponse,
514 ) -> Result<(), SynapticError> {
515 self.after_count.fetch_add(1, Ordering::SeqCst);
516 Ok(())
517 }
518 }
519
520 #[test]
521 fn middleware_chain_creation() {
522 let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
523 let chain = MiddlewareChain::new(vec![mw]);
524 assert!(!chain.is_empty());
525 }
526
527 #[test]
528 fn empty_middleware_chain() {
529 let chain = MiddlewareChain::new(vec![]);
530 assert!(chain.is_empty());
531 }
532
533 #[test]
534 fn model_request_to_chat_request() {
535 let req = ModelRequest {
536 messages: vec![Message::human("hello")],
537 tools: vec![],
538 tool_choice: None,
539 system_prompt: Some("You are helpful.".to_string()),
540 };
541 let chat_req = req.to_chat_request();
542 assert_eq!(chat_req.messages.len(), 2);
543 assert!(chat_req.messages[0].is_system());
544 assert!(chat_req.messages[1].is_human());
545 }
546
547 #[test]
548 fn model_request_without_system_prompt() {
549 let req = ModelRequest {
550 messages: vec![Message::human("hello")],
551 tools: vec![],
552 tool_choice: None,
553 system_prompt: None,
554 };
555 let chat_req = req.to_chat_request();
556 assert_eq!(chat_req.messages.len(), 1);
557 }
558
559 #[tokio::test]
560 async fn file_hook_default_allows() {
561 let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
562 let chain = MiddlewareChain::new(vec![mw]);
563 let op = FileOp {
564 path: "/tmp/test".to_string(),
565 kind: FileOpKind::Write,
566 };
567 let decision = chain.run_before_file_op(&op).await.unwrap();
568 assert!(matches!(decision, FileOpDecision::Allow));
569 }
570
571 #[tokio::test]
572 async fn command_hook_default_allows() {
573 let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
574 let chain = MiddlewareChain::new(vec![mw]);
575 let cmd = CommandOp {
576 command: "ls".to_string(),
577 args: vec![],
578 working_dir: None,
579 };
580 let decision = chain.run_before_command(&cmd).await.unwrap();
581 assert!(matches!(decision, CommandDecision::Allow));
582 }
583}