wesichain_core/callbacks/
wrappers.rs1use futures::stream::BoxStream;
2
3use crate::callbacks::{
4 ensure_object, CallbackManager, RunContext, RunType, ToTraceInput, ToTraceOutput,
5};
6use crate::{Runnable, StreamEvent, Value, WesichainError};
7
8#[allow(dead_code)]
9pub struct TracedRunnable<R> {
10 inner: R,
11 manager: CallbackManager,
12 parent: RunContext,
13 run_type: RunType,
14 name: String,
15}
16
17impl<R> TracedRunnable<R> {
18 #[allow(dead_code)]
19 pub fn new(
20 inner: R,
21 manager: CallbackManager,
22 parent: RunContext,
23 run_type: RunType,
24 name: String,
25 ) -> Self {
26 Self {
27 inner,
28 manager,
29 parent,
30 run_type,
31 name,
32 }
33 }
34}
35
36#[async_trait::async_trait]
37impl<Input, Output, R> Runnable<Input, Output> for TracedRunnable<R>
38where
39 Input: Send + Sync + ToTraceInput + Clone + 'static,
40 Output: Send + Sync + ToTraceOutput + 'static,
41 R: Runnable<Input, Output> + Send + Sync,
42{
43 async fn invoke(&self, input: Input) -> Result<Output, WesichainError> {
44 if self.manager.is_noop() {
45 return self.inner.invoke(input).await;
46 }
47
48 let ctx = self.parent.child(self.run_type.clone(), self.name.clone());
49 let inputs = ensure_object(input.to_trace_input());
50 self.manager.on_start(&ctx, &inputs).await;
51
52 let result = self.inner.invoke(input).await;
53 let duration_ms = ctx.start_instant.elapsed().as_millis();
54
55 match &result {
56 Ok(output) => {
57 let outputs = ensure_object(output.to_trace_output());
58 self.manager.on_end(&ctx, &outputs, duration_ms).await;
59 }
60 Err(err) => {
61 let error = ensure_object(err.to_string().to_trace_output());
62 self.manager.on_error(&ctx, &error, duration_ms).await;
63 }
64 }
65
66 result
67 }
68
69 fn stream(&self, input: Input) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
70 if self.manager.is_noop() {
71 return self.inner.stream(input);
72 }
73
74 let manager = self.manager.clone();
75 let parent = self.parent.clone();
76 let run_type = self.run_type.clone();
77 let name = self.name.clone();
78 let inner_stream = self.inner.stream(input.clone());
79
80 Box::pin(async_stream::stream! {
81 let ctx = parent.child(run_type, name);
82 let inputs = ensure_object(input.to_trace_input());
83 manager.on_start(&ctx, &inputs).await;
84
85 let mut got_final_answer = false;
86
87 for await event in inner_stream {
88 match &event {
89 Ok(StreamEvent::ContentChunk(chunk)) => {
90 let chunk_value = Value::String(chunk.clone());
91 manager.on_stream_chunk(&ctx, &chunk_value).await;
92 }
93 Ok(StreamEvent::ToolCallDelta { id: _, delta }) => {
94 manager.on_stream_chunk(&ctx, delta).await;
95 }
96 Ok(StreamEvent::FinalAnswer(_)) => {
97 got_final_answer = true;
98 let outputs = ensure_object(Value::String("final_answer".to_string()));
99 let duration_ms = ctx.start_instant.elapsed().as_millis();
100 manager.on_end(&ctx, &outputs, duration_ms).await;
101 }
102 Ok(_) => {
103 }
105 Err(err) => {
106 let error = ensure_object(err.to_string().to_trace_output());
107 let duration_ms = ctx.start_instant.elapsed().as_millis();
108 manager.on_error(&ctx, &error, duration_ms).await;
109 }
110 }
111 yield event;
112 }
113
114 if !got_final_answer {
116 let outputs = ensure_object(Value::Object(serde_json::Map::new()));
117 let duration_ms = ctx.start_instant.elapsed().as_millis();
118 manager.on_end(&ctx, &outputs, duration_ms).await;
119 }
120 })
121 }
122}