Skip to main content

synaptic_parsers/
xml_parser.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use synaptic_core::{RunnableConfig, SynapticError};
5use synaptic_runnables::Runnable;
6
7use crate::FormatInstructions;
8
9/// Represents a parsed XML element.
10#[derive(Debug, Clone, PartialEq)]
11pub struct XmlElement {
12    pub tag: String,
13    pub text: Option<String>,
14    pub attributes: HashMap<String, String>,
15    pub children: Vec<XmlElement>,
16}
17
18/// Parses XML-formatted LLM output into an `XmlElement` tree.
19///
20/// Performs simple XML parsing without requiring a full XML library.
21/// Handles basic tags with text content and nested elements.
22pub struct XmlOutputParser {
23    /// Optional root tag to extract from within a larger output.
24    root_tag: Option<String>,
25}
26
27impl XmlOutputParser {
28    /// Creates a new parser with no root tag filter.
29    pub fn new() -> Self {
30        Self { root_tag: None }
31    }
32
33    /// Creates a new parser that only parses content within the specified root tag.
34    pub fn with_root_tag(tag: impl Into<String>) -> Self {
35        Self {
36            root_tag: Some(tag.into()),
37        }
38    }
39}
40
41impl Default for XmlOutputParser {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl FormatInstructions for XmlOutputParser {
48    fn get_format_instructions(&self) -> String {
49        match &self.root_tag {
50            Some(tag) => {
51                format!("Your response should be valid XML wrapped in <{tag}>...</{tag}> tags.")
52            }
53            None => "Your response should be valid XML.".to_string(),
54        }
55    }
56}
57
58#[async_trait]
59impl Runnable<String, XmlElement> for XmlOutputParser {
60    async fn invoke(
61        &self,
62        input: String,
63        _config: &RunnableConfig,
64    ) -> Result<XmlElement, SynapticError> {
65        let xml = if let Some(root_tag) = &self.root_tag {
66            let open = format!("<{}", root_tag);
67            let close = format!("</{}>", root_tag);
68            let start = input.find(&open).ok_or_else(|| {
69                SynapticError::Parsing(format!("root tag <{}> not found in input", root_tag))
70            })?;
71            let end = input.find(&close).ok_or_else(|| {
72                SynapticError::Parsing(format!("closing tag </{}> not found in input", root_tag))
73            })?;
74            &input[start..end + close.len()]
75        } else {
76            input.trim()
77        };
78
79        let mut pos = 0;
80        parse_element(xml, &mut pos)
81    }
82}
83
84/// Skip whitespace characters, advancing `pos`.
85fn skip_whitespace(input: &str, pos: &mut usize) {
86    let bytes = input.as_bytes();
87    while *pos < bytes.len() && (bytes[*pos] as char).is_ascii_whitespace() {
88        *pos += 1;
89    }
90}
91
92/// Parse a single XML element starting at `pos`.
93fn parse_element(input: &str, pos: &mut usize) -> Result<XmlElement, SynapticError> {
94    skip_whitespace(input, pos);
95
96    if *pos >= input.len() || input.as_bytes()[*pos] != b'<' {
97        return Err(SynapticError::Parsing(format!(
98            "expected '<' at position {pos}",
99            pos = *pos
100        )));
101    }
102    *pos += 1; // skip '<'
103
104    // Parse tag name
105    let tag_start = *pos;
106    while *pos < input.len() {
107        let ch = input.as_bytes()[*pos] as char;
108        if ch.is_ascii_whitespace() || ch == '>' || ch == '/' {
109            break;
110        }
111        *pos += 1;
112    }
113    let tag = input[tag_start..*pos].to_string();
114    if tag.is_empty() {
115        return Err(SynapticError::Parsing("empty tag name".to_string()));
116    }
117
118    // Parse attributes
119    let attributes = parse_attributes(input, pos)?;
120
121    skip_whitespace(input, pos);
122
123    // Check for self-closing tag
124    if *pos < input.len() && input.as_bytes()[*pos] == b'/' {
125        *pos += 1; // skip '/'
126        if *pos >= input.len() || input.as_bytes()[*pos] != b'>' {
127            return Err(SynapticError::Parsing(
128                "expected '>' after '/' in self-closing tag".to_string(),
129            ));
130        }
131        *pos += 1; // skip '>'
132        return Ok(XmlElement {
133            tag,
134            text: None,
135            attributes,
136            children: Vec::new(),
137        });
138    }
139
140    // Expect '>'
141    if *pos >= input.len() || input.as_bytes()[*pos] != b'>' {
142        return Err(SynapticError::Parsing(format!(
143            "expected '>' for tag <{tag}>"
144        )));
145    }
146    *pos += 1; // skip '>'
147
148    // Parse content: text and/or child elements
149    let mut children = Vec::new();
150    let mut text_parts: Vec<String> = Vec::new();
151
152    loop {
153        if *pos >= input.len() {
154            return Err(SynapticError::Parsing(format!(
155                "unexpected end of input, missing closing tag </{tag}>"
156            )));
157        }
158
159        // Check for closing tag
160        let closing = format!("</{tag}>");
161        if input[*pos..].starts_with(&closing) {
162            *pos += closing.len();
163            break;
164        }
165
166        // Check for child element
167        if input.as_bytes()[*pos] == b'<' {
168            // Make sure it's not a closing tag for something else
169            if *pos + 1 < input.len() && input.as_bytes()[*pos + 1] == b'/' {
170                return Err(SynapticError::Parsing(format!(
171                    "unexpected closing tag at position {pos}, expected </{tag}>",
172                    pos = *pos
173                )));
174            }
175            let child = parse_element(input, pos)?;
176            children.push(child);
177        } else {
178            // Collect text content until we hit a '<'
179            let text_start = *pos;
180            while *pos < input.len() && input.as_bytes()[*pos] != b'<' {
181                *pos += 1;
182            }
183            let part = input[text_start..*pos].to_string();
184            let trimmed = part.trim().to_string();
185            if !trimmed.is_empty() {
186                text_parts.push(trimmed);
187            }
188        }
189    }
190
191    let text = if text_parts.is_empty() {
192        None
193    } else {
194        Some(text_parts.join(" "))
195    };
196
197    Ok(XmlElement {
198        tag,
199        text,
200        attributes,
201        children,
202    })
203}
204
205/// Parse attributes inside an opening tag. `pos` should be right after the tag name.
206fn parse_attributes(
207    input: &str,
208    pos: &mut usize,
209) -> Result<HashMap<String, String>, SynapticError> {
210    let mut attributes = HashMap::new();
211
212    loop {
213        skip_whitespace(input, pos);
214
215        if *pos >= input.len() {
216            break;
217        }
218
219        let ch = input.as_bytes()[*pos] as char;
220        if ch == '>' || ch == '/' {
221            break;
222        }
223
224        // Parse attribute name
225        let name_start = *pos;
226        while *pos < input.len() {
227            let c = input.as_bytes()[*pos] as char;
228            if c == '=' || c.is_ascii_whitespace() || c == '>' || c == '/' {
229                break;
230            }
231            *pos += 1;
232        }
233        let name = input[name_start..*pos].to_string();
234        if name.is_empty() {
235            return Err(SynapticError::Parsing("empty attribute name".to_string()));
236        }
237
238        skip_whitespace(input, pos);
239
240        // Expect '='
241        if *pos >= input.len() || input.as_bytes()[*pos] != b'=' {
242            return Err(SynapticError::Parsing(format!(
243                "expected '=' after attribute name '{name}'"
244            )));
245        }
246        *pos += 1; // skip '='
247
248        skip_whitespace(input, pos);
249
250        // Expect quoted value
251        if *pos >= input.len() {
252            return Err(SynapticError::Parsing(
253                "unexpected end of input in attribute value".to_string(),
254            ));
255        }
256
257        let quote = input.as_bytes()[*pos] as char;
258        if quote != '"' && quote != '\'' {
259            return Err(SynapticError::Parsing(format!(
260                "expected quote for attribute '{name}' value, got '{quote}'"
261            )));
262        }
263        *pos += 1; // skip opening quote
264
265        let value_start = *pos;
266        while *pos < input.len() && input.as_bytes()[*pos] as char != quote {
267            *pos += 1;
268        }
269        if *pos >= input.len() {
270            return Err(SynapticError::Parsing(format!(
271                "unterminated attribute value for '{name}'"
272            )));
273        }
274        let value = input[value_start..*pos].to_string();
275        *pos += 1; // skip closing quote
276
277        attributes.insert(name, value);
278    }
279
280    Ok(attributes)
281}