1use std::collections::{BTreeMap, BTreeSet};
12
13use thiserror::Error;
14
15pub type Attributes = BTreeMap<String, String>;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum ParserEvent {
19 Text(String),
20 StartTag(String, Attributes),
21 Chunk(Vec<u8>),
22 EndTag(String),
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum Instruction {
27 Text(String),
28 StartTag {
29 name: String,
30 attributes: Attributes,
31 },
32 EndTag(String),
33 WriteChunk(Vec<u8>),
34 RawChunk {
35 tag: String,
36 bytes: Vec<u8>,
37 },
38}
39
40#[derive(Debug, Error, PartialEq, Eq)]
41pub enum ParserError {
42 #[error("invalid utf-8 in text")]
43 InvalidUtf8,
44 #[error("malformed tag: {0}")]
45 MalformedTag(String),
46 #[error("unexpected closing tag: </{found}>")]
47 UnexpectedClosingTag { found: String },
48 #[error("mismatched closing tag: expected </{expected}> but found </{found}>")]
49 MismatchedClosingTag { expected: String, found: String },
50 #[error("unterminated tag")]
51 UnterminatedTag,
52 #[error("unterminated raw section for <{0}>")]
53 UnterminatedRawSection(String),
54 #[error("unclosed tag(s): {0}")]
55 UnclosedTags(String),
56 #[error("chunk emitted with no active raw tag")]
57 UnexpectedChunk,
58}
59
60#[derive(Debug)]
61pub struct StreamingParser {
62 buffer: Vec<u8>,
63 open_tags: Vec<String>,
64 raw_tag: Option<String>,
65 raw_tags: BTreeSet<String>,
66}
67
68impl Default for StreamingParser {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl StreamingParser {
75 pub fn new() -> Self {
76 Self::with_raw_tags(default_raw_tags())
77 }
78
79 pub fn with_raw_tags<I, S>(raw_tags: I) -> Self
80 where
81 I: IntoIterator<Item = S>,
82 S: Into<String>,
83 {
84 let raw_tags = raw_tags.into_iter().map(Into::into).collect();
85 Self {
86 buffer: Vec::new(),
87 open_tags: Vec::new(),
88 raw_tag: None,
89 raw_tags,
90 }
91 }
92
93 pub fn feed(&mut self, input: &[u8]) -> Result<Vec<ParserEvent>, ParserError> {
94 if !input.is_empty() {
95 self.buffer.extend_from_slice(input);
96 }
97
98 let mut events = Vec::new();
99
100 loop {
101 if let Some(active_raw) = self.raw_tag.clone() {
102 let needle = format!("</{active_raw}>").into_bytes();
103 if let Some(idx) = find_subsequence(&self.buffer, &needle) {
104 if idx > 0 {
105 events.push(ParserEvent::Chunk(self.buffer[..idx].to_vec()));
106 }
107 self.buffer.drain(..idx + needle.len());
108
109 self.raw_tag = None;
110 self.pop_expected_tag(&active_raw)?;
111 events.push(ParserEvent::EndTag(active_raw));
112 continue;
113 }
114
115 let keep_tail = needle.len().saturating_sub(1);
116 if self.buffer.len() > keep_tail {
117 let emit_len = self.buffer.len() - keep_tail;
118 events.push(ParserEvent::Chunk(self.buffer[..emit_len].to_vec()));
119 self.buffer.drain(..emit_len);
120 }
121 break;
122 }
123
124 let Some(tag_start) = self.buffer.iter().position(|b| *b == b'<') else {
125 if !self.buffer.is_empty() {
126 events.push(ParserEvent::Text(to_utf8_lossless(&self.buffer)?));
127 self.buffer.clear();
128 }
129 break;
130 };
131
132 if tag_start > 0 {
133 events.push(ParserEvent::Text(to_utf8_lossless(
134 &self.buffer[..tag_start],
135 )?));
136 self.buffer.drain(..tag_start);
137 continue;
138 }
139
140 let Some(tag_end) = find_tag_end(&self.buffer) else {
141 break;
142 };
143
144 let tag_bytes: Vec<u8> = self.buffer.drain(..=tag_end).collect();
145 let parsed = parse_tag(&tag_bytes)?;
146 match parsed {
147 ParsedTag::Start {
148 name,
149 attributes,
150 self_closing,
151 } => {
152 events.push(ParserEvent::StartTag(name.clone(), attributes));
153 if self_closing {
154 events.push(ParserEvent::EndTag(name));
155 } else {
156 self.open_tags.push(name.clone());
157 if self.raw_tags.contains(&name) {
158 self.raw_tag = Some(name);
159 }
160 }
161 }
162 ParsedTag::End { name } => {
163 self.pop_expected_tag(&name)?;
164 events.push(ParserEvent::EndTag(name));
165 }
166 }
167 }
168
169 Ok(events)
170 }
171
172 pub fn finish(&mut self) -> Result<Vec<ParserEvent>, ParserError> {
173 let mut events = self.feed(&[])?;
174
175 if let Some(raw_tag) = &self.raw_tag {
176 return Err(ParserError::UnterminatedRawSection(raw_tag.to_string()));
177 }
178
179 if !self.buffer.is_empty() {
180 if self.buffer[0] == b'<' {
181 return Err(ParserError::UnterminatedTag);
182 }
183 events.push(ParserEvent::Text(to_utf8_lossless(&self.buffer)?));
184 self.buffer.clear();
185 }
186
187 if !self.open_tags.is_empty() {
188 return Err(ParserError::UnclosedTags(self.open_tags.join(", ")));
189 }
190
191 Ok(events)
192 }
193
194 fn pop_expected_tag(&mut self, closing_tag: &str) -> Result<(), ParserError> {
195 let Some(last) = self.open_tags.pop() else {
196 return Err(ParserError::UnexpectedClosingTag {
197 found: closing_tag.to_string(),
198 });
199 };
200
201 if last != closing_tag {
202 return Err(ParserError::MismatchedClosingTag {
203 expected: last,
204 found: closing_tag.to_string(),
205 });
206 }
207
208 Ok(())
209 }
210}
211
212#[derive(Debug, Default)]
213pub struct InstructionParser {
214 parser: StreamingParser,
215 raw_tag_stack: Vec<String>,
216}
217
218impl InstructionParser {
219 pub fn new() -> Self {
220 Self::default()
221 }
222
223 pub fn feed(&mut self, input: &[u8]) -> Result<Vec<Instruction>, ParserError> {
224 let events = self.parser.feed(input)?;
225 self.map_events(events)
226 }
227
228 pub fn finish(&mut self) -> Result<Vec<Instruction>, ParserError> {
229 let events = self.parser.finish()?;
230 self.map_events(events)
231 }
232
233 fn map_events(&mut self, events: Vec<ParserEvent>) -> Result<Vec<Instruction>, ParserError> {
234 let mut instructions = Vec::with_capacity(events.len());
235 for event in events {
236 match event {
237 ParserEvent::Text(text) => instructions.push(Instruction::Text(text)),
238 ParserEvent::StartTag(name, attributes) => {
239 if is_raw_tag(&name) {
240 self.raw_tag_stack.push(name.clone());
241 }
242 instructions.push(Instruction::StartTag { name, attributes });
243 }
244 ParserEvent::EndTag(name) => {
245 if self.raw_tag_stack.last().is_some_and(|v| v == &name) {
246 self.raw_tag_stack.pop();
247 }
248 instructions.push(Instruction::EndTag(name));
249 }
250 ParserEvent::Chunk(bytes) => {
251 let Some(active_tag) = self.raw_tag_stack.last() else {
252 return Err(ParserError::UnexpectedChunk);
253 };
254 if active_tag == "write_file" {
255 instructions.push(Instruction::WriteChunk(bytes));
256 } else {
257 instructions.push(Instruction::RawChunk {
258 tag: active_tag.clone(),
259 bytes,
260 });
261 }
262 }
263 }
264 }
265 Ok(instructions)
266 }
267}
268
269fn to_utf8_lossless(bytes: &[u8]) -> Result<String, ParserError> {
270 String::from_utf8(bytes.to_vec()).map_err(|_| ParserError::InvalidUtf8)
271}
272
273fn default_raw_tags() -> [&'static str; 4] {
274 ["write_file", "search", "replace", "terminal"]
275}
276
277fn is_raw_tag(tag: &str) -> bool {
278 default_raw_tags().contains(&tag)
279}
280
281#[derive(Debug)]
282enum ParsedTag {
283 Start {
284 name: String,
285 attributes: Attributes,
286 self_closing: bool,
287 },
288 End {
289 name: String,
290 },
291}
292
293fn parse_tag(tag_bytes: &[u8]) -> Result<ParsedTag, ParserError> {
294 let tag = std::str::from_utf8(tag_bytes).map_err(|_| ParserError::InvalidUtf8)?;
295 if !tag.starts_with('<') || !tag.ends_with('>') {
296 return Err(ParserError::MalformedTag(tag.to_string()));
297 }
298
299 let mut inner = tag[1..tag.len() - 1].trim();
300 if inner.is_empty() {
301 return Err(ParserError::MalformedTag(tag.to_string()));
302 }
303
304 if let Some(stripped) = inner.strip_prefix('/') {
305 let name = stripped.trim();
306 if !is_valid_name(name) {
307 return Err(ParserError::MalformedTag(tag.to_string()));
308 }
309 return Ok(ParsedTag::End {
310 name: name.to_string(),
311 });
312 }
313
314 let self_closing = inner.ends_with('/');
315 if self_closing {
316 inner = inner[..inner.len() - 1].trim_end();
317 }
318
319 let name_end = inner.find(char::is_whitespace).unwrap_or(inner.len());
320
321 let name = &inner[..name_end];
322 if !is_valid_name(name) {
323 return Err(ParserError::MalformedTag(tag.to_string()));
324 }
325
326 let mut attributes = Attributes::new();
327 let mut cursor = inner[name_end..].trim_start();
328 while !cursor.is_empty() {
329 let eq_idx = cursor
330 .find('=')
331 .ok_or_else(|| ParserError::MalformedTag(tag.to_string()))?;
332
333 let key = cursor[..eq_idx].trim();
334 if !is_valid_name(key) {
335 return Err(ParserError::MalformedTag(tag.to_string()));
336 }
337
338 cursor = cursor[eq_idx + 1..].trim_start();
339 let Some(quote) = cursor.chars().next() else {
340 return Err(ParserError::MalformedTag(tag.to_string()));
341 };
342 if quote != '"' && quote != '\'' {
343 return Err(ParserError::MalformedTag(tag.to_string()));
344 }
345
346 cursor = &cursor[quote.len_utf8()..];
347 let mut value_end = None;
348 for (idx, ch) in cursor.char_indices() {
349 if ch == quote {
350 value_end = Some(idx);
351 break;
352 }
353 }
354 let Some(value_end) = value_end else {
355 return Err(ParserError::MalformedTag(tag.to_string()));
356 };
357
358 let value = &cursor[..value_end];
359 attributes.insert(key.to_string(), value.to_string());
360 cursor = cursor[value_end + quote.len_utf8()..].trim_start();
361 }
362
363 Ok(ParsedTag::Start {
364 name: name.to_string(),
365 attributes,
366 self_closing,
367 })
368}
369
370fn is_valid_name(name: &str) -> bool {
371 !name.is_empty()
372 && !name.contains(char::is_whitespace)
373 && !name.contains('/')
374 && !name.contains('>')
375 && !name.contains('<')
376}
377
378fn find_tag_end(bytes: &[u8]) -> Option<usize> {
379 let mut in_quote = None;
380 for (idx, byte) in bytes.iter().enumerate().skip(1) {
381 match (*byte, in_quote) {
382 (b'\'' | b'"', None) => in_quote = Some(*byte),
383 (b'\'' | b'"', Some(q)) if q == *byte => in_quote = None,
384 (b'>', None) => return Some(idx),
385 _ => {}
386 }
387 }
388 None
389}
390
391fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
392 if needle.is_empty() || haystack.len() < needle.len() {
393 return None;
394 }
395
396 haystack
397 .windows(needle.len())
398 .position(|window| window == needle)
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn parses_self_closing_tag() {
407 let mut parser = StreamingParser::new();
408 let events = parser
409 .feed(br#"<read_file path="src/lib.rs" start_line="1" end_line="5" />"#)
410 .expect("parse should succeed");
411 let finished = parser.finish().expect("finish should succeed");
412
413 assert_eq!(
414 events,
415 vec![
416 ParserEvent::StartTag(
417 "read_file".to_string(),
418 BTreeMap::from([
419 ("end_line".to_string(), "5".to_string()),
420 ("path".to_string(), "src/lib.rs".to_string()),
421 ("start_line".to_string(), "1".to_string()),
422 ])
423 ),
424 ParserEvent::EndTag("read_file".to_string())
425 ]
426 );
427 assert!(finished.is_empty());
428 }
429
430 #[test]
431 fn parses_raw_write_file_body_even_with_pseudo_tags_inside() {
432 let mut parser = StreamingParser::new();
433
434 let first = parser
435 .feed(b"<write_file path=\"src/main.rs\">fn main() {<not_a_tag>")
436 .expect("first parse should succeed");
437 assert!(matches!(
438 first.first(),
439 Some(ParserEvent::StartTag(name, _)) if name == "write_file"
440 ));
441
442 let second = parser
443 .feed(b" println!(\"ok\"); }</write_file>")
444 .expect("second parse should succeed");
445 assert!(matches!(
446 second.last(),
447 Some(ParserEvent::EndTag(name)) if name == "write_file"
448 ));
449
450 let full_body = collect_chunks(&[first.clone(), second.clone()]);
451 assert_eq!(full_body, b"fn main() {<not_a_tag> println!(\"ok\"); }");
452
453 assert!(parser.finish().expect("finish should succeed").is_empty());
454 }
455
456 #[test]
457 fn handles_raw_tag_close_across_chunk_boundaries() {
458 let mut parser = StreamingParser::new();
459
460 let events_1 = parser
461 .feed(b"<terminal>cargo t")
462 .expect("feed should succeed");
463 assert_eq!(
464 events_1,
465 vec![ParserEvent::StartTag(
466 "terminal".to_string(),
467 BTreeMap::new()
468 )]
469 );
470
471 let events_2 = parser.feed(b"est</ter").expect("feed should succeed");
472 assert!(!events_2.is_empty());
473
474 let events_3 = parser.feed(b"minal>").expect("feed should succeed");
475 assert!(matches!(
476 events_3.last(),
477 Some(ParserEvent::EndTag(name)) if name == "terminal"
478 ));
479 let full_body = collect_chunks(&[events_1, events_2, events_3]);
480 assert_eq!(full_body, b"cargo test");
481
482 assert!(parser.finish().expect("finish should succeed").is_empty());
483 }
484
485 #[test]
486 fn returns_mismatched_closing_tag_error() {
487 let mut parser = StreamingParser::new();
488
489 let _ = parser
490 .feed(b"<apply_edit path=\"src/lib.rs\">")
491 .expect("opening tag should parse");
492 let err = parser.feed(b"</write_file>").expect_err("should fail");
493
494 assert_eq!(
495 err,
496 ParserError::MismatchedClosingTag {
497 expected: "apply_edit".to_string(),
498 found: "write_file".to_string(),
499 }
500 );
501 }
502
503 #[test]
504 fn returns_unterminated_tag_error_on_finish() {
505 let mut parser = StreamingParser::new();
506 let _ = parser
507 .feed(b"<read_file path=\"a")
508 .expect("feed should work");
509 let err = parser.finish().expect_err("finish should fail");
510
511 assert_eq!(err, ParserError::UnterminatedTag);
512 }
513
514 #[test]
515 fn returns_unterminated_raw_section_error_on_finish() {
516 let mut parser = StreamingParser::new();
517 let _ = parser
518 .feed(b"<write_file path=\"x\">partial")
519 .expect("feed should work");
520 let err = parser.finish().expect_err("finish should fail");
521
522 assert_eq!(
523 err,
524 ParserError::UnterminatedRawSection("write_file".to_string())
525 );
526 }
527
528 #[test]
529 fn instruction_parser_emits_write_chunk_for_write_file_content() {
530 let mut parser = InstructionParser::new();
531 let batch_1 = parser
532 .feed(b"<write_file path=\"src/main.rs\">hello")
533 .expect("instruction batch 1 should parse");
534 let batch_2 = parser
535 .feed(b" world</write_file>")
536 .expect("instruction batch 2 should parse");
537 let final_batch = parser.finish().expect("finish should parse");
538
539 assert!(matches!(
540 batch_1.first(),
541 Some(Instruction::StartTag { name, .. }) if name == "write_file"
542 ));
543
544 let mut body = Vec::new();
545 for batch in [&batch_1, &batch_2, &final_batch] {
546 for instruction in batch.iter() {
547 if let Instruction::WriteChunk(bytes) = instruction {
548 body.extend_from_slice(bytes);
549 }
550 }
551 }
552
553 assert_eq!(body, b"hello world");
554 assert!(matches!(
555 batch_2.last(),
556 Some(Instruction::EndTag(name)) if name == "write_file"
557 ));
558 }
559
560 fn collect_chunks(batches: &[Vec<ParserEvent>]) -> Vec<u8> {
561 let mut out = Vec::new();
562 for batch in batches {
563 for event in batch {
564 if let ParserEvent::Chunk(bytes) = event {
565 out.extend_from_slice(bytes);
566 }
567 }
568 }
569 out
570 }
571}