synaptic_middleware/
callback_adapter.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{CallbackHandler, Message, RunEvent, SynapticError};
5use uuid::Uuid;
6
7use crate::{AgentMiddleware, ModelRequest, ModelResponse};
8
9pub struct CallbackMiddleware {
21 handler: Arc<dyn CallbackHandler>,
22 run_id: String,
23}
24
25impl CallbackMiddleware {
26 pub fn new(handler: Arc<dyn CallbackHandler>) -> Self {
28 Self {
29 handler,
30 run_id: Uuid::new_v4().to_string(),
31 }
32 }
33
34 pub fn with_run_id(handler: Arc<dyn CallbackHandler>, run_id: String) -> Self {
36 Self { handler, run_id }
37 }
38}
39
40#[async_trait]
41impl AgentMiddleware for CallbackMiddleware {
42 async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
43 self.handler
44 .on_event(RunEvent::RunStarted {
45 run_id: self.run_id.clone(),
46 session_id: String::new(),
47 })
48 .await
49 }
50
51 async fn after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
52 let output = messages
53 .last()
54 .map(|m| m.content().to_string())
55 .unwrap_or_default();
56 self.handler
57 .on_event(RunEvent::RunFinished {
58 run_id: self.run_id.clone(),
59 output,
60 })
61 .await
62 }
63
64 async fn before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
65 self.handler
66 .on_event(RunEvent::BeforeMessage {
67 run_id: self.run_id.clone(),
68 message_count: request.messages.len(),
69 })
70 .await
71 }
72
73 async fn after_model(
74 &self,
75 _request: &ModelRequest,
76 response: &mut ModelResponse,
77 ) -> Result<(), SynapticError> {
78 self.handler
79 .on_event(RunEvent::LlmCalled {
80 run_id: self.run_id.clone(),
81 message_count: response.message.content().len(),
82 })
83 .await
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use std::sync::Mutex;
91 use synaptic_core::RunEvent;
92
93 struct RecordHandler {
94 events: Mutex<Vec<String>>,
95 }
96
97 impl RecordHandler {
98 fn new() -> Self {
99 Self {
100 events: Mutex::new(Vec::new()),
101 }
102 }
103
104 fn events(&self) -> Vec<String> {
105 self.events.lock().unwrap().clone()
106 }
107 }
108
109 #[async_trait]
110 impl CallbackHandler for RecordHandler {
111 async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError> {
112 let label = match &event {
113 RunEvent::RunStarted { .. } => "RunStarted",
114 RunEvent::RunFinished { .. } => "RunFinished",
115 RunEvent::BeforeMessage { .. } => "BeforeMessage",
116 RunEvent::LlmCalled { .. } => "LlmCalled",
117 _ => "Other",
118 };
119 self.events.lock().unwrap().push(label.to_string());
120 Ok(())
121 }
122 }
123
124 #[tokio::test]
125 async fn callback_middleware_fires_events() {
126 let handler = Arc::new(RecordHandler::new());
127 let mw = CallbackMiddleware::new(handler.clone());
128
129 let mut messages = vec![Message::human("hello")];
131 mw.before_agent(&mut messages).await.unwrap();
132
133 let mut req = ModelRequest {
135 messages: vec![Message::human("hello")],
136 tools: vec![],
137 tool_choice: None,
138 system_prompt: None,
139 };
140 mw.before_model(&mut req).await.unwrap();
141
142 let mut resp = ModelResponse {
144 message: Message::ai("world"),
145 usage: None,
146 };
147 mw.after_model(&req, &mut resp).await.unwrap();
148
149 messages.push(Message::ai("world"));
151 mw.after_agent(&mut messages).await.unwrap();
152
153 let events = handler.events();
154 assert_eq!(
155 events,
156 vec!["RunStarted", "BeforeMessage", "LlmCalled", "RunFinished"]
157 );
158 }
159}