1use core::num::TryFromIntError;
4
5use bytes::Buf;
6use tokio_util::{
7 bytes::{BufMut, BytesMut},
8 codec::{Decoder, Encoder},
9};
10
11use crate::{
12 command::owned::Command,
13 decode::owned::DecodeWithLength,
14 encode::{Length, owned::Encode},
15 logging::{debug, error, trace},
16};
17
18#[cfg(test)]
19mod tests;
20
21#[derive(Debug)]
22enum DecodeState {
23 Length,
25 Command { command_length: usize },
27}
28
29#[derive(Debug)]
31pub struct CommandCodec {
32 max_length: Option<usize>,
33 state: DecodeState,
34}
35
36impl CommandCodec {
37 #[inline]
39 pub const fn new() -> Self {
40 Self {
41 max_length: Some(8192),
42 state: DecodeState::Length,
43 }
44 }
45
46 #[inline]
47 pub const fn max_length(&self) -> Option<usize> {
48 self.max_length
49 }
50
51 #[inline]
52 pub fn with_max_length(mut self, max_length: usize) -> Self {
53 self.max_length = Some(max_length);
54 self
55 }
56
57 #[inline]
58 pub fn without_max_length(mut self) -> Self {
59 self.max_length = None;
60 self
61 }
62
63 #[inline]
65 const fn decode_length(&mut self) {
66 self.state = DecodeState::Length;
67 }
68
69 #[inline]
71 const fn decode_command(&mut self, command_length: usize) {
72 self.state = DecodeState::Command { command_length };
73 }
74}
75
76impl Default for CommandCodec {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[derive(Debug)]
84#[non_exhaustive]
85pub enum EncodeError {
86 Io(std::io::Error),
88}
89
90impl From<std::io::Error> for EncodeError {
91 fn from(e: std::io::Error) -> Self {
92 EncodeError::Io(e)
93 }
94}
95
96impl core::fmt::Display for EncodeError {
97 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
98 match self {
99 EncodeError::Io(e) => write!(f, "I/O error: {e}"),
100 }
101 }
102}
103
104impl core::error::Error for EncodeError {
105 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
106 match self {
107 EncodeError::Io(e) => Some(e),
108 }
109 }
110
111 fn cause(&self) -> Option<&dyn core::error::Error> {
112 self.source()
113 }
114}
115
116impl Encoder<&Command> for CommandCodec {
117 type Error = EncodeError;
118
119 fn encode(&mut self, command: &Command, dst: &mut BytesMut) -> Result<(), Self::Error> {
120 let command_length = 4 + command.length();
121
122 dst.reserve(command_length);
123 dst.put_u32(command_length as u32);
124 command.encode(dst);
125
126 debug!(target: "rusmpp::codec::encode", command=?command, "Encoding");
127 debug!(target: "rusmpp::codec::encode", encoded=?crate::formatter::Formatter(&dst[..command_length]), encoded_length=command.length(), command_length, "Encoded");
128
129 Ok(())
130 }
131}
132
133impl Encoder<Command> for CommandCodec {
134 type Error = EncodeError;
135
136 fn encode(&mut self, command: Command, dst: &mut BytesMut) -> Result<(), Self::Error> {
137 self.encode(&command, dst)
138 }
139}
140
141#[derive(Debug)]
143#[non_exhaustive]
144pub enum DecodeError {
145 Io(std::io::Error),
147 Decode(crate::decode::DecodeError),
149 MinLength { actual: usize, min: usize },
151 MaxLength { actual: usize, max: usize },
153 InvalidLength(TryFromIntError),
155}
156
157impl From<std::io::Error> for DecodeError {
158 fn from(e: std::io::Error) -> Self {
159 DecodeError::Io(e)
160 }
161}
162
163impl core::fmt::Display for DecodeError {
164 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
165 match self {
166 DecodeError::Io(e) => write!(f, "I/O error: {e}"),
167 DecodeError::Decode(e) => write!(f, "Decode error: {e}"),
168 DecodeError::MinLength { actual, min } => {
169 write!(
170 f,
171 "Minimum command length not met. actual: {actual}, min: {min}"
172 )
173 }
174 DecodeError::MaxLength { actual, max } => {
175 write!(
176 f,
177 "Maximum command length exceeded. actual: {actual}, max: {max}"
178 )
179 }
180 DecodeError::InvalidLength(e) => {
181 write!(f, "Integral type conversion failed: {e}")
182 }
183 }
184 }
185}
186
187impl core::error::Error for DecodeError {
188 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
189 match self {
190 DecodeError::Io(e) => Some(e),
191 DecodeError::Decode(e) => Some(e),
192 DecodeError::MinLength { .. } => None,
193 DecodeError::MaxLength { .. } => None,
194 DecodeError::InvalidLength(e) => Some(e),
195 }
196 }
197
198 fn cause(&self) -> Option<&dyn core::error::Error> {
199 self.source()
200 }
201}
202
203impl Decoder for CommandCodec {
204 type Item = Command;
205 type Error = DecodeError;
206
207 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
208 const HEADER_LENGTH: usize = 16;
209
210 loop {
211 match self.state {
212 DecodeState::Length => {
213 if src.len() < HEADER_LENGTH {
214 trace!(target: "rusmpp::codec::decode", source_length=src.len(), "Not enough bytes to read the header");
215
216 return Ok(None);
217 }
218
219 let command_length = usize::try_from(src.get_u32()).map_err(|err|
220 {
221 error!(target: "rusmpp::codec::decode", ?err, "Failed to convert command length to usize");
222
223 DecodeError::InvalidLength(err)
224 })?;
225
226 trace!(target: "rusmpp::codec::decode", command_length);
227
228 if command_length < HEADER_LENGTH {
229 error!(target: "rusmpp::codec::decode", command_length, min_command_length=HEADER_LENGTH, "Minimum command length not met");
230
231 return Err(DecodeError::MinLength {
232 actual: command_length,
233 min: HEADER_LENGTH,
234 });
235 }
236
237 #[allow(clippy::collapsible_if)]
239 if let Some(max_command_length) = self.max_length {
240 if command_length > max_command_length {
241 error!(target: "rusmpp::codec::decode", command_length, max_command_length, "Maximum command length exceeded");
242
243 return Err(DecodeError::MaxLength {
244 actual: command_length,
245 max: max_command_length,
246 });
247 }
248 }
249
250 self.decode_command(command_length);
251 }
252 DecodeState::Command { command_length } => {
253 let pdu_length = command_length - 4;
255
256 if src.len() < pdu_length {
257 src.reserve(pdu_length - src.len());
259
260 trace!(target: "rusmpp::codec::decode", command_length, decode_length=pdu_length, "Not enough bytes to read the entire command");
261
262 return Ok(None);
263 }
264
265 debug!(target: "rusmpp::codec::decode", command_length, decode_length=pdu_length, decoding=?crate::formatter::Formatter(&src[..pdu_length]), "Decoding");
266
267 let (command, _size) = match Command::decode(src, pdu_length) {
268 Ok((command, size)) => {
269 debug!(target: "rusmpp::codec::decode", command=?command, command_length, decoded_length=size, "Decoded");
270
271 (command, size)
272 }
273 Err(err) => {
274 error!(target: "rusmpp::codec::decode", ?err);
275
276 self.decode_length();
277
278 return Err(DecodeError::Decode(err));
279 }
280 };
281
282 self.decode_length();
283
284 return Ok(Some(command));
285 }
286 }
287 }
288 }
289}