sml_rs/
transport.rs

1//! SML transport protocol (version 1).
2//!
3//! *Hint: This crate currently only implements version 1 of the SML transport
4//! protocol. If you need support for version 2, let me know!*
5//!
6//! # SML Transport Protocol - Version 1
7//!
8//! Version 1 of the SML Transport Protocol is a simple format that encodes binary messages using escape sequences. A message consists of the following parts (numbers in hex):
9//!
10//! - **Start sequence**: `1b1b1b1b 01010101`
11//! - **Escaped data**: The data that should be encoded. If the escape sequence (`1b1b1b1b`) occurs in the data, it is escaped by an escape sequence (`1b1b1b1b`). For example, the data `001b1b1b 1b010203` would be encoded as `001b1b1b 1b1b1b1b 1b010203`.
12//! - **Padding**: The data is zero-padded to the next multiple of four. Therefore, zero to three `0x00` bytes are inserted.
13//! - **End sequence**: `1b1b1b1b 1aXXYYZZ`
14//!   - `XX`: number of padding bytes
15//!   - `YY`/`ZZ`: CRC checksum
16//!
17//! ## Encoding
18//!
19//! This crate implements both a streaming and a more traditional encoder.
20//!
21//! - `encode`: takes a sequence of bytes as input and returns a buffer containing the encoded message
22//! - `encode_streaming`: an iterator adapter that encodes the input on the fly
23//!
24//!
25//! ## Decoding
26//!
27//! - `decode`: takes a sequence of bytes and decodes them into a vector of messages / errors. Requires feature "alloc".
28//! - `decode_streaming`: takes a sequence of bytes and returns an iterator over the decoded messages / errors.
29//! - using `Decoder` directly: instantiate a `Decoder` manually, call `push_byte()` on it when data becomes available. Call `finalize()` when all data has been pushed.
30
31mod decoder_reader;
32
33pub use decoder_reader::{DecoderReader, ReadDecodedError};
34
35use core::{borrow::Borrow, fmt};
36
37use crate::util::{Buffer, OutOfMemory, CRC_X25};
38
39#[cfg(feature = "alloc")]
40use alloc::vec::Vec;
41
42struct Padding(u8);
43
44impl Padding {
45    const fn new() -> Self {
46        Padding(0)
47    }
48
49    fn bump(&mut self) {
50        self.0 = self.0.wrapping_sub(1);
51    }
52
53    const fn get(&self) -> u8 {
54        self.0 & 0x3
55    }
56}
57
58#[derive(Debug, Clone, Copy)]
59enum EncoderState {
60    Init(u8),
61    LookingForEscape(u8),
62    HandlingEscape(u8),
63    End(i8),
64}
65
66/// An iterator that encodes the bytes of an underlying iterator using the SML Transport Protocol v1.
67pub struct Encoder<I>
68where
69    I: Iterator<Item = u8>,
70{
71    state: EncoderState,
72    crc: crc::Digest<'static, u16>,
73    padding: Padding,
74    iter: I,
75}
76
77impl<I> Encoder<I>
78where
79    I: Iterator<Item = u8>,
80{
81    /// Creates an `Encoder` from a byte iterator.
82    pub fn new(iter: I) -> Self {
83        let mut crc = CRC_X25.digest();
84        crc.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01]);
85        Encoder {
86            state: EncoderState::Init(0),
87            crc,
88            padding: Padding::new(),
89            iter,
90        }
91    }
92
93    fn read_from_iter(&mut self) -> Option<u8> {
94        let ret = self.iter.next();
95        if ret.is_some() {
96            self.padding.bump();
97        }
98        ret
99    }
100
101    fn next_from_state(&mut self, state: EncoderState) -> (Option<u8>, EncoderState) {
102        self.state = state;
103        let out = self.next();
104        (out, self.state)
105    }
106}
107
108impl<I> Iterator for Encoder<I>
109where
110    I: Iterator<Item = u8>,
111{
112    type Item = u8;
113
114    fn next(&mut self) -> Option<u8> {
115        use EncoderState::*;
116        let (out, state) = match self.state {
117            Init(n) if n < 4 => (Some(0x1b), Init(n + 1)),
118            Init(n) if n < 8 => (Some(0x01), Init(n + 1)),
119            Init(n) => {
120                assert_eq!(n, 8);
121                self.next_from_state(LookingForEscape(0))
122            }
123            LookingForEscape(n) if n < 4 => {
124                match self.read_from_iter() {
125                    Some(b) => {
126                        self.crc.update(&[b]);
127                        (Some(b), LookingForEscape((n + 1) * u8::from(b == 0x1b)))
128                    }
129                    None => {
130                        let padding = self.padding.get();
131                        // finalize crc
132                        for _ in 0..padding {
133                            self.crc.update(&[0x00]);
134                        }
135                        self.crc.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x1a, padding]);
136                        self.next_from_state(End(-(padding as i8)))
137                    }
138                }
139            }
140            LookingForEscape(n) => {
141                assert_eq!(n, 4);
142                self.crc.update(&[0x1b; 4]);
143                self.next_from_state(HandlingEscape(0))
144            }
145            HandlingEscape(n) if n < 4 => (Some(0x1b), HandlingEscape(n + 1)),
146            HandlingEscape(n) => {
147                assert_eq!(n, 4);
148                self.next_from_state(LookingForEscape(0))
149            }
150            End(n) => {
151                let out = match n {
152                    n if n < 0 => 0x00,
153                    n if n < 4 => 0x1b,
154                    4 => 0x1a,
155                    5 => self.padding.get(),
156                    n if n < 8 => {
157                        let crc_bytes = self.crc.clone().finalize().to_le_bytes();
158                        crc_bytes[(n - 6) as usize]
159                    }
160                    8 => {
161                        return None;
162                    }
163                    _ => unreachable!(),
164                };
165                (Some(out), End(n + 1))
166            }
167        };
168        self.state = state;
169        out
170    }
171}
172
173/// Takes a slice of bytes as input and returns a buffer containing the encoded message.
174///
175/// Returns `Err(())` when the buffer can't be grown to hold the entire output.
176///
177/// # Examples
178///
179/// ```
180/// // example data
181/// let bytes = [0x12, 0x34, 0x56, 0x78];
182/// let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b];
183/// ```
184///
185/// ### Using alloc::Vec
186///
187/// ```
188/// # #[cfg(feature = "alloc")] {
189/// # use sml_rs::transport::encode;
190/// # let bytes = [0x12, 0x34, 0x56, 0x78];
191/// # let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b];
192/// let encoded = encode::<Vec<u8>>(&bytes);
193/// assert!(encoded.is_ok());
194/// assert_eq!(encoded.unwrap().as_slice(), &expected);
195/// # }
196/// ```
197///
198/// ### Using `ArrayBuf`
199///
200/// ```
201/// # use sml_rs::{util::{ArrayBuf, OutOfMemory}, transport::encode};
202/// # let bytes = [0x12, 0x34, 0x56, 0x78];
203/// # let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b];
204/// let encoded = encode::<ArrayBuf<20>>(&bytes);
205/// assert!(encoded.is_ok());
206/// assert_eq!(&*encoded.unwrap(), &expected);
207///
208/// // encoding returns `Err(())` if the encoded message does not fit into the vector
209/// let encoded = encode::<ArrayBuf<19>>(&bytes);
210/// assert_eq!(encoded, Err(OutOfMemory));
211/// ```
212///
213pub fn encode<B: Buffer>(
214    iter: impl IntoIterator<Item = impl Borrow<u8>>,
215) -> Result<B, OutOfMemory> {
216    let mut res: B = Default::default();
217
218    // start escape sequence
219    res.extend_from_slice(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01])?;
220
221    // encode data
222    let mut num_1b = 0;
223    for b in iter.into_iter() {
224        let b = *b.borrow();
225        if b == 0x1b {
226            num_1b += 1;
227        } else {
228            num_1b = 0;
229        }
230
231        res.push(b)?;
232
233        if num_1b == 4 {
234            res.extend_from_slice(&[0x1b; 4])?;
235            num_1b = 0;
236        }
237    }
238
239    // padding bytes
240    let num_padding_bytes = (4 - (res.len() % 4)) % 4;
241    res.extend_from_slice(&[0x0; 3][..num_padding_bytes])?;
242
243    res.extend_from_slice(&[0x1b, 0x1b, 0x1b, 0x1b, 0x1a, num_padding_bytes as u8])?;
244    let crc = CRC_X25.checksum(&res[..]);
245
246    res.extend_from_slice(&crc.to_le_bytes())?;
247
248    Ok(res)
249}
250
251/// Takes an iterator over bytes and returns an iterator that produces the encoded message.
252///
253/// # Examples
254/// ```
255/// # use sml_rs::transport::encode_streaming;
256/// // example data
257/// let bytes = [0x12, 0x34, 0x56, 0x78];
258/// let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b];
259/// let iter = encode_streaming(bytes);
260/// assert!(iter.eq(expected));
261/// ```
262pub fn encode_streaming(
263    iter: impl IntoIterator<Item = impl Borrow<u8>>,
264) -> Encoder<impl Iterator<Item = u8>> {
265    Encoder::new(iter.into_iter().map(|x| *x.borrow()))
266}
267
268#[derive(Debug, PartialEq, Eq, Clone)]
269/// An error which can be returned when decoding an sml message.
270pub enum DecodeErr {
271    /// Some bytes could not be parsed and were discarded
272    DiscardedBytes(usize),
273    /// An invalid escape sequence has been read
274    InvalidEsc([u8; 4]),
275    /// The buffer used internally by the encoder is full. When using vec, allocation has failed.
276    OutOfMemory,
277    /// The decoded message is invalid.
278    InvalidMessage {
279        /// (expected, found) checksums
280        checksum_mismatch: (u16, u16),
281        /// whether the end escape sequence wasn't aligned to a 4-byte boundary
282        end_esc_misaligned: bool,
283        /// the number of padding bytes.
284        num_padding_bytes: u8,
285    },
286}
287
288impl fmt::Display for DecodeErr {
289    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290        <Self as fmt::Debug>::fmt(self, f)
291    }
292}
293
294#[cfg(feature = "std")]
295impl std::error::Error for DecodeErr {}
296
297#[derive(Debug)]
298enum DecodeState {
299    LookingForMessageStart {
300        num_discarded_bytes: u16,
301        num_init_seq_bytes: u8,
302    },
303    ParsingNormal,
304    ParsingEscChars(u8),
305    ParsingEscPayload(u8),
306    Done,
307}
308
309/// Decoder for sml transport v1.
310///
311/// # Examples
312///
313/// ```
314/// # use sml_rs::{util::ArrayBuf, transport::Decoder};
315/// let bytes = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b];
316/// let expected = [0x12, 0x34, 0x56, 0x78];
317///
318/// let mut decoder = Decoder::<ArrayBuf<20>>::new();
319/// for b in bytes {
320///     match decoder.push_byte(b) {
321///         Ok(None) => {},  // nothing to output currently
322///         Ok(Some(decoded)) => {  // complete and valid message was decoded
323///             assert_eq!(decoded, expected);
324///         }
325///         Err(e) => {
326///             panic!("Unexpected Error: {:?}", e);
327///         }
328///     }
329/// }
330/// assert_eq!(decoder.finalize(), None)
331/// ```
332pub struct Decoder<B: Buffer> {
333    buf: B,
334    raw_msg_len: usize,
335    crc: crc::Digest<'static, u16>,
336    crc_idx: usize,
337    state: DecodeState,
338}
339
340impl<B: Buffer> Default for Decoder<B> {
341    fn default() -> Self {
342        Self::new()
343    }
344}
345
346impl<B: Buffer> Decoder<B> {
347    /// Constructs a new decoder.
348    #[must_use]
349    pub fn new() -> Self {
350        Self::from_buf(Default::default())
351    }
352
353    /// Constructs a new decoder using an existing buffer `buf`.
354    pub fn from_buf(mut buf: B) -> Self {
355        buf.clear();
356        Decoder {
357            buf,
358            raw_msg_len: 0,
359            crc: CRC_X25.digest(),
360            crc_idx: 0,
361            state: DecodeState::LookingForMessageStart {
362                num_discarded_bytes: 0,
363                num_init_seq_bytes: 0,
364            },
365        }
366    }
367
368    /// Pushes a byte `b` into the decoder, advances the parser state and possibly returns
369    /// a transmission or an decoder error.
370    pub fn push_byte(&mut self, b: u8) -> Result<Option<&[u8]>, DecodeErr> {
371        self._push_byte(b)
372            .map(|b| if b { Some(self.borrow_buf()) } else { None })
373    }
374
375    /// Resets the `Decoder` and returns an error if it contained an incomplete message.
376    pub fn finalize(&mut self) -> Option<DecodeErr> {
377        use DecodeState::*;
378        let res = match self.state {
379            LookingForMessageStart {
380                num_discarded_bytes: 0,
381                num_init_seq_bytes: 0,
382            } => None,
383            Done => None,
384            _ => Some(DecodeErr::DiscardedBytes(self.raw_msg_len)),
385        };
386        self.reset();
387        res
388    }
389
390    /// Main function of the parser.
391    ///
392    /// Returns
393    /// - `Ok(true)` if a complete message is ready.
394    /// - `Ok(false)` when more bytes are necessary to complete parsing a message.
395    /// - `Err(_)` if an error occurred during parsing
396    pub(crate) fn _push_byte(&mut self, b: u8) -> Result<bool, DecodeErr> {
397        use DecodeState::*;
398        self.raw_msg_len += 1;
399        match self.state {
400            LookingForMessageStart {
401                ref mut num_discarded_bytes,
402                ref mut num_init_seq_bytes,
403            } => {
404                if (b == 0x1b && *num_init_seq_bytes < 4) || (b == 0x01 && *num_init_seq_bytes >= 4)
405                {
406                    *num_init_seq_bytes += 1;
407                } else {
408                    *num_discarded_bytes += 1 + u16::from(*num_init_seq_bytes);
409                    *num_init_seq_bytes = 0;
410                }
411                if *num_init_seq_bytes == 8 {
412                    let num_discarded_bytes = *num_discarded_bytes;
413                    self.state = ParsingNormal;
414                    self.raw_msg_len = 8;
415                    assert_eq!(self.buf.len(), 0);
416                    assert_eq!(self.crc_idx, 0);
417                    self.crc = CRC_X25.digest();
418                    self.crc
419                        .update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01]);
420                    if num_discarded_bytes > 0 {
421                        return Err(DecodeErr::DiscardedBytes(num_discarded_bytes as usize));
422                    }
423                }
424            }
425            ParsingNormal => {
426                if b == 0x1b {
427                    // this could be the first byte of an escape sequence
428                    self.state = ParsingEscChars(1);
429                } else {
430                    // regular data
431                    self.push(b)?;
432                }
433            }
434            ParsingEscChars(n) => {
435                if b != 0x1b {
436                    // push previous 0x1b bytes as they didn't belong to an escape sequence
437                    for _ in 0..n {
438                        self.push(0x1b)?;
439                    }
440                    // push current byte
441                    self.push(b)?;
442                    // continue in regular parsing state
443                    self.state = ParsingNormal;
444                } else if n == 3 {
445                    // this is the fourth 0x1b byte, so we're seeing an escape sequence.
446                    // continue by parsing the escape sequence's payload.
447
448                    // also update the crc here. the escape bytes aren't stored in `buf`, but
449                    // still need to count for the crc calculation
450                    // (1) add everything that's in the buffer and hasn't been added to the crc previously
451                    self.crc.update(&self.buf[self.crc_idx..self.buf.len()]);
452                    // (2) add the four escape bytes
453                    self.crc.update(&[0x1b, 0x1b, 0x1b, 0x1b]);
454                    // update crc_idx to indicate that everything that's currently in the buffer has already
455                    // been used to update the crc
456                    self.crc_idx = self.buf.len();
457
458                    self.state = ParsingEscPayload(0);
459                } else {
460                    self.state = ParsingEscChars(n + 1);
461                }
462            }
463            ParsingEscPayload(n) => {
464                self.push(b)?;
465                if n < 3 {
466                    self.state = ParsingEscPayload(n + 1);
467                } else {
468                    // last 4 elements in self.buf are the escape sequence payload
469                    let payload = &self.buf[self.buf.len() - 4..self.buf.len()];
470                    if payload == [0x1b, 0x1b, 0x1b, 0x1b] {
471                        // escape sequence in user data
472
473                        // nothing to do here as the input has already been added to the buffer (see above)
474                        self.state = ParsingNormal;
475                    } else if payload == [0x01, 0x01, 0x01, 0x01] {
476                        // another transmission start
477
478                        // ignore everything that has previously been read and start reading a new transmission
479                        let ignored_bytes = self.raw_msg_len - 8;
480                        self.raw_msg_len = 8;
481                        self.buf.clear();
482                        self.crc = CRC_X25.digest();
483                        self.crc
484                            .update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01]);
485                        self.crc_idx = 0;
486                        self.state = ParsingNormal;
487                        return Err(DecodeErr::DiscardedBytes(ignored_bytes));
488                    } else if payload[0] == 0x1a {
489                        // end sequence (layout: [0x1a, num_padding_bytes, crc, crc])
490
491                        // check number of padding bytes
492                        let num_padding_bytes = payload[1];
493
494                        // compute and compare checksum
495                        let read_crc = u16::from_le_bytes([payload[2], payload[3]]);
496                        // update the crc, but exclude the last two bytes (which contain the crc itself)
497                        self.crc
498                            .update(&self.buf[self.crc_idx..(self.buf.len() - 2)]);
499                        // get the calculated crc and reset it afterwards
500                        let calculated_crc = {
501                            let mut crc = CRC_X25.digest();
502                            core::mem::swap(&mut crc, &mut self.crc);
503                            crc.finalize()
504                        };
505
506                        // check alignment (end marker needs to have 4-byte alignment)
507                        let misaligned = self.buf.len() % 4 != 0;
508
509                        // check if padding is larger than the message length
510                        let padding_too_large = num_padding_bytes > 3
511                            || (num_padding_bytes as usize + 4) > self.buf.len();
512
513                        if read_crc != calculated_crc || misaligned || padding_too_large {
514                            self.reset();
515                            return Err(DecodeErr::InvalidMessage {
516                                checksum_mismatch: (read_crc, calculated_crc),
517                                end_esc_misaligned: misaligned,
518                                num_padding_bytes,
519                            });
520                        }
521
522                        // subtract padding bytes and escape payload length from buffer length
523                        self.buf
524                            .truncate(self.buf.len() - num_padding_bytes as usize - 4);
525
526                        self.set_done();
527
528                        return Ok(true);
529                    } else {
530                        // special case of message ending with incomplete escape sequence
531                        // Explanation:
532                        // when a message ends with 1-3 0x1b bytes and there's no padding bytes,
533                        // we end up in this branch because there's four consecutive 0x1b bytes
534                        // that aren't followed by a known escape sequence. The problem is that
535                        // the first 1-3 0x1b bytes belong to the message, not to the end escape
536                        // code.
537                        // Example:
538                        //                  detected as escape sequence
539                        //                  vvvv vvvv
540                        // Message: ... 12341b1b 1b1b1b1b 1a00abcd
541                        //                       ^^^^^^^^
542                        //                       real escape sequence
543                        //
544                        // The solution for this issue is to check whether the read esacpe code
545                        // isn't aligned to a 4-byte boundary and followed by an aligned end
546                        // escape sequence (`1b1b1b1b 1a...`).
547                        // If that's the case, simply reset the parser state by 1-3 steps. This
548                        // will parse the 0x1b bytes in the message as regular bytes and check
549                        // for the end escape code at the right position.
550                        let bytes_until_alignment = (4 - (self.buf.len() % 4)) % 4;
551                        if bytes_until_alignment > 0
552                            && payload[..bytes_until_alignment].iter().all(|x| *x == 0x1b)
553                            && payload[bytes_until_alignment] == 0x1a
554                        {
555                            self.state = ParsingEscPayload(4 - bytes_until_alignment as u8);
556                            return Ok(false);
557                        }
558
559                        // invalid escape sequence
560
561                        // unwrap is safe here because payload is guaranteed to have size 4
562                        let esc_bytes: [u8; 4] = payload.try_into().unwrap();
563                        self.reset();
564                        return Err(DecodeErr::InvalidEsc(esc_bytes));
565                    }
566                }
567            }
568            Done => {
569                // reset and let's go again
570                self.reset();
571                return self._push_byte(b);
572            }
573        }
574        Ok(false)
575    }
576
577    pub(crate) fn borrow_buf(&self) -> &[u8] {
578        if !matches!(self.state, DecodeState::Done) {
579            panic!("Reading from the internal buffer is only allowed when a complete message is present (DecodeState::Done). Found state {:?}.", self.state);
580        }
581        &self.buf[..self.buf.len()]
582    }
583
584    fn set_done(&mut self) {
585        self.state = DecodeState::Done;
586    }
587
588    /// Resets the `Decoder` and returns the number of bytes that were discarded
589    pub fn reset(&mut self) -> usize {
590        let num_discarded = match self.state {
591            DecodeState::Done => 0,
592            _ => self.raw_msg_len,
593        };
594        self.state = DecodeState::LookingForMessageStart {
595            num_discarded_bytes: 0,
596            num_init_seq_bytes: 0,
597        };
598        self.buf.clear();
599        self.crc_idx = 0;
600        self.raw_msg_len = 0;
601        num_discarded
602    }
603
604    fn push(&mut self, b: u8) -> Result<(), DecodeErr> {
605        if self.buf.push(b).is_err() {
606            self.reset();
607            return Err(DecodeErr::OutOfMemory);
608        }
609        Ok(())
610    }
611}
612
613/// Decode a given slice of bytes and returns a vector of messages / errors.
614///
615/// *This function is available only if sml-rs is built with the `"alloc"` feature.*
616///
617/// # Examples
618/// ```
619/// # use sml_rs::transport::decode;
620/// // example data
621/// let bytes = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b];
622/// let expected = [0x12, 0x34, 0x56, 0x78];
623/// let decoded = decode(&bytes);
624/// assert_eq!(decoded, vec!(Ok(expected.to_vec())));
625#[cfg(feature = "alloc")]
626#[must_use]
627pub fn decode(iter: impl IntoIterator<Item = impl Borrow<u8>>) -> Vec<Result<Vec<u8>, DecodeErr>> {
628    let mut decoder: Decoder<Vec<u8>> = Decoder::new();
629    let mut res = Vec::new();
630    for b in iter.into_iter() {
631        match decoder.push_byte(*b.borrow()) {
632            Ok(None) => {}
633            Ok(Some(buf)) => res.push(Ok(buf.to_vec())),
634            Err(e) => res.push(Err(e)),
635        }
636    }
637    if let Some(e) = decoder.finalize() {
638        res.push(Err(e));
639    }
640    res
641}
642
643/// Iterator over decoded messages / errors.
644pub struct DecodeIterator<B: Buffer, I: Iterator<Item = u8>> {
645    decoder: Decoder<B>,
646    bytes: I,
647    done: bool,
648}
649
650impl<B: Buffer, I: Iterator<Item = u8>> DecodeIterator<B, I> {
651    fn new(bytes: I) -> Self {
652        DecodeIterator {
653            decoder: Decoder::new(),
654            bytes,
655            done: false,
656        }
657    }
658
659    /// Returns the next message / error.
660    #[allow(clippy::should_implement_trait)]
661    pub fn next(&mut self) -> Option<Result<&[u8], DecodeErr>> {
662        if self.done {
663            return None;
664        }
665        loop {
666            match self.bytes.next() {
667                Some(b) => {
668                    match self.decoder._push_byte(b) {
669                        Ok(true) => return Some(Ok(self.decoder.borrow_buf())),
670                        Err(e) => {
671                            return Some(Err(e));
672                        }
673                        Ok(false) => {
674                            // take next byte...
675                        }
676                    }
677                }
678                None => {
679                    self.done = true;
680                    return self.decoder.finalize().map(Err);
681                }
682            }
683        }
684    }
685}
686
687/// Takes an iterator over bytes and returns an iterator that yields decoded messages / decoding errors.
688///
689/// # Examples
690/// ```
691/// # use sml_rs::{util::ArrayBuf, transport::decode_streaming};
692/// // example data
693/// let bytes = [
694///     // first message
695///     0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b,
696///     // second message
697///     0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x13, 0x24, 0x35, 0x46, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb1, 0xa1,
698/// ];
699/// let mut decode_iterator = decode_streaming::<ArrayBuf<10>>(&bytes);
700/// assert_eq!(decode_iterator.next(), Some(Ok([0x12, 0x34, 0x56, 0x78].as_slice())));
701/// assert_eq!(decode_iterator.next(), Some(Ok([0x13, 0x24, 0x35, 0x46].as_slice())));
702/// assert_eq!(decode_iterator.next(), None);
703pub fn decode_streaming<B: Buffer>(
704    iter: impl IntoIterator<Item = impl Borrow<u8>>,
705) -> DecodeIterator<B, impl Iterator<Item = u8>> {
706    DecodeIterator::new(iter.into_iter().map(|x| *x.borrow()))
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712    use hex_literal::hex;
713
714    // assert_eq macro that prints its arguments as hex when they don't match.
715    // (adapted from the `assert_hex` crate)
716    macro_rules! assert_eq_hex {
717        ($left:expr, $right:expr $(,)?) => {{
718            match (&$left, &$right) {
719                (left_val, right_val) => {
720                    if !(*left_val == *right_val) {
721                        // The reborrows below are intentional. Without them, the stack slot for the
722                        // borrow is initialized even before the values are compared, leading to a
723                        // noticeable slow down.
724                        panic!(
725                            "assertion failed: `(left == right)`\n  left: `{:02x?}`,\n right: `{:02x?}`",
726                            &*left_val, &*right_val
727                        )
728                    }
729                }
730            }
731        }};
732    }
733
734    fn test_encoding<const N: usize>(bytes: &[u8], exp_encoded_bytes: &[u8; N]) {
735        compare_encoded_bytes(
736            exp_encoded_bytes,
737            &encode::<crate::util::ArrayBuf<N>>(bytes).expect("ran out of memory"),
738        );
739        compare_encoded_bytes(
740            exp_encoded_bytes,
741            &encode_streaming(bytes).collect::<crate::util::ArrayBuf<N>>(),
742        );
743        #[cfg(feature = "alloc")]
744        assert_eq_hex!(alloc::vec![Ok(bytes.to_vec())], decode(exp_encoded_bytes));
745    }
746
747    fn compare_encoded_bytes(expected: &[u8], actual: &[u8]) {
748        assert_eq_hex!(expected, actual);
749    }
750
751    #[test]
752    fn basic() {
753        test_encoding(
754            &hex!("12345678"),
755            &hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b"),
756        );
757    }
758
759    #[test]
760    fn empty() {
761        test_encoding(&hex!(""), &hex!("1b1b1b1b 01010101 1b1b1b1b 1a00c6e5"));
762    }
763
764    #[test]
765    fn padding() {
766        test_encoding(
767            &hex!("123456"),
768            &hex!("1b1b1b1b 01010101 12345600 1b1b1b1b 1a0191a5"),
769        );
770    }
771
772    #[test]
773    fn escape_in_user_data() {
774        test_encoding(
775            &hex!("121b1b1b1b"),
776            &hex!("1b1b1b1b 01010101 12 1b1b1b1b 1b1b1b1b 000000 1b1b1b1b 1a03be25"),
777        );
778    }
779
780    #[test]
781    fn almost_escape_in_user_data() {
782        test_encoding(
783            &hex!("121b1b1bFF"),
784            &hex!("1b1b1b1b 01010101 12 1b1b1bFF 000000 1b1b1b1b 1a0324d9"),
785        );
786    }
787
788    #[test]
789    fn ending_with_1b_no_padding() {
790        test_encoding(
791            &hex!("12345678 12341b1b"),
792            &hex!("1b1b1b1b 01010101 12345678 12341b1b 1b1b1b1b 1a001ac5"),
793        );
794    }
795}
796
797#[cfg(test)]
798mod decode_tests {
799    use super::*;
800    use crate::util::ArrayBuf;
801    use hex_literal::hex;
802    use DecodeErr::*;
803
804    fn test_parse_input<B: Buffer>(bytes: &[u8], exp: &[Result<&[u8], DecodeErr>]) {
805        // check that the streaming decoder yields the expected data
806        let mut exp_iter = exp.iter();
807        let mut streaming_decoder = DecodeIterator::<B, _>::new(bytes.iter().cloned());
808
809        while let Some(res) = streaming_decoder.next() {
810            match exp_iter.next() {
811                Some(exp) => {
812                    assert_eq!(res, *exp);
813                }
814                None => {
815                    panic!("Additional decoded item: {:?}", res);
816                }
817            }
818        }
819        assert_eq!(exp_iter.next(), None);
820
821        // check that Decoder and DecodeIterator yield the same data:
822        let mut decoder = Decoder::<B>::new();
823        let mut streaming_decoder = DecodeIterator::<B, _>::new(bytes.iter().cloned());
824        for b in bytes {
825            let res = decoder.push_byte(*b);
826            if let Ok(None) = res {
827                continue;
828            }
829            let res2 = streaming_decoder.next();
830            match (res, res2) {
831                (Ok(Some(a)), Some(Ok(b))) => assert_eq!(a, b),
832                (Err(a), Some(Err(b))) => assert_eq!(a, b),
833                (a, b) => panic!(
834                    "Mismatch between decoder and streaming_decoder: {:?} vs. {:?}",
835                    a, b
836                ),
837            }
838        }
839        match (decoder.finalize(), streaming_decoder.next()) {
840            (None, None) => {},
841            (Some(a), Some(Err(b))) => assert_eq!(a, b),
842            (a, b) => panic!("Mismatch between decoder and streaming_decoder on the final element: {:?} vs. {:?}", a, b),
843        }
844    }
845
846    #[test]
847    fn basic() {
848        let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b");
849        let exp = &[Ok(hex!("12345678").as_slice())];
850
851        test_parse_input::<ArrayBuf<8>>(&bytes, exp);
852    }
853
854    #[test]
855    fn out_of_memory() {
856        let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b");
857        let exp = &[Err(DecodeErr::OutOfMemory)];
858
859        test_parse_input::<ArrayBuf<7>>(&bytes, exp);
860    }
861
862    #[test]
863    fn invalid_crc() {
864        let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b8FF");
865        let exp = &[Err(InvalidMessage {
866            checksum_mismatch: (0xFFb8, 0x7bb8),
867            end_esc_misaligned: false,
868            num_padding_bytes: 0,
869        })];
870
871        test_parse_input::<ArrayBuf<8>>(&bytes, exp);
872    }
873
874    #[test]
875    fn msg_end_misaligned() {
876        let bytes = hex!("1b1b1b1b 01010101 12345678 FF 1b1b1b1b 1a0013b6");
877        let exp = &[Err(InvalidMessage {
878            checksum_mismatch: (0xb613, 0xb613),
879            end_esc_misaligned: true,
880            num_padding_bytes: 0,
881        })];
882
883        test_parse_input::<ArrayBuf<16>>(&bytes, exp);
884    }
885
886    #[test]
887    fn padding_too_large() {
888        let bytes = hex!("1b1b1b1b 01010101 12345678 12345678 1b1b1b1b 1a04f950");
889        let exp = &[Err(InvalidMessage {
890            checksum_mismatch: (0x50f9, 0x50f9),
891            end_esc_misaligned: false,
892            num_padding_bytes: 4,
893        })];
894
895        test_parse_input::<ArrayBuf<16>>(&bytes, exp);
896    }
897
898    #[test]
899    fn empty_msg_with_padding() {
900        let bytes = hex!("1b1b1b1b 01010101 1b1b1b1b 1a014FF4");
901        let exp = &[Err(InvalidMessage {
902            checksum_mismatch: (0xf44f, 0xf44f),
903            end_esc_misaligned: false,
904            num_padding_bytes: 1,
905        })];
906
907        test_parse_input::<ArrayBuf<16>>(&bytes, exp);
908    }
909
910    #[test]
911    fn additional_bytes() {
912        let bytes = hex!("000102 1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b 1234");
913        let exp = &[
914            Err(DiscardedBytes(3)),
915            Ok(hex!("12345678").as_slice()),
916            Err(DiscardedBytes(2)),
917        ];
918
919        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
920    }
921
922    #[test]
923    fn incomplete_message() {
924        let bytes = hex!("1b1b1b1b 01010101 123456");
925        let exp = &[Err(DiscardedBytes(11))];
926
927        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
928    }
929
930    #[test]
931    fn invalid_esc_sequence() {
932        let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1c000000 12345678 1b1b1b1b 1a03be25");
933        let exp = &[
934            Err(InvalidEsc([0x1c, 0x0, 0x0, 0x0])),
935            Err(DiscardedBytes(12)),
936        ];
937
938        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
939    }
940
941    #[test]
942    fn incomplete_esc_sequence() {
943        let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b00 12345678 1b1b1b1b 1a030A07");
944        let exp = &[Ok(hex!("12345678 1b1b1b00 12").as_slice())];
945
946        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
947    }
948
949    #[test]
950    fn double_msg_start() {
951        let bytes =
952            hex!("1b1b1b1b 01010101 09 87654321 1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b");
953        let exp = &[Err(DiscardedBytes(13)), Ok(hex!("12345678").as_slice())];
954
955        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
956    }
957
958    #[test]
959    fn padding() {
960        let bytes = hex!("1b1b1b1b 01010101 12345600 1b1b1b1b 1a0191a5");
961        let exp_bytes = hex!("123456");
962        let exp = &[Ok(exp_bytes.as_slice())];
963
964        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
965    }
966
967    #[test]
968    fn escape_in_user_data() {
969        let bytes = hex!("1b1b1b1b 01010101 12 1b1b1b1b 1b1b1b1b 000000 1b1b1b1b 1a03be25");
970        let exp = &[Ok(hex!("121b1b1b1b").as_slice())];
971
972        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
973    }
974
975    #[test]
976    fn ending_with_1b_no_padding_1() {
977        let bytes = hex!("1b1b1b1b 01010101 12345678 1234561b 1b1b1b1b 1a00361a");
978        let exp_bytes = hex!("12345678 1234561b");
979        let exp = &[Ok(exp_bytes.as_slice())];
980
981        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
982    }
983
984    #[test]
985    fn ending_with_1b_no_padding_2() {
986        let bytes = hex!("1b1b1b1b 01010101 12345678 12341b1b 1b1b1b1b 1a001ac5");
987        let exp_bytes = hex!("12345678 12341b1b");
988        let exp = &[Ok(exp_bytes.as_slice())];
989
990        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
991    }
992
993    #[test]
994    fn ending_with_1b_no_padding_3() {
995        let bytes = hex!("1b1b1b1b 01010101 12345678 121b1b1b 1b1b1b1b 1a000ba4");
996        let exp_bytes = hex!("12345678 121b1b1b");
997        let exp = &[Ok(exp_bytes.as_slice())];
998
999        test_parse_input::<ArrayBuf<128>>(&bytes, exp);
1000    }
1001
1002    #[cfg(feature = "alloc")]
1003    #[test]
1004    fn alloc_basic() {
1005        let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b");
1006        let exp = &[Ok(hex!("12345678").as_slice())];
1007
1008        test_parse_input::<Vec<u8>>(&bytes, exp);
1009    }
1010}