1mod context_editing;
2mod human_in_the_loop;
3mod model_call_limit;
4mod model_fallback;
5mod summarization;
6mod todo_list;
7mod tool_call_limit;
8mod tool_retry;
9
10pub use context_editing::{ContextEditingMiddleware, ContextStrategy};
11pub use human_in_the_loop::{ApprovalCallback, HumanInTheLoopMiddleware};
12pub use model_call_limit::ModelCallLimitMiddleware;
13pub use model_fallback::ModelFallbackMiddleware;
14pub use summarization::SummarizationMiddleware;
15pub use todo_list::TodoListMiddleware;
16pub use tool_call_limit::ToolCallLimitMiddleware;
17pub use tool_retry::ToolRetryMiddleware;
18
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use serde_json::Value;
23use synaptic_core::{
24 ChatModel, ChatRequest, ChatResponse, Message, SynapticError, TokenUsage, ToolCall, ToolChoice,
25 ToolDefinition,
26};
27
28#[derive(Debug, Clone)]
37pub struct ModelRequest {
38 pub messages: Vec<Message>,
39 pub tools: Vec<ToolDefinition>,
40 pub tool_choice: Option<ToolChoice>,
41 pub system_prompt: Option<String>,
42}
43
44impl ModelRequest {
45 pub fn to_chat_request(&self) -> ChatRequest {
47 let mut messages = Vec::new();
48 if let Some(ref prompt) = self.system_prompt {
49 messages.push(Message::system(prompt));
50 }
51 messages.extend(self.messages.clone());
52 let mut req = ChatRequest::new(messages).with_tools(self.tools.clone());
53 if let Some(ref choice) = self.tool_choice {
54 req = req.with_tool_choice(choice.clone());
55 }
56 req
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct ModelResponse {
63 pub message: Message,
64 pub usage: Option<TokenUsage>,
65}
66
67impl From<ChatResponse> for ModelResponse {
68 fn from(resp: ChatResponse) -> Self {
69 Self {
70 message: resp.message,
71 usage: resp.usage,
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
82pub struct ToolCallRequest {
83 pub call: ToolCall,
84}
85
86#[async_trait]
95pub trait ModelCaller: Send + Sync {
96 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError>;
97}
98
99#[async_trait]
101pub trait ToolCaller: Send + Sync {
102 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError>;
103}
104
105#[async_trait]
125pub trait AgentMiddleware: Send + Sync {
126 async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
128 Ok(())
129 }
130
131 async fn after_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
133 Ok(())
134 }
135
136 async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
138 Ok(())
139 }
140
141 async fn after_model(
143 &self,
144 _request: &ModelRequest,
145 _response: &mut ModelResponse,
146 ) -> Result<(), SynapticError> {
147 Ok(())
148 }
149
150 async fn wrap_model_call(
152 &self,
153 request: ModelRequest,
154 next: &dyn ModelCaller,
155 ) -> Result<ModelResponse, SynapticError> {
156 next.call(request).await
157 }
158
159 async fn wrap_tool_call(
161 &self,
162 request: ToolCallRequest,
163 next: &dyn ToolCaller,
164 ) -> Result<Value, SynapticError> {
165 next.call(request).await
166 }
167}
168
169pub struct MiddlewareChain {
175 middlewares: Vec<Arc<dyn AgentMiddleware>>,
176}
177
178impl MiddlewareChain {
179 pub fn new(middlewares: Vec<Arc<dyn AgentMiddleware>>) -> Self {
180 Self { middlewares }
181 }
182
183 pub fn is_empty(&self) -> bool {
184 self.middlewares.is_empty()
185 }
186
187 pub async fn run_before_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
188 for mw in &self.middlewares {
189 mw.before_agent(messages).await?;
190 }
191 Ok(())
192 }
193
194 pub async fn run_after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
195 for mw in self.middlewares.iter().rev() {
196 mw.after_agent(messages).await?;
197 }
198 Ok(())
199 }
200
201 pub async fn run_before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
202 for mw in &self.middlewares {
203 mw.before_model(request).await?;
204 }
205 Ok(())
206 }
207
208 pub async fn run_after_model(
209 &self,
210 request: &ModelRequest,
211 response: &mut ModelResponse,
212 ) -> Result<(), SynapticError> {
213 for mw in self.middlewares.iter().rev() {
214 mw.after_model(request, response).await?;
215 }
216 Ok(())
217 }
218
219 pub async fn call_model(
224 &self,
225 mut request: ModelRequest,
226 base: &dyn ModelCaller,
227 ) -> Result<ModelResponse, SynapticError> {
228 self.run_before_model(&mut request).await?;
230
231 let mut response = if self.middlewares.is_empty() {
233 base.call(request.clone()).await?
234 } else {
235 let chain = WrapModelChain {
236 middlewares: &self.middlewares,
237 index: 0,
238 base,
239 };
240 chain.call(request.clone()).await?
241 };
242
243 self.run_after_model(&request, &mut response).await?;
245
246 Ok(response)
247 }
248
249 pub async fn call_tool(
251 &self,
252 request: ToolCallRequest,
253 base: &dyn ToolCaller,
254 ) -> Result<Value, SynapticError> {
255 if self.middlewares.is_empty() {
256 base.call(request).await
257 } else {
258 let chain = WrapToolChain {
259 middlewares: &self.middlewares,
260 index: 0,
261 base,
262 };
263 chain.call(request).await
264 }
265 }
266}
267
268struct WrapModelChain<'a> {
271 middlewares: &'a [Arc<dyn AgentMiddleware>],
272 index: usize,
273 base: &'a dyn ModelCaller,
274}
275
276#[async_trait]
277impl ModelCaller for WrapModelChain<'_> {
278 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
279 if self.index >= self.middlewares.len() {
280 self.base.call(request).await
281 } else {
282 let next = WrapModelChain {
283 middlewares: self.middlewares,
284 index: self.index + 1,
285 base: self.base,
286 };
287 self.middlewares[self.index]
288 .wrap_model_call(request, &next)
289 .await
290 }
291 }
292}
293
294struct WrapToolChain<'a> {
295 middlewares: &'a [Arc<dyn AgentMiddleware>],
296 index: usize,
297 base: &'a dyn ToolCaller,
298}
299
300#[async_trait]
301impl ToolCaller for WrapToolChain<'_> {
302 async fn call(&self, request: ToolCallRequest) -> Result<Value, SynapticError> {
303 if self.index >= self.middlewares.len() {
304 self.base.call(request).await
305 } else {
306 let next = WrapToolChain {
307 middlewares: self.middlewares,
308 index: self.index + 1,
309 base: self.base,
310 };
311 self.middlewares[self.index]
312 .wrap_tool_call(request, &next)
313 .await
314 }
315 }
316}
317
318pub struct BaseChatModelCaller {
324 model: Arc<dyn ChatModel>,
325}
326
327impl BaseChatModelCaller {
328 pub fn new(model: Arc<dyn ChatModel>) -> Self {
329 Self { model }
330 }
331}
332
333#[async_trait]
334impl ModelCaller for BaseChatModelCaller {
335 async fn call(&self, request: ModelRequest) -> Result<ModelResponse, SynapticError> {
336 let chat_request = request.to_chat_request();
337 let response = self.model.chat(chat_request).await?;
338 Ok(response.into())
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use std::sync::atomic::{AtomicUsize, Ordering};
346
347 struct CountingMiddleware {
348 before_count: AtomicUsize,
349 after_count: AtomicUsize,
350 }
351
352 impl CountingMiddleware {
353 fn new() -> Self {
354 Self {
355 before_count: AtomicUsize::new(0),
356 after_count: AtomicUsize::new(0),
357 }
358 }
359 }
360
361 #[async_trait]
362 impl AgentMiddleware for CountingMiddleware {
363 async fn before_model(&self, _request: &mut ModelRequest) -> Result<(), SynapticError> {
364 self.before_count.fetch_add(1, Ordering::SeqCst);
365 Ok(())
366 }
367
368 async fn after_model(
369 &self,
370 _request: &ModelRequest,
371 _response: &mut ModelResponse,
372 ) -> Result<(), SynapticError> {
373 self.after_count.fetch_add(1, Ordering::SeqCst);
374 Ok(())
375 }
376 }
377
378 #[test]
379 fn middleware_chain_creation() {
380 let mw: Arc<dyn AgentMiddleware> = Arc::new(CountingMiddleware::new());
381 let chain = MiddlewareChain::new(vec![mw]);
382 assert!(!chain.is_empty());
383 }
384
385 #[test]
386 fn empty_middleware_chain() {
387 let chain = MiddlewareChain::new(vec![]);
388 assert!(chain.is_empty());
389 }
390
391 #[test]
392 fn model_request_to_chat_request() {
393 let req = ModelRequest {
394 messages: vec![Message::human("hello")],
395 tools: vec![],
396 tool_choice: None,
397 system_prompt: Some("You are helpful.".to_string()),
398 };
399 let chat_req = req.to_chat_request();
400 assert_eq!(chat_req.messages.len(), 2);
401 assert!(chat_req.messages[0].is_system());
402 assert!(chat_req.messages[1].is_human());
403 }
404
405 #[test]
406 fn model_request_without_system_prompt() {
407 let req = ModelRequest {
408 messages: vec![Message::human("hello")],
409 tools: vec![],
410 tool_choice: None,
411 system_prompt: None,
412 };
413 let chat_req = req.to_chat_request();
414 assert_eq!(chat_req.messages.len(), 1);
415 }
416}