Skip to main content

sh_layer3/
output_parsers.rs

1//! # Output Parsers
2//!
3//! 输出解析器:解析 LLM 输出为结构化数据。
4
5use crate::types::Layer3Result;
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9/// 输出解析器 trait
10///
11/// 定义 LLM 输出解析接口。
12#[async_trait]
13pub trait OutputParser: Send + Sync {
14    /// 解析器名称
15    fn name(&self) -> &str;
16
17    /// 解析 LLM 输出
18    async fn parse(&self, output: &str) -> Layer3Result<ParsedOutput>;
19
20    /// 获取解析指令(用于提示词)
21    fn get_format_instructions(&self) -> String;
22}
23
24/// 解析后的输出
25#[derive(Debug, Clone)]
26pub struct ParsedOutput {
27    /// 解析结果(JSON)
28    pub data: serde_json::Value,
29    /// 原始输出
30    pub raw: String,
31    /// 是否成功
32    pub success: bool,
33    /// 解析错误(如果有)
34    pub error: Option<String>,
35}
36
37/// JSON 解析器
38#[allow(dead_code)]
39pub struct JsonParser {
40    /// 是否严格模式
41    #[allow(dead_code)]
42    strict: bool,
43}
44
45impl JsonParser {
46    pub fn new(strict: bool) -> Self {
47        Self { strict }
48    }
49}
50
51impl Default for JsonParser {
52    fn default() -> Self {
53        Self::new(false)
54    }
55}
56
57#[async_trait]
58impl OutputParser for JsonParser {
59    fn name(&self) -> &str {
60        "json"
61    }
62
63    async fn parse(&self, output: &str) -> Layer3Result<ParsedOutput> {
64        // 尝试提取 JSON
65        let trimmed = output.trim();
66
67        // 尝试直接解析
68        if let Ok(data) = serde_json::from_str::<serde_json::Value>(trimmed) {
69            return Ok(ParsedOutput {
70                data,
71                raw: output.to_string(),
72                success: true,
73                error: None,
74            });
75        }
76
77        // 尝试从文本中提取 JSON 块
78        let json_start = trimmed.find('{').or_else(|| trimmed.find('['));
79        let json_end = trimmed.rfind('}').or_else(|| trimmed.rfind(']'));
80
81        if let (Some(start), Some(end)) = (json_start, json_end) {
82            let json_str = &trimmed[start..=end];
83            if let Ok(data) = serde_json::from_str::<serde_json::Value>(json_str) {
84                return Ok(ParsedOutput {
85                    data,
86                    raw: output.to_string(),
87                    success: true,
88                    error: None,
89                });
90            }
91        }
92
93        Ok(ParsedOutput {
94            data: serde_json::Value::Null,
95            raw: output.to_string(),
96            success: false,
97            error: Some("Failed to parse JSON".to_string()),
98        })
99    }
100
101    fn get_format_instructions(&self) -> String {
102        "Output should be a valid JSON object.".to_string()
103    }
104}
105
106/// 结构化解析器
107#[allow(dead_code)]
108pub struct StructuredParser<T: for<'de> Deserialize<'de> + Serialize + Send + Sync> {
109    #[allow(dead_code)]
110    schema: serde_json::Value,
111    _marker: std::marker::PhantomData<T>,
112}
113
114impl<T: for<'de> Deserialize<'de> + Serialize + Send + Sync> Default for StructuredParser<T> {
115    fn default() -> Self {
116        Self {
117            schema: serde_json::Value::Null,
118            _marker: std::marker::PhantomData,
119        }
120    }
121}
122
123impl<T: for<'de> Deserialize<'de> + Serialize + Send + Sync> StructuredParser<T> {
124    pub fn new() -> Self {
125        Self::default()
126    }
127
128    pub fn with_schema(schema: serde_json::Value) -> Self {
129        Self {
130            schema,
131            _marker: std::marker::PhantomData,
132        }
133    }
134}
135
136/// 列表解析器
137pub struct ListParser {
138    delimiter: String,
139}
140
141impl ListParser {
142    pub fn new(delimiter: impl Into<String>) -> Self {
143        Self {
144            delimiter: delimiter.into(),
145        }
146    }
147}
148
149impl Default for ListParser {
150    fn default() -> Self {
151        Self::new("\n")
152    }
153}
154
155#[async_trait]
156impl OutputParser for ListParser {
157    fn name(&self) -> &str {
158        "list"
159    }
160
161    async fn parse(&self, output: &str) -> Layer3Result<ParsedOutput> {
162        let items: Vec<String> = output
163            .split(&self.delimiter)
164            .map(|s| s.trim().to_string())
165            .filter(|s| !s.is_empty())
166            .collect();
167
168        Ok(ParsedOutput {
169            data: serde_json::to_value(items)?,
170            raw: output.to_string(),
171            success: true,
172            error: None,
173        })
174    }
175
176    fn get_format_instructions(&self) -> String {
177        format!("Output should be a list separated by '{}'.", self.delimiter)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[tokio::test]
186    async fn test_json_parser() {
187        let parser = JsonParser::default();
188        let result = parser.parse("{\"key\": \"value\"}").await.unwrap();
189        assert!(result.success);
190    }
191
192    #[tokio::test]
193    async fn test_list_parser() {
194        let parser = ListParser::default();
195        let result = parser.parse("a\nb\nc").await.unwrap();
196        assert!(result.success);
197        let items: Vec<String> = serde_json::from_value(result.data).unwrap();
198        assert_eq!(items.len(), 3);
199    }
200}