Skip to main content

wesichain_core/callbacks/
mod.rs

1use std::collections::BTreeMap;
2use std::time::{Instant, SystemTime};
3
4use async_trait::async_trait;
5use serde::Serialize;
6use uuid::Uuid;
7
8use crate::Value;
9
10mod llm;
11mod wrappers;
12
13pub use llm::{LlmInput, LlmResult, TokenUsage};
14
15pub use wrappers::TracedRunnable;
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum RunType {
19    Chain,
20    Llm,
21    Tool,
22    Graph,
23    Agent,
24    Retriever,
25    Runnable,
26}
27
28#[derive(Clone, Debug)]
29pub struct RunContext {
30    pub run_id: Uuid,
31    pub parent_run_id: Option<Uuid>,
32    pub trace_id: Uuid,
33    pub run_type: RunType,
34    pub name: String,
35    pub start_time: SystemTime,
36    pub start_instant: Instant,
37    pub tags: Vec<String>,
38    pub metadata: BTreeMap<String, Value>,
39}
40
41impl RunContext {
42    pub fn root(
43        run_type: RunType,
44        name: String,
45        tags: Vec<String>,
46        metadata: BTreeMap<String, Value>,
47    ) -> Self {
48        let run_id = Uuid::new_v4();
49        Self {
50            run_id,
51            parent_run_id: None,
52            trace_id: run_id,
53            run_type,
54            name,
55            start_time: SystemTime::now(),
56            start_instant: Instant::now(),
57            tags,
58            metadata,
59        }
60    }
61
62    pub fn child(&self, run_type: RunType, name: String) -> Self {
63        let run_id = Uuid::new_v4();
64        Self {
65            run_id,
66            parent_run_id: Some(self.run_id),
67            trace_id: self.trace_id,
68            run_type,
69            name,
70            start_time: SystemTime::now(),
71            start_instant: Instant::now(),
72            tags: self.tags.clone(),
73            metadata: self.metadata.clone(),
74        }
75    }
76}
77
78#[derive(Clone, Debug, Default)]
79pub struct RunConfig {
80    pub callbacks: Option<CallbackManager>,
81    pub tags: Vec<String>,
82    pub metadata: BTreeMap<String, Value>,
83    pub name_override: Option<String>,
84}
85
86#[async_trait]
87pub trait CallbackHandler: Send + Sync {
88    async fn on_start(&self, ctx: &RunContext, inputs: &Value);
89    async fn on_end(&self, ctx: &RunContext, outputs: &Value, duration_ms: u128);
90    async fn on_error(&self, ctx: &RunContext, error: &Value, duration_ms: u128);
91    async fn on_stream_chunk(&self, _ctx: &RunContext, _chunk: &Value) {}
92
93    /// Called when an LLM call starts. Override for structured LLM observability.
94    /// Default implementation calls `on_start` with serialized input.
95    async fn on_llm_start(&self, ctx: &RunContext, input: &LlmInput) {
96        self.on_start(ctx, &serde_json::to_value(input).unwrap_or_default())
97            .await
98    }
99
100    /// Called when an LLM call ends. Override for structured LLM observability.
101    /// Default implementation calls `on_end` with serialized result.
102    async fn on_llm_end(&self, ctx: &RunContext, result: &LlmResult, duration_ms: u128) {
103        self.on_end(
104            ctx,
105            &serde_json::to_value(result).unwrap_or_default(),
106            duration_ms,
107        )
108        .await
109    }
110
111    /// Generic event hook for custom events (e.g. checkpoint saved)
112    async fn on_event(&self, _ctx: &RunContext, _event: &str, _data: &Value) {}
113}
114
115#[derive(Clone, Default)]
116pub struct CallbackManager {
117    handlers: Vec<std::sync::Arc<dyn CallbackHandler>>,
118}
119
120impl std::fmt::Debug for CallbackManager {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("CallbackManager")
123            .field("handlers", &self.handlers.len())
124            .finish()
125    }
126}
127
128impl CallbackManager {
129    pub fn new(handlers: Vec<std::sync::Arc<dyn CallbackHandler>>) -> Self {
130        Self { handlers }
131    }
132
133    pub fn noop() -> Self {
134        Self { handlers: vec![] }
135    }
136
137    pub fn is_noop(&self) -> bool {
138        self.handlers.is_empty()
139    }
140
141    pub fn add_handler(&mut self, handler: std::sync::Arc<dyn CallbackHandler>) {
142        self.handlers.push(handler);
143    }
144
145    pub async fn on_start(&self, ctx: &RunContext, inputs: &Value) {
146        for handler in &self.handlers {
147            handler.on_start(ctx, inputs).await;
148        }
149    }
150
151    pub async fn on_end(&self, ctx: &RunContext, outputs: &Value, duration_ms: u128) {
152        for handler in &self.handlers {
153            handler.on_end(ctx, outputs, duration_ms).await;
154        }
155    }
156
157    pub async fn on_error(&self, ctx: &RunContext, error: &Value, duration_ms: u128) {
158        for handler in &self.handlers {
159            handler.on_error(ctx, error, duration_ms).await;
160        }
161    }
162
163    pub async fn on_stream_chunk(&self, ctx: &RunContext, chunk: &Value) {
164        for handler in &self.handlers {
165            handler.on_stream_chunk(ctx, chunk).await;
166        }
167    }
168
169    pub async fn on_llm_start(&self, ctx: &RunContext, input: &LlmInput) {
170        for handler in &self.handlers {
171            handler.on_llm_start(ctx, input).await;
172        }
173    }
174
175    pub async fn on_llm_end(&self, ctx: &RunContext, result: &LlmResult, duration_ms: u128) {
176        for handler in &self.handlers {
177            handler.on_llm_end(ctx, result, duration_ms).await;
178        }
179    }
180
181    pub async fn on_event(&self, ctx: &RunContext, event: &str, data: &Value) {
182        for handler in &self.handlers {
183            handler.on_event(ctx, event, data).await;
184        }
185    }
186}
187
188pub trait ToTraceInput {
189    fn to_trace_input(&self) -> Value;
190}
191
192pub trait ToTraceOutput {
193    fn to_trace_output(&self) -> Value;
194}
195
196impl<T> ToTraceInput for T
197where
198    T: Serialize,
199{
200    fn to_trace_input(&self) -> Value {
201        serde_json::to_value(self).unwrap_or(Value::Null)
202    }
203}
204
205impl<T> ToTraceOutput for T
206where
207    T: Serialize,
208{
209    fn to_trace_output(&self) -> Value {
210        serde_json::to_value(self).unwrap_or(Value::Null)
211    }
212}
213
214pub fn ensure_object(value: Value) -> Value {
215    match value {
216        Value::Object(_) => value,
217        other => Value::Object(serde_json::Map::from_iter([("value".to_string(), other)])),
218    }
219}