vt_push_parser/
capture.rs

1//! Raw-input-capturing push parser.
2
3use crate::{VT_PARSER_INTEREST_DEFAULT, VTEvent, VTPushParser};
4
5/// The type of capture mode to use after this event has been emitted.
6///
7/// The data will be emitted as a [`VTInputEvent::Captured`] event.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum VTInputCapture {
10    /// No capture mode. This must also be returned from any
11    /// [`VTInputEvent::Captured`] event.
12    None,
13    /// Capture a fixed number of bytes.
14    Count(usize),
15    /// Capture a fixed number of UTF-8 chars.
16    CountUtf8(usize),
17    /// Capture bytes until a terminator is found.
18    Terminator(&'static [u8]),
19}
20
21#[cfg_attr(feature = "serde", derive(serde::Serialize))]
22#[derive(Debug)]
23pub enum VTCaptureEvent<'a> {
24    VTEvent(VTEvent<'a>),
25    Capture(&'a [u8]),
26    CaptureEnd,
27}
28
29enum VTCaptureInternal {
30    None,
31    Count(usize),
32    CountUtf8(usize),
33    Terminator(&'static [u8], usize),
34}
35
36impl VTCaptureInternal {
37    fn feed<'a>(&mut self, input: &mut &'a [u8]) -> Option<&'a [u8]> {
38        match self {
39            VTCaptureInternal::None => None,
40            VTCaptureInternal::Count(count) => {
41                if input.len() >= *count {
42                    let (capture, rest) = input.split_at(*count);
43                    *input = rest;
44                    *self = VTCaptureInternal::None;
45                    Some(capture)
46                } else {
47                    None
48                }
49            }
50            VTCaptureInternal::CountUtf8(count) => {
51                // Count UTF-8 characters, not bytes
52                let mut chars_found = 0;
53                let mut bytes_consumed = 0;
54
55                for (i, &byte) in input.iter().enumerate() {
56                    // Check if this is the start of a new UTF-8 character
57                    if byte & 0xC0 != 0x80 {
58                        // Not a continuation byte
59                        chars_found += 1;
60                        if chars_found == *count {
61                            // We found the nth character, now we need to find where it ends
62                            // by consuming all its continuation bytes
63                            let mut j = i + 1;
64                            while j < input.len() && input[j] & 0xC0 == 0x80 {
65                                j += 1;
66                            }
67                            bytes_consumed = j;
68                            break;
69                        }
70                    }
71                }
72
73                if chars_found == *count {
74                    let (capture, rest) = input.split_at(bytes_consumed);
75                    *input = rest;
76                    *self = VTCaptureInternal::None;
77                    Some(capture)
78                } else {
79                    None
80                }
81            }
82            VTCaptureInternal::Terminator(terminator, found) => {
83                // Ground state
84                if *found == 0 {
85                    if let Some(position) = input.iter().position(|&b| b == terminator[0]) {
86                        // Advance to first match position
87                        *found = 1;
88                        let unmatched = &input[..position];
89                        *input = &input[position + 1..];
90                        return Some(unmatched);
91                    } else {
92                        let unmatched = *input;
93                        *input = &[];
94                        return Some(unmatched);
95                    }
96                }
97
98                // We've already found part of the terminator, so we can continue
99                while *found < terminator.len() {
100                    if input.is_empty() {
101                        return None;
102                    }
103
104                    if input[0] == terminator[*found] {
105                        *found += 1;
106                        *input = &input[1..];
107                    } else {
108                        // Failed a match, so return the part of the terminator we already matched
109                        let old_found = std::mem::take(found);
110                        return Some(&terminator[..old_found]);
111                    }
112                }
113
114                // We've matched the entire terminator
115                *self = VTCaptureInternal::None;
116                None
117            }
118        }
119    }
120}
121
122/// A parser that allows for "capturing" of input data, ie: temporarily
123/// transferring control of the parser to unparsed data events.
124///
125/// This functions in the same way as [`VTPushParser`], but emits
126/// [`VTCaptureEvent`]s instead of [`VTEvent`]s.
127pub struct VTCapturePushParser<const INTEREST: u8 = VT_PARSER_INTEREST_DEFAULT> {
128    parser: VTPushParser<INTEREST>,
129    capture: VTCaptureInternal,
130}
131
132impl Default for VTCapturePushParser {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138impl VTCapturePushParser {
139    pub const fn new() -> VTCapturePushParser {
140        VTCapturePushParser::new_with_interest::<VT_PARSER_INTEREST_DEFAULT>()
141    }
142
143    pub const fn new_with_interest<const INTEREST: u8>() -> VTCapturePushParser<INTEREST> {
144        VTCapturePushParser::new_with()
145    }
146}
147
148impl<const INTEREST: u8> VTCapturePushParser<INTEREST> {
149    const fn new_with() -> Self {
150        Self {
151            parser: VTPushParser::new_with(),
152            capture: VTCaptureInternal::None,
153        }
154    }
155
156    pub fn is_ground(&self) -> bool {
157        self.parser.is_ground()
158    }
159
160    pub fn idle(&mut self) -> Option<VTCaptureEvent<'static>> {
161        self.parser.idle().map(VTCaptureEvent::VTEvent)
162    }
163
164    pub fn feed_with<'this, 'input, F: for<'any> FnMut(VTCaptureEvent<'any>) -> VTInputCapture>(
165        &'this mut self,
166        mut input: &'input [u8],
167        cb: &mut F,
168    ) {
169        while !input.is_empty() {
170            match &mut self.capture {
171                VTCaptureInternal::None => {
172                    // Normal parsing mode - feed to the underlying parser
173                    let count = self.parser.feed_with_abortable(input, &mut |event| {
174                        let capture_mode = cb(VTCaptureEvent::VTEvent(event));
175                        match capture_mode {
176                            VTInputCapture::None => {
177                                // Stay in normal mode
178                            }
179                            VTInputCapture::Count(count) => {
180                                self.capture = VTCaptureInternal::Count(count);
181                            }
182                            VTInputCapture::CountUtf8(count) => {
183                                self.capture = VTCaptureInternal::CountUtf8(count);
184                            }
185                            VTInputCapture::Terminator(terminator) => {
186                                self.capture = VTCaptureInternal::Terminator(terminator, 0);
187                            }
188                        }
189                        false // Don't abort parsing
190                    });
191
192                    input = &input[count..];
193                }
194                capture => {
195                    // Capture mode - collect data until capture is complete
196                    if let Some(captured_data) = capture.feed(&mut input) {
197                        cb(VTCaptureEvent::Capture(captured_data));
198                    }
199
200                    // Check if capture is complete
201                    if matches!(self.capture, VTCaptureInternal::None) {
202                        cb(VTCaptureEvent::CaptureEnd);
203                    }
204                }
205            }
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_capture_paste() {
216        let mut output = String::new();
217        let mut parser = VTCapturePushParser::new();
218        parser.feed_with(b"raw\x1b[200~paste\x1b[201~raw", &mut |event| {
219            output.push_str(&format!("{event:?}\n"));
220            match event {
221                VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
222                    if csi.params.try_parse::<usize>(0).unwrap_or(0) == 200 {
223                        VTInputCapture::Terminator(b"\x1b[201~")
224                    } else {
225                        VTInputCapture::None
226                    }
227                }
228                _ => VTInputCapture::None,
229            }
230        });
231        assert_eq!(
232            output.trim(),
233            r#"
234VTEvent(Raw('raw'))
235VTEvent(Csi('200', '', '~'))
236Capture([112, 97, 115, 116, 101])
237CaptureEnd
238VTEvent(Raw('raw'))
239"#
240            .trim()
241        );
242    }
243
244    #[test]
245    fn test_capture_count() {
246        let mut output = String::new();
247        let mut parser = VTCapturePushParser::new();
248        parser.feed_with(b"raw\x1b[Xpaste\x1b[Yraw", &mut |event| {
249            output.push_str(&format!("{event:?}\n"));
250            match event {
251                VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
252                    if csi.final_byte == b'X' {
253                        VTInputCapture::Count(5)
254                    } else {
255                        VTInputCapture::None
256                    }
257                }
258                _ => VTInputCapture::None,
259            }
260        });
261        assert_eq!(
262            output.trim(),
263            r#"
264VTEvent(Raw('raw'))
265VTEvent(Csi('', 'X'))
266Capture([112, 97, 115, 116, 101])
267CaptureEnd
268VTEvent(Csi('', 'Y'))
269VTEvent(Raw('raw'))
270"#
271            .trim()
272        );
273    }
274
275    #[test]
276    fn test_capture_count_utf8_but_ascii() {
277        let mut output = String::new();
278        let mut parser = VTCapturePushParser::new();
279        parser.feed_with(b"raw\x1b[Xpaste\x1b[Yraw", &mut |event| {
280            output.push_str(&format!("{event:?}\n"));
281            match event {
282                VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
283                    if csi.final_byte == b'X' {
284                        VTInputCapture::CountUtf8(5)
285                    } else {
286                        VTInputCapture::None
287                    }
288                }
289                _ => VTInputCapture::None,
290            }
291        });
292        assert_eq!(
293            output.trim(),
294            r#"
295VTEvent(Raw('raw'))
296VTEvent(Csi('', 'X'))
297Capture([112, 97, 115, 116, 101])
298CaptureEnd
299VTEvent(Csi('', 'Y'))
300VTEvent(Raw('raw'))
301"#
302            .trim()
303        );
304    }
305
306    #[test]
307    fn test_capture_count_utf8() {
308        let mut output = String::new();
309        let mut parser = VTCapturePushParser::new();
310        let input = "raw\u{001b}[X🤖🦕✅😀🕓\u{001b}[Yraw".as_bytes();
311        parser.feed_with(input, &mut |event| {
312            output.push_str(&format!("{event:?}\n"));
313            match event {
314                VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
315                    if csi.final_byte == b'X' {
316                        VTInputCapture::CountUtf8(5)
317                    } else {
318                        VTInputCapture::None
319                    }
320                }
321                _ => VTInputCapture::None,
322            }
323        });
324        assert_eq!(output.trim(), r#"
325VTEvent(Raw('raw'))
326VTEvent(Csi('', 'X'))
327Capture([240, 159, 164, 150, 240, 159, 166, 149, 226, 156, 133, 240, 159, 152, 128, 240, 159, 149, 147])
328CaptureEnd
329VTEvent(Csi('', 'Y'))
330VTEvent(Raw('raw'))
331"#.trim());
332    }
333
334    #[test]
335    fn test_capture_terminator_partial_match() {
336        let mut output = String::new();
337        let mut parser = VTCapturePushParser::new();
338
339        parser.feed_with(b"start\x1b[200~part\x1b[201ial\x1b[201~end", &mut |event| {
340            output.push_str(&format!("{event:?}\n"));
341            match event {
342                VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
343                    if csi.final_byte == b'~'
344                        && csi.params.try_parse::<usize>(0).unwrap_or(0) == 200
345                    {
346                        VTInputCapture::Terminator(b"\x1b[201~")
347                    } else {
348                        VTInputCapture::None
349                    }
350                }
351                _ => VTInputCapture::None,
352            }
353        });
354
355        assert_eq!(
356            output.trim(),
357            r#"VTEvent(Raw('start'))
358VTEvent(Csi('200', '', '~'))
359Capture([112, 97, 114, 116])
360Capture([27, 91, 50, 48, 49])
361Capture([105, 97, 108])
362CaptureEnd
363VTEvent(Raw('end'))"#
364        );
365    }
366
367    #[test]
368    fn test_capture_terminator_partial_match_single_byte() {
369        let input = b"start\x1b[200~part\x1b[201ial\x1b[201~end";
370
371        for chunk_size in 1..5 {
372            let (captured, output) = capture_chunk_size(input, chunk_size);
373            assert_eq!(captured, b"part\x1b[201ial", "{output}",);
374        }
375    }
376
377    fn capture_chunk_size(input: &'static [u8; 32], chunk_size: usize) -> (Vec<u8>, String) {
378        let mut output = String::new();
379        let mut parser = VTCapturePushParser::new();
380        let mut captured = Vec::new();
381        for chunk in input.chunks(chunk_size) {
382            parser.feed_with(chunk, &mut |event| {
383                output.push_str(&format!("{event:?}\n"));
384                match event {
385                    VTCaptureEvent::Capture(data) => {
386                        captured.extend_from_slice(data);
387                        VTInputCapture::None
388                    }
389                    VTCaptureEvent::VTEvent(VTEvent::Csi(csi)) => {
390                        if csi.final_byte == b'~'
391                            && csi.params.try_parse::<usize>(0).unwrap_or(0) == 200
392                        {
393                            VTInputCapture::Terminator(b"\x1b[201~")
394                        } else {
395                            VTInputCapture::None
396                        }
397                    }
398                    _ => VTInputCapture::None,
399                }
400            });
401        }
402        (captured, output)
403    }
404}