wesichain_core/callbacks/
mod.rs1use std::collections::BTreeMap;
2use std::time::{Instant, SystemTime};
3
4use async_trait::async_trait;
5use serde::Serialize;
6use uuid::Uuid;
7
8use crate::Value;
9
10mod llm;
11mod wrappers;
12
13pub use llm::{LlmInput, LlmResult, TokenUsage};
14
15pub use wrappers::TracedRunnable;
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum RunType {
19 Chain,
20 Llm,
21 Tool,
22 Graph,
23 Agent,
24 Retriever,
25 Runnable,
26}
27
28#[derive(Clone, Debug)]
29pub struct RunContext {
30 pub run_id: Uuid,
31 pub parent_run_id: Option<Uuid>,
32 pub trace_id: Uuid,
33 pub run_type: RunType,
34 pub name: String,
35 pub start_time: SystemTime,
36 pub start_instant: Instant,
37 pub tags: Vec<String>,
38 pub metadata: BTreeMap<String, Value>,
39}
40
41impl RunContext {
42 pub fn root(
43 run_type: RunType,
44 name: String,
45 tags: Vec<String>,
46 metadata: BTreeMap<String, Value>,
47 ) -> Self {
48 let run_id = Uuid::new_v4();
49 Self {
50 run_id,
51 parent_run_id: None,
52 trace_id: run_id,
53 run_type,
54 name,
55 start_time: SystemTime::now(),
56 start_instant: Instant::now(),
57 tags,
58 metadata,
59 }
60 }
61
62 pub fn child(&self, run_type: RunType, name: String) -> Self {
63 let run_id = Uuid::new_v4();
64 Self {
65 run_id,
66 parent_run_id: Some(self.run_id),
67 trace_id: self.trace_id,
68 run_type,
69 name,
70 start_time: SystemTime::now(),
71 start_instant: Instant::now(),
72 tags: self.tags.clone(),
73 metadata: self.metadata.clone(),
74 }
75 }
76}
77
78#[derive(Clone, Debug, Default)]
79pub struct RunConfig {
80 pub callbacks: Option<CallbackManager>,
81 pub tags: Vec<String>,
82 pub metadata: BTreeMap<String, Value>,
83 pub name_override: Option<String>,
84}
85
86#[async_trait]
87pub trait CallbackHandler: Send + Sync {
88 async fn on_start(&self, ctx: &RunContext, inputs: &Value);
89 async fn on_end(&self, ctx: &RunContext, outputs: &Value, duration_ms: u128);
90 async fn on_error(&self, ctx: &RunContext, error: &Value, duration_ms: u128);
91 async fn on_stream_chunk(&self, _ctx: &RunContext, _chunk: &Value) {}
92
93 async fn on_llm_start(&self, ctx: &RunContext, input: &LlmInput) {
96 self.on_start(ctx, &serde_json::to_value(input).unwrap_or_default())
97 .await
98 }
99
100 async fn on_llm_end(&self, ctx: &RunContext, result: &LlmResult, duration_ms: u128) {
103 self.on_end(
104 ctx,
105 &serde_json::to_value(result).unwrap_or_default(),
106 duration_ms,
107 )
108 .await
109 }
110
111 async fn on_event(&self, _ctx: &RunContext, _event: &str, _data: &Value) {}
113}
114
115#[derive(Clone, Default)]
116pub struct CallbackManager {
117 handlers: Vec<std::sync::Arc<dyn CallbackHandler>>,
118}
119
120impl std::fmt::Debug for CallbackManager {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("CallbackManager")
123 .field("handlers", &self.handlers.len())
124 .finish()
125 }
126}
127
128impl CallbackManager {
129 pub fn new(handlers: Vec<std::sync::Arc<dyn CallbackHandler>>) -> Self {
130 Self { handlers }
131 }
132
133 pub fn noop() -> Self {
134 Self { handlers: vec![] }
135 }
136
137 pub fn is_noop(&self) -> bool {
138 self.handlers.is_empty()
139 }
140
141 pub fn add_handler(&mut self, handler: std::sync::Arc<dyn CallbackHandler>) {
142 self.handlers.push(handler);
143 }
144
145 pub async fn on_start(&self, ctx: &RunContext, inputs: &Value) {
146 for handler in &self.handlers {
147 handler.on_start(ctx, inputs).await;
148 }
149 }
150
151 pub async fn on_end(&self, ctx: &RunContext, outputs: &Value, duration_ms: u128) {
152 for handler in &self.handlers {
153 handler.on_end(ctx, outputs, duration_ms).await;
154 }
155 }
156
157 pub async fn on_error(&self, ctx: &RunContext, error: &Value, duration_ms: u128) {
158 for handler in &self.handlers {
159 handler.on_error(ctx, error, duration_ms).await;
160 }
161 }
162
163 pub async fn on_stream_chunk(&self, ctx: &RunContext, chunk: &Value) {
164 for handler in &self.handlers {
165 handler.on_stream_chunk(ctx, chunk).await;
166 }
167 }
168
169 pub async fn on_llm_start(&self, ctx: &RunContext, input: &LlmInput) {
170 for handler in &self.handlers {
171 handler.on_llm_start(ctx, input).await;
172 }
173 }
174
175 pub async fn on_llm_end(&self, ctx: &RunContext, result: &LlmResult, duration_ms: u128) {
176 for handler in &self.handlers {
177 handler.on_llm_end(ctx, result, duration_ms).await;
178 }
179 }
180
181 pub async fn on_event(&self, ctx: &RunContext, event: &str, data: &Value) {
182 for handler in &self.handlers {
183 handler.on_event(ctx, event, data).await;
184 }
185 }
186}
187
188pub trait ToTraceInput {
189 fn to_trace_input(&self) -> Value;
190}
191
192pub trait ToTraceOutput {
193 fn to_trace_output(&self) -> Value;
194}
195
196impl<T> ToTraceInput for T
197where
198 T: Serialize,
199{
200 fn to_trace_input(&self) -> Value {
201 serde_json::to_value(self).unwrap_or(Value::Null)
202 }
203}
204
205impl<T> ToTraceOutput for T
206where
207 T: Serialize,
208{
209 fn to_trace_output(&self) -> Value {
210 serde_json::to_value(self).unwrap_or(Value::Null)
211 }
212}
213
214pub fn ensure_object(value: Value) -> Value {
215 match value {
216 Value::Object(_) => value,
217 other => Value::Object(serde_json::Map::from_iter([("value".to_string(), other)])),
218 }
219}