Skip to main content

wesichain_core/
output_parsers.rs

1use crate::{LlmRequest, LlmResponse, Message, Role, Runnable, StreamEvent, WesichainError};
2use async_trait::async_trait;
3use futures::stream::BoxStream;
4use futures::StreamExt;
5use serde::de::DeserializeOwned;
6use serde_json::Value;
7use std::marker::PhantomData;
8
9/// Trait for output parsers that can transform input into a specific output.
10/// This is a specialized version of Runnable for parsing logic.
11#[async_trait]
12pub trait BaseOutputParser<Input: Send + Sync + 'static, Output: Send + Sync + 'static>:
13    Runnable<Input, Output> + Send + Sync
14{
15    async fn parse(&self, input: Input) -> Result<Output, WesichainError>;
16}
17
18/// A parser that converts `LlmResponse` or `String` into a `String`.
19/// If input is `LlmResponse`, it extracts the `content`.
20#[derive(Clone, Default)]
21pub struct StrOutputParser;
22
23#[async_trait]
24impl Runnable<LlmResponse, String> for StrOutputParser {
25    async fn invoke(&self, input: LlmResponse) -> Result<String, WesichainError> {
26        Ok(input.content)
27    }
28
29    fn stream(&self, input: LlmResponse) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
30        futures::stream::once(async move { Ok(StreamEvent::ContentChunk(input.content)) }).boxed()
31    }
32
33    fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
34        Some(crate::serde::SerializableRunnable::Parser {
35            kind: "str".to_string(),
36            target_type: None,
37        })
38    }
39}
40
41#[async_trait]
42impl Runnable<String, String> for StrOutputParser {
43    async fn invoke(&self, input: String) -> Result<String, WesichainError> {
44        Ok(input)
45    }
46
47    fn stream(&self, input: String) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
48        futures::stream::once(async move { Ok(StreamEvent::ContentChunk(input)) }).boxed()
49    }
50}
51
52/// A parser that parses a JSON string (or LlmResponse content) into a structured type or Value.
53#[derive(Clone, Default)]
54pub struct JsonOutputParser<T = Value> {
55    _marker: PhantomData<T>,
56}
57
58impl<T> JsonOutputParser<T> {
59    pub fn new() -> Self {
60        Self {
61            _marker: PhantomData,
62        }
63    }
64}
65
66#[async_trait]
67impl<T: DeserializeOwned + serde::Serialize + Send + Sync + 'static> Runnable<String, T>
68    for JsonOutputParser<T>
69{
70    async fn invoke(&self, input: String) -> Result<T, WesichainError> {
71        // Basic cleanup of markdown code blocks if present
72        let cleaned = input.trim();
73        let cleaned = if cleaned.starts_with("```json") {
74            cleaned
75                .trim_start_matches("```json")
76                .trim_end_matches("```")
77                .trim()
78        } else if cleaned.starts_with("```") {
79            cleaned
80                .trim_start_matches("```")
81                .trim_end_matches("```")
82                .trim()
83        } else {
84            cleaned
85        };
86
87        serde_json::from_str(cleaned).map_err(WesichainError::Serde)
88    }
89
90    fn stream(&self, input: String) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
91        futures::stream::once(async move {
92            let res = self.invoke(input).await?;
93            Ok(StreamEvent::Metadata {
94                key: "param".to_string(),
95                value: serde_json::to_value(res).unwrap_or(Value::Null),
96            })
97        })
98        .boxed()
99    }
100
101    fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
102        Some(crate::serde::SerializableRunnable::Parser {
103            kind: "json".to_string(),
104            target_type: Some(std::any::type_name::<T>().to_string()),
105        })
106    }
107}
108
109#[async_trait]
110impl<T: DeserializeOwned + serde::Serialize + Send + Sync + 'static> Runnable<LlmResponse, T>
111    for JsonOutputParser<T>
112{
113    async fn invoke(&self, input: LlmResponse) -> Result<T, WesichainError> {
114        // First check for JSON content
115        // If that fails, or if empty, we might check tool calls?
116        // But JsonOutputParser specifically targets JSON string content.
117        // For structured output via tools, we need a different parser or logic.
118        Runnable::<String, T>::invoke(self, input.content).await
119    }
120
121    fn stream(&self, input: LlmResponse) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
122        Runnable::<String, T>::stream(self, input.content)
123    }
124}
125
126/// A parser that extracts structured output from `LlmResponse`.
127/// It prioritizes `tool_calls` (first call args), then falls back to parsing `content` as JSON.
128#[derive(Clone, Default)]
129pub struct StructuredOutputParser<T = Value> {
130    _marker: PhantomData<T>,
131}
132
133impl<T> StructuredOutputParser<T> {
134    pub fn new() -> Self {
135        Self {
136            _marker: PhantomData,
137        }
138    }
139}
140
141#[async_trait]
142impl<T: DeserializeOwned + serde::Serialize + Send + Sync + 'static> Runnable<LlmResponse, T>
143    for StructuredOutputParser<T>
144{
145    async fn invoke(&self, input: LlmResponse) -> Result<T, WesichainError> {
146        // 1. Check tool calls
147        if let Some(call) = input.tool_calls.first() {
148            // Args is expected to be Value, which we can deserialize to T
149            return serde_json::from_value(call.args.clone()).map_err(WesichainError::Serde);
150        }
151
152        // 2. Fallback to content parsing (reuse logic from JsonOutputParser)
153        let content = input.content.trim();
154        let cleaned = if content.starts_with("```json") {
155            content
156                .trim_start_matches("```json")
157                .trim_end_matches("```")
158                .trim()
159        } else if content.starts_with("```") {
160            content
161                .trim_start_matches("```")
162                .trim_end_matches("```")
163                .trim()
164        } else {
165            content
166        };
167
168        if cleaned.is_empty() {
169            return Err(WesichainError::Custom(
170                "No structured output found in tool calls or content".to_string(),
171            ));
172        }
173
174        serde_json::from_str(cleaned).map_err(WesichainError::Serde)
175    }
176
177    fn stream(&self, _input: LlmResponse) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
178        // Structured parser hard to stream partial results unless we implement partial JSON parsing.
179        // For now, empty stream or just wait for invoke?
180        // Let's just return empty stream as we rely on invoke.
181        futures::stream::empty().boxed()
182    }
183
184    fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
185        Some(crate::serde::SerializableRunnable::Parser {
186            kind: "structured".to_string(),
187            target_type: Some(std::any::type_name::<T>().to_string()),
188        })
189    }
190}
191
192/// A parser/chain that wraps an LLM and a base parser.
193/// It attempts to invoke the LLM and parse the result.
194/// If parsing fails, it feeds the error back to the LLM to generate a fix.
195#[derive(Clone)]
196pub struct OutputFixingParser<L, P> {
197    llm: L,
198    parser: P,
199    max_retries: usize,
200}
201
202impl<L, P> OutputFixingParser<L, P> {
203    pub fn new(llm: L, parser: P, max_retries: usize) -> Self {
204        Self {
205            llm,
206            parser,
207            max_retries,
208        }
209    }
210}
211
212#[async_trait]
213impl<L, P, O> Runnable<LlmRequest, O> for OutputFixingParser<L, P>
214where
215    L: Runnable<LlmRequest, LlmResponse> + Clone + Send + Sync,
216    P: Runnable<LlmResponse, O> + Clone + Send + Sync,
217    O: Send + Sync + 'static,
218{
219    async fn invoke(&self, input: LlmRequest) -> Result<O, WesichainError> {
220        let mut current_request = input.clone();
221        let mut attempts = 0;
222
223        loop {
224            // 1. Invoke LLM
225            let response = self.llm.invoke(current_request.clone()).await?;
226
227            // 2. Try to parse
228            match self.parser.invoke(response.clone()).await {
229                Ok(output) => return Ok(output),
230                Err(e) => {
231                    attempts += 1;
232                    if attempts >= self.max_retries {
233                        return Err(e);
234                    }
235
236                    // 3. Prepare retry request
237                    // Append bad response and error message
238                    current_request.messages.push(Message {
239                        role: Role::Assistant,
240                        content: response.content.into(),
241                        tool_call_id: None,
242                        tool_calls: Vec::new(),
243                    });
244                    current_request.messages.push(Message {
245                        role: Role::User,
246                        content: format!(
247                            "The previous response failed to parse with error: {}. Please fix the output to match the required format.",
248                            e
249                        ).into(),
250                        tool_call_id: None,
251                        tool_calls: Vec::new(),
252                    });
253                }
254            }
255        }
256    }
257
258    fn stream(&self, input: LlmRequest) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
259        futures::stream::once(async move {
260            let _res = self.invoke(input).await?;
261            Ok(StreamEvent::Metadata {
262                key: "fixed_output".to_string(),
263                value: serde_json::to_value("Processing complete").unwrap_or(Value::Null),
264            })
265        })
266        .boxed()
267    }
268
269    fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
270        Some(crate::serde::SerializableRunnable::Parser {
271            kind: "output_fixing".to_string(),
272            target_type: None,
273        })
274    }
275}