rusmpp_core/tokio_codec/
mod.rs

1//! Tokio's util [`Encoder`] and [`Decoder`] implementations for [`CommandCodec`].
2
3use core::num::TryFromIntError;
4
5use bytes::Buf;
6use tokio_util::{
7    bytes::{BufMut, BytesMut},
8    codec::{Decoder, Encoder},
9};
10
11use crate::{
12    command::owned::Command,
13    decode::owned::DecodeWithLength,
14    encode::{Length, owned::Encode},
15    logging::{debug, error, trace},
16};
17
18#[cfg(test)]
19mod tests;
20
21#[derive(Debug)]
22enum DecodeState {
23    /// Decoding the command length.
24    Length,
25    /// Decoding the command (Header without length + Body).
26    Command { command_length: usize },
27}
28
29/// Codec for encoding and decoding `SMPP` PDUs.
30#[derive(Debug)]
31pub struct CommandCodec {
32    max_length: Option<usize>,
33    state: DecodeState,
34}
35
36impl CommandCodec {
37    /// Creates a new [`CommandCodec`] with a default maximum length of `8192` bytes.
38    #[inline]
39    pub const fn new() -> Self {
40        Self {
41            max_length: Some(8192),
42            state: DecodeState::Length,
43        }
44    }
45
46    #[inline]
47    pub const fn max_length(&self) -> Option<usize> {
48        self.max_length
49    }
50
51    #[inline]
52    pub fn with_max_length(mut self, max_length: usize) -> Self {
53        self.max_length = Some(max_length);
54        self
55    }
56
57    #[inline]
58    pub fn without_max_length(mut self) -> Self {
59        self.max_length = None;
60        self
61    }
62
63    /// Sets the decoder state to decode the command length.
64    #[inline]
65    const fn decode_length(&mut self) {
66        self.state = DecodeState::Length;
67    }
68
69    /// Sets the decoder state to decode the rest of the command.
70    #[inline]
71    const fn decode_command(&mut self, command_length: usize) {
72        self.state = DecodeState::Command { command_length };
73    }
74}
75
76impl Default for CommandCodec {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82/// An error that can occur when encoding a `Command`.
83#[derive(Debug)]
84#[non_exhaustive]
85pub enum EncodeError {
86    /// I/O error.
87    Io(std::io::Error),
88}
89
90impl From<std::io::Error> for EncodeError {
91    fn from(e: std::io::Error) -> Self {
92        EncodeError::Io(e)
93    }
94}
95
96impl core::fmt::Display for EncodeError {
97    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
98        match self {
99            EncodeError::Io(e) => write!(f, "I/O error: {e}"),
100        }
101    }
102}
103
104impl core::error::Error for EncodeError {
105    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
106        match self {
107            EncodeError::Io(e) => Some(e),
108        }
109    }
110
111    fn cause(&self) -> Option<&dyn core::error::Error> {
112        self.source()
113    }
114}
115
116impl Encoder<&Command> for CommandCodec {
117    type Error = EncodeError;
118
119    fn encode(&mut self, command: &Command, dst: &mut BytesMut) -> Result<(), Self::Error> {
120        let command_length = 4 + command.length();
121
122        dst.reserve(command_length);
123        dst.put_u32(command_length as u32);
124        command.encode(dst);
125
126        debug!(target: "rusmpp::codec::encode", command=?command, "Encoding");
127        debug!(target: "rusmpp::codec::encode", encoded=?crate::formatter::Formatter(&dst[..command_length]), encoded_length=command.length(), command_length, "Encoded");
128
129        Ok(())
130    }
131}
132
133impl Encoder<Command> for CommandCodec {
134    type Error = EncodeError;
135
136    fn encode(&mut self, command: Command, dst: &mut BytesMut) -> Result<(), Self::Error> {
137        self.encode(&command, dst)
138    }
139}
140
141/// An error that can occur when decoding a `Command`.
142#[derive(Debug)]
143#[non_exhaustive]
144pub enum DecodeError {
145    /// I/O error.
146    Io(std::io::Error),
147    /// Decode error.
148    Decode(crate::decode::DecodeError),
149    /// Minimum command length not met.
150    MinLength { actual: usize, min: usize },
151    /// Maximum command length exceeded.
152    MaxLength { actual: usize, max: usize },
153    /// Integral type conversion failed.
154    InvalidLength(TryFromIntError),
155}
156
157impl From<std::io::Error> for DecodeError {
158    fn from(e: std::io::Error) -> Self {
159        DecodeError::Io(e)
160    }
161}
162
163impl core::fmt::Display for DecodeError {
164    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
165        match self {
166            DecodeError::Io(e) => write!(f, "I/O error: {e}"),
167            DecodeError::Decode(e) => write!(f, "Decode error: {e}"),
168            DecodeError::MinLength { actual, min } => {
169                write!(
170                    f,
171                    "Minimum command length not met. actual: {actual}, min: {min}"
172                )
173            }
174            DecodeError::MaxLength { actual, max } => {
175                write!(
176                    f,
177                    "Maximum command length exceeded. actual: {actual}, max: {max}"
178                )
179            }
180            DecodeError::InvalidLength(e) => {
181                write!(f, "Integral type conversion failed: {e}")
182            }
183        }
184    }
185}
186
187impl core::error::Error for DecodeError {
188    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
189        match self {
190            DecodeError::Io(e) => Some(e),
191            DecodeError::Decode(e) => Some(e),
192            DecodeError::MinLength { .. } => None,
193            DecodeError::MaxLength { .. } => None,
194            DecodeError::InvalidLength(e) => Some(e),
195        }
196    }
197
198    fn cause(&self) -> Option<&dyn core::error::Error> {
199        self.source()
200    }
201}
202
203impl Decoder for CommandCodec {
204    type Item = Command;
205    type Error = DecodeError;
206
207    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
208        const HEADER_LENGTH: usize = 16;
209
210        loop {
211            match self.state {
212                DecodeState::Length => {
213                    if src.len() < HEADER_LENGTH {
214                        trace!(target: "rusmpp::codec::decode", source_length=src.len(), "Not enough bytes to read the header");
215
216                        return Ok(None);
217                    }
218
219                    let command_length = usize::try_from(src.get_u32()).map_err(|err|
220                        {
221                            error!(target: "rusmpp::codec::decode", ?err, "Failed to convert command length to usize");
222
223                            DecodeError::InvalidLength(err)
224                        })?;
225
226                    trace!(target: "rusmpp::codec::decode", command_length);
227
228                    if command_length < HEADER_LENGTH {
229                        error!(target: "rusmpp::codec::decode", command_length, min_command_length=HEADER_LENGTH, "Minimum command length not met");
230
231                        return Err(DecodeError::MinLength {
232                            actual: command_length,
233                            min: HEADER_LENGTH,
234                        });
235                    }
236
237                    // XXX: keep msrv lower than 1.90
238                    #[allow(clippy::collapsible_if)]
239                    if let Some(max_command_length) = self.max_length {
240                        if command_length > max_command_length {
241                            error!(target: "rusmpp::codec::decode", command_length, max_command_length, "Maximum command length exceeded");
242
243                            return Err(DecodeError::MaxLength {
244                                actual: command_length,
245                                max: max_command_length,
246                            });
247                        }
248                    }
249
250                    self.decode_command(command_length);
251                }
252                DecodeState::Command { command_length } => {
253                    // command_length is at least 16 (bytes)
254                    let pdu_length = command_length - 4;
255
256                    if src.len() < pdu_length {
257                        // Reserve enough space to read the entire command
258                        src.reserve(pdu_length - src.len());
259
260                        trace!(target: "rusmpp::codec::decode", command_length, decode_length=pdu_length, "Not enough bytes to read the entire command");
261
262                        return Ok(None);
263                    }
264
265                    debug!(target: "rusmpp::codec::decode", command_length, decode_length=pdu_length, decoding=?crate::formatter::Formatter(&src[..pdu_length]), "Decoding");
266
267                    let (command, _size) = match Command::decode(src, pdu_length) {
268                        Ok((command, size)) => {
269                            debug!(target: "rusmpp::codec::decode", command=?command, command_length, decoded_length=size, "Decoded");
270
271                            (command, size)
272                        }
273                        Err(err) => {
274                            error!(target: "rusmpp::codec::decode", ?err);
275
276                            self.decode_length();
277
278                            return Err(DecodeError::Decode(err));
279                        }
280                    };
281
282                    self.decode_length();
283
284                    return Ok(Some(command));
285                }
286            }
287        }
288    }
289}