sh_layer3/
output_parsers.rs1use crate::types::Layer3Result;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9#[async_trait]
13pub trait OutputParser: Send + Sync {
14 fn name(&self) -> &str;
16
17 async fn parse(&self, output: &str) -> Layer3Result<ParsedOutput>;
19
20 fn get_format_instructions(&self) -> String;
22}
23
24#[derive(Debug, Clone)]
26pub struct ParsedOutput {
27 pub data: serde_json::Value,
29 pub raw: String,
31 pub success: bool,
33 pub error: Option<String>,
35}
36
37#[allow(dead_code)]
39pub struct JsonParser {
40 #[allow(dead_code)]
42 strict: bool,
43}
44
45impl JsonParser {
46 pub fn new(strict: bool) -> Self {
47 Self { strict }
48 }
49}
50
51impl Default for JsonParser {
52 fn default() -> Self {
53 Self::new(false)
54 }
55}
56
57#[async_trait]
58impl OutputParser for JsonParser {
59 fn name(&self) -> &str {
60 "json"
61 }
62
63 async fn parse(&self, output: &str) -> Layer3Result<ParsedOutput> {
64 let trimmed = output.trim();
66
67 if let Ok(data) = serde_json::from_str::<serde_json::Value>(trimmed) {
69 return Ok(ParsedOutput {
70 data,
71 raw: output.to_string(),
72 success: true,
73 error: None,
74 });
75 }
76
77 let json_start = trimmed.find('{').or_else(|| trimmed.find('['));
79 let json_end = trimmed.rfind('}').or_else(|| trimmed.rfind(']'));
80
81 if let (Some(start), Some(end)) = (json_start, json_end) {
82 let json_str = &trimmed[start..=end];
83 if let Ok(data) = serde_json::from_str::<serde_json::Value>(json_str) {
84 return Ok(ParsedOutput {
85 data,
86 raw: output.to_string(),
87 success: true,
88 error: None,
89 });
90 }
91 }
92
93 Ok(ParsedOutput {
94 data: serde_json::Value::Null,
95 raw: output.to_string(),
96 success: false,
97 error: Some("Failed to parse JSON".to_string()),
98 })
99 }
100
101 fn get_format_instructions(&self) -> String {
102 "Output should be a valid JSON object.".to_string()
103 }
104}
105
106#[allow(dead_code)]
108pub struct StructuredParser<T: for<'de> Deserialize<'de> + Serialize + Send + Sync> {
109 #[allow(dead_code)]
110 schema: serde_json::Value,
111 _marker: std::marker::PhantomData<T>,
112}
113
114impl<T: for<'de> Deserialize<'de> + Serialize + Send + Sync> Default for StructuredParser<T> {
115 fn default() -> Self {
116 Self {
117 schema: serde_json::Value::Null,
118 _marker: std::marker::PhantomData,
119 }
120 }
121}
122
123impl<T: for<'de> Deserialize<'de> + Serialize + Send + Sync> StructuredParser<T> {
124 pub fn new() -> Self {
125 Self::default()
126 }
127
128 pub fn with_schema(schema: serde_json::Value) -> Self {
129 Self {
130 schema,
131 _marker: std::marker::PhantomData,
132 }
133 }
134}
135
136pub struct ListParser {
138 delimiter: String,
139}
140
141impl ListParser {
142 pub fn new(delimiter: impl Into<String>) -> Self {
143 Self {
144 delimiter: delimiter.into(),
145 }
146 }
147}
148
149impl Default for ListParser {
150 fn default() -> Self {
151 Self::new("\n")
152 }
153}
154
155#[async_trait]
156impl OutputParser for ListParser {
157 fn name(&self) -> &str {
158 "list"
159 }
160
161 async fn parse(&self, output: &str) -> Layer3Result<ParsedOutput> {
162 let items: Vec<String> = output
163 .split(&self.delimiter)
164 .map(|s| s.trim().to_string())
165 .filter(|s| !s.is_empty())
166 .collect();
167
168 Ok(ParsedOutput {
169 data: serde_json::to_value(items)?,
170 raw: output.to_string(),
171 success: true,
172 error: None,
173 })
174 }
175
176 fn get_format_instructions(&self) -> String {
177 format!("Output should be a list separated by '{}'.", self.delimiter)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[tokio::test]
186 async fn test_json_parser() {
187 let parser = JsonParser::default();
188 let result = parser.parse("{\"key\": \"value\"}").await.unwrap();
189 assert!(result.success);
190 }
191
192 #[tokio::test]
193 async fn test_list_parser() {
194 let parser = ListParser::default();
195 let result = parser.parse("a\nb\nc").await.unwrap();
196 assert!(result.success);
197 let items: Vec<String> = serde_json::from_value(result.data).unwrap();
198 assert_eq!(items.len(), 3);
199 }
200}