Skip to main content

synaptic_middleware/
callback_adapter.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{CallbackHandler, Message, RunEvent, SynapticError};
5use uuid::Uuid;
6
7use crate::{AgentMiddleware, ModelRequest, ModelResponse};
8
9/// Adapts a [`CallbackHandler`] into an [`AgentMiddleware`].
10///
11/// This allows any callback handler (e.g., `OpenTelemetryCallback`,
12/// `RecordingCallback`) to be used in the middleware stack of a Deep Agent
13/// or graph agent that only accepts `Vec<Arc<dyn AgentMiddleware>>`.
14///
15/// The adapter fires lifecycle events at each middleware hook:
16/// - `before_agent` → `RunEvent::RunStarted`
17/// - `after_agent` → `RunEvent::RunFinished`
18/// - `before_model` → `RunEvent::BeforeMessage`
19/// - `after_model` → `RunEvent::LlmCalled`
20pub struct CallbackMiddleware {
21    handler: Arc<dyn CallbackHandler>,
22    run_id: String,
23}
24
25impl CallbackMiddleware {
26    /// Create a new adapter wrapping the given callback handler.
27    pub fn new(handler: Arc<dyn CallbackHandler>) -> Self {
28        Self {
29            handler,
30            run_id: Uuid::new_v4().to_string(),
31        }
32    }
33
34    /// Create with a specific run ID (useful for correlating spans).
35    pub fn with_run_id(handler: Arc<dyn CallbackHandler>, run_id: String) -> Self {
36        Self { handler, run_id }
37    }
38}
39
40#[async_trait]
41impl AgentMiddleware for CallbackMiddleware {
42    async fn before_agent(&self, _messages: &mut Vec<Message>) -> Result<(), SynapticError> {
43        self.handler
44            .on_event(RunEvent::RunStarted {
45                run_id: self.run_id.clone(),
46                session_id: String::new(),
47            })
48            .await
49    }
50
51    async fn after_agent(&self, messages: &mut Vec<Message>) -> Result<(), SynapticError> {
52        let output = messages
53            .last()
54            .map(|m| m.content().to_string())
55            .unwrap_or_default();
56        self.handler
57            .on_event(RunEvent::RunFinished {
58                run_id: self.run_id.clone(),
59                output,
60            })
61            .await
62    }
63
64    async fn before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
65        self.handler
66            .on_event(RunEvent::BeforeMessage {
67                run_id: self.run_id.clone(),
68                message_count: request.messages.len(),
69            })
70            .await
71    }
72
73    async fn after_model(
74        &self,
75        _request: &ModelRequest,
76        response: &mut ModelResponse,
77    ) -> Result<(), SynapticError> {
78        self.handler
79            .on_event(RunEvent::LlmCalled {
80                run_id: self.run_id.clone(),
81                message_count: response.message.content().len(),
82            })
83            .await
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use std::sync::Mutex;
91    use synaptic_core::RunEvent;
92
93    struct RecordHandler {
94        events: Mutex<Vec<String>>,
95    }
96
97    impl RecordHandler {
98        fn new() -> Self {
99            Self {
100                events: Mutex::new(Vec::new()),
101            }
102        }
103
104        fn events(&self) -> Vec<String> {
105            self.events.lock().unwrap().clone()
106        }
107    }
108
109    #[async_trait]
110    impl CallbackHandler for RecordHandler {
111        async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError> {
112            let label = match &event {
113                RunEvent::RunStarted { .. } => "RunStarted",
114                RunEvent::RunFinished { .. } => "RunFinished",
115                RunEvent::BeforeMessage { .. } => "BeforeMessage",
116                RunEvent::LlmCalled { .. } => "LlmCalled",
117                _ => "Other",
118            };
119            self.events.lock().unwrap().push(label.to_string());
120            Ok(())
121        }
122    }
123
124    #[tokio::test]
125    async fn callback_middleware_fires_events() {
126        let handler = Arc::new(RecordHandler::new());
127        let mw = CallbackMiddleware::new(handler.clone());
128
129        // before_agent
130        let mut messages = vec![Message::human("hello")];
131        mw.before_agent(&mut messages).await.unwrap();
132
133        // before_model
134        let mut req = ModelRequest {
135            messages: vec![Message::human("hello")],
136            tools: vec![],
137            tool_choice: None,
138            system_prompt: None,
139        };
140        mw.before_model(&mut req).await.unwrap();
141
142        // after_model
143        let mut resp = ModelResponse {
144            message: Message::ai("world"),
145            usage: None,
146        };
147        mw.after_model(&req, &mut resp).await.unwrap();
148
149        // after_agent
150        messages.push(Message::ai("world"));
151        mw.after_agent(&mut messages).await.unwrap();
152
153        let events = handler.events();
154        assert_eq!(
155            events,
156            vec!["RunStarted", "BeforeMessage", "LlmCalled", "RunFinished"]
157        );
158    }
159}