Skip to main content

wesichain_core/callbacks/
wrappers.rs

1use futures::stream::BoxStream;
2
3use crate::callbacks::{
4    ensure_object, CallbackManager, RunContext, RunType, ToTraceInput, ToTraceOutput,
5};
6use crate::{Runnable, StreamEvent, Value, WesichainError};
7
8#[allow(dead_code)]
9pub struct TracedRunnable<R> {
10    inner: R,
11    manager: CallbackManager,
12    parent: RunContext,
13    run_type: RunType,
14    name: String,
15}
16
17impl<R> TracedRunnable<R> {
18    #[allow(dead_code)]
19    pub fn new(
20        inner: R,
21        manager: CallbackManager,
22        parent: RunContext,
23        run_type: RunType,
24        name: String,
25    ) -> Self {
26        Self {
27            inner,
28            manager,
29            parent,
30            run_type,
31            name,
32        }
33    }
34}
35
36#[async_trait::async_trait]
37impl<Input, Output, R> Runnable<Input, Output> for TracedRunnable<R>
38where
39    Input: Send + Sync + ToTraceInput + Clone + 'static,
40    Output: Send + Sync + ToTraceOutput + 'static,
41    R: Runnable<Input, Output> + Send + Sync,
42{
43    async fn invoke(&self, input: Input) -> Result<Output, WesichainError> {
44        if self.manager.is_noop() {
45            return self.inner.invoke(input).await;
46        }
47
48        let ctx = self.parent.child(self.run_type.clone(), self.name.clone());
49        let inputs = ensure_object(input.to_trace_input());
50        self.manager.on_start(&ctx, &inputs).await;
51
52        let result = self.inner.invoke(input).await;
53        let duration_ms = ctx.start_instant.elapsed().as_millis();
54
55        match &result {
56            Ok(output) => {
57                let outputs = ensure_object(output.to_trace_output());
58                self.manager.on_end(&ctx, &outputs, duration_ms).await;
59            }
60            Err(err) => {
61                let error = ensure_object(err.to_string().to_trace_output());
62                self.manager.on_error(&ctx, &error, duration_ms).await;
63            }
64        }
65
66        result
67    }
68
69    fn stream(&self, input: Input) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
70        if self.manager.is_noop() {
71            return self.inner.stream(input);
72        }
73
74        let manager = self.manager.clone();
75        let parent = self.parent.clone();
76        let run_type = self.run_type.clone();
77        let name = self.name.clone();
78        let inner_stream = self.inner.stream(input.clone());
79
80        Box::pin(async_stream::stream! {
81            let ctx = parent.child(run_type, name);
82            let inputs = ensure_object(input.to_trace_input());
83            manager.on_start(&ctx, &inputs).await;
84
85            let mut got_final_answer = false;
86
87            for await event in inner_stream {
88                match &event {
89                    Ok(StreamEvent::ContentChunk(chunk)) => {
90                        let chunk_value = Value::String(chunk.clone());
91                        manager.on_stream_chunk(&ctx, &chunk_value).await;
92                    }
93                    Ok(StreamEvent::ToolCallDelta { id: _, delta }) => {
94                        manager.on_stream_chunk(&ctx, delta).await;
95                    }
96                    Ok(StreamEvent::FinalAnswer(_)) => {
97                        got_final_answer = true;
98                        let outputs = ensure_object(Value::String("final_answer".to_string()));
99                        let duration_ms = ctx.start_instant.elapsed().as_millis();
100                        manager.on_end(&ctx, &outputs, duration_ms).await;
101                    }
102                    Ok(_) => {
103                        // Other variants don't trigger specific callbacks
104                    }
105                    Err(err) => {
106                        let error = ensure_object(err.to_string().to_trace_output());
107                        let duration_ms = ctx.start_instant.elapsed().as_millis();
108                        manager.on_error(&ctx, &error, duration_ms).await;
109                    }
110                }
111                yield event;
112            }
113
114            // If stream ended without FinalAnswer, call on_end
115            if !got_final_answer {
116                let outputs = ensure_object(Value::Object(serde_json::Map::new()));
117                let duration_ms = ctx.start_instant.elapsed().as_millis();
118                manager.on_end(&ctx, &outputs, duration_ms).await;
119            }
120        })
121    }
122}