1use crate::{LlmRequest, LlmResponse, Message, Role, Runnable, StreamEvent, WesichainError};
2use async_trait::async_trait;
3use futures::stream::BoxStream;
4use futures::StreamExt;
5use serde::de::DeserializeOwned;
6use serde_json::Value;
7use std::marker::PhantomData;
8
9#[async_trait]
12pub trait BaseOutputParser<Input: Send + Sync + 'static, Output: Send + Sync + 'static>:
13 Runnable<Input, Output> + Send + Sync
14{
15 async fn parse(&self, input: Input) -> Result<Output, WesichainError>;
16}
17
18#[derive(Clone, Default)]
21pub struct StrOutputParser;
22
23#[async_trait]
24impl Runnable<LlmResponse, String> for StrOutputParser {
25 async fn invoke(&self, input: LlmResponse) -> Result<String, WesichainError> {
26 Ok(input.content)
27 }
28
29 fn stream(&self, input: LlmResponse) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
30 futures::stream::once(async move { Ok(StreamEvent::ContentChunk(input.content)) }).boxed()
31 }
32
33 fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
34 Some(crate::serde::SerializableRunnable::Parser {
35 kind: "str".to_string(),
36 target_type: None,
37 })
38 }
39}
40
41#[async_trait]
42impl Runnable<String, String> for StrOutputParser {
43 async fn invoke(&self, input: String) -> Result<String, WesichainError> {
44 Ok(input)
45 }
46
47 fn stream(&self, input: String) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
48 futures::stream::once(async move { Ok(StreamEvent::ContentChunk(input)) }).boxed()
49 }
50}
51
52#[derive(Clone, Default)]
54pub struct JsonOutputParser<T = Value> {
55 _marker: PhantomData<T>,
56}
57
58impl<T> JsonOutputParser<T> {
59 pub fn new() -> Self {
60 Self {
61 _marker: PhantomData,
62 }
63 }
64}
65
66#[async_trait]
67impl<T: DeserializeOwned + serde::Serialize + Send + Sync + 'static> Runnable<String, T>
68 for JsonOutputParser<T>
69{
70 async fn invoke(&self, input: String) -> Result<T, WesichainError> {
71 let cleaned = input.trim();
73 let cleaned = if cleaned.starts_with("```json") {
74 cleaned
75 .trim_start_matches("```json")
76 .trim_end_matches("```")
77 .trim()
78 } else if cleaned.starts_with("```") {
79 cleaned
80 .trim_start_matches("```")
81 .trim_end_matches("```")
82 .trim()
83 } else {
84 cleaned
85 };
86
87 serde_json::from_str(cleaned).map_err(WesichainError::Serde)
88 }
89
90 fn stream(&self, input: String) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
91 futures::stream::once(async move {
92 let res = self.invoke(input).await?;
93 Ok(StreamEvent::Metadata {
94 key: "param".to_string(),
95 value: serde_json::to_value(res).unwrap_or(Value::Null),
96 })
97 })
98 .boxed()
99 }
100
101 fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
102 Some(crate::serde::SerializableRunnable::Parser {
103 kind: "json".to_string(),
104 target_type: Some(std::any::type_name::<T>().to_string()),
105 })
106 }
107}
108
109#[async_trait]
110impl<T: DeserializeOwned + serde::Serialize + Send + Sync + 'static> Runnable<LlmResponse, T>
111 for JsonOutputParser<T>
112{
113 async fn invoke(&self, input: LlmResponse) -> Result<T, WesichainError> {
114 Runnable::<String, T>::invoke(self, input.content).await
119 }
120
121 fn stream(&self, input: LlmResponse) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
122 Runnable::<String, T>::stream(self, input.content)
123 }
124}
125
126#[derive(Clone, Default)]
129pub struct StructuredOutputParser<T = Value> {
130 _marker: PhantomData<T>,
131}
132
133impl<T> StructuredOutputParser<T> {
134 pub fn new() -> Self {
135 Self {
136 _marker: PhantomData,
137 }
138 }
139}
140
141#[async_trait]
142impl<T: DeserializeOwned + serde::Serialize + Send + Sync + 'static> Runnable<LlmResponse, T>
143 for StructuredOutputParser<T>
144{
145 async fn invoke(&self, input: LlmResponse) -> Result<T, WesichainError> {
146 if let Some(call) = input.tool_calls.first() {
148 return serde_json::from_value(call.args.clone()).map_err(WesichainError::Serde);
150 }
151
152 let content = input.content.trim();
154 let cleaned = if content.starts_with("```json") {
155 content
156 .trim_start_matches("```json")
157 .trim_end_matches("```")
158 .trim()
159 } else if content.starts_with("```") {
160 content
161 .trim_start_matches("```")
162 .trim_end_matches("```")
163 .trim()
164 } else {
165 content
166 };
167
168 if cleaned.is_empty() {
169 return Err(WesichainError::Custom(
170 "No structured output found in tool calls or content".to_string(),
171 ));
172 }
173
174 serde_json::from_str(cleaned).map_err(WesichainError::Serde)
175 }
176
177 fn stream(&self, _input: LlmResponse) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
178 futures::stream::empty().boxed()
182 }
183
184 fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
185 Some(crate::serde::SerializableRunnable::Parser {
186 kind: "structured".to_string(),
187 target_type: Some(std::any::type_name::<T>().to_string()),
188 })
189 }
190}
191
192#[derive(Clone)]
196pub struct OutputFixingParser<L, P> {
197 llm: L,
198 parser: P,
199 max_retries: usize,
200}
201
202impl<L, P> OutputFixingParser<L, P> {
203 pub fn new(llm: L, parser: P, max_retries: usize) -> Self {
204 Self {
205 llm,
206 parser,
207 max_retries,
208 }
209 }
210}
211
212#[async_trait]
213impl<L, P, O> Runnable<LlmRequest, O> for OutputFixingParser<L, P>
214where
215 L: Runnable<LlmRequest, LlmResponse> + Clone + Send + Sync,
216 P: Runnable<LlmResponse, O> + Clone + Send + Sync,
217 O: Send + Sync + 'static,
218{
219 async fn invoke(&self, input: LlmRequest) -> Result<O, WesichainError> {
220 let mut current_request = input.clone();
221 let mut attempts = 0;
222
223 loop {
224 let response = self.llm.invoke(current_request.clone()).await?;
226
227 match self.parser.invoke(response.clone()).await {
229 Ok(output) => return Ok(output),
230 Err(e) => {
231 attempts += 1;
232 if attempts >= self.max_retries {
233 return Err(e);
234 }
235
236 current_request.messages.push(Message {
239 role: Role::Assistant,
240 content: response.content.into(),
241 tool_call_id: None,
242 tool_calls: Vec::new(),
243 });
244 current_request.messages.push(Message {
245 role: Role::User,
246 content: format!(
247 "The previous response failed to parse with error: {}. Please fix the output to match the required format.",
248 e
249 ).into(),
250 tool_call_id: None,
251 tool_calls: Vec::new(),
252 });
253 }
254 }
255 }
256 }
257
258 fn stream(&self, input: LlmRequest) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
259 futures::stream::once(async move {
260 let _res = self.invoke(input).await?;
261 Ok(StreamEvent::Metadata {
262 key: "fixed_output".to_string(),
263 value: serde_json::to_value("Processing complete").unwrap_or(Value::Null),
264 })
265 })
266 .boxed()
267 }
268
269 fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
270 Some(crate::serde::SerializableRunnable::Parser {
271 kind: "output_fixing".to_string(),
272 target_type: None,
273 })
274 }
275}