1use core::num::TryFromIntError;
4
5use tokio_util::{
6 bytes::{Buf, BufMut, BytesMut},
7 codec::{Decoder, Encoder},
8};
9
10use crate::{
11 command::owned::Command,
12 decode::owned::DecodeWithLength,
13 encode::{Encode, Length},
14 logging::{debug, error, trace},
15};
16
17#[cfg(test)]
18mod tests;
19
20#[derive(Debug)]
22pub struct CommandCodec {
23 max_length: Option<usize>,
24}
25
26impl CommandCodec {
27 #[inline]
29 pub const fn new() -> Self {
30 Self {
31 max_length: Some(8192),
32 }
33 }
34
35 #[inline]
36 pub const fn max_length(&self) -> Option<usize> {
37 self.max_length
38 }
39
40 #[inline]
41 pub fn with_max_length(mut self, max_length: usize) -> Self {
42 self.max_length = Some(max_length);
43 self
44 }
45
46 #[inline]
47 pub fn without_max_length(mut self) -> Self {
48 self.max_length = None;
49 self
50 }
51}
52
53impl Default for CommandCodec {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59#[derive(Debug)]
61#[non_exhaustive]
62pub enum EncodeError {
63 Io(std::io::Error),
65}
66
67impl From<std::io::Error> for EncodeError {
68 fn from(e: std::io::Error) -> Self {
69 EncodeError::Io(e)
70 }
71}
72
73impl core::fmt::Display for EncodeError {
74 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75 match self {
76 EncodeError::Io(e) => write!(f, "I/O error: {e}"),
77 }
78 }
79}
80
81impl core::error::Error for EncodeError {
82 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
83 match self {
84 EncodeError::Io(e) => Some(e),
85 }
86 }
87
88 fn cause(&self) -> Option<&dyn core::error::Error> {
89 self.source()
90 }
91}
92
93impl Encoder<&Command> for CommandCodec {
94 type Error = EncodeError;
95
96 fn encode(&mut self, command: &Command, dst: &mut BytesMut) -> Result<(), Self::Error> {
97 let command_length = 4 + command.length();
98
99 dst.reserve(command_length);
100 dst.put_u32(command_length as u32);
101
102 let mut buf = alloc::vec![0; command.length()];
104 let _ = command.encode(buf.as_mut_slice());
105
106 dst.put_slice(&buf);
107
108 debug!(target: "rusmpp::codec::encode", command=?command, "Encoding");
109 debug!(target: "rusmpp::codec::encode", encoded=?crate::formatter::Formatter(&buf), encoded_length=command.length(), command_length, "Encoded");
110
111 Ok(())
112 }
113}
114
115impl Encoder<Command> for CommandCodec {
116 type Error = EncodeError;
117
118 fn encode(&mut self, command: Command, dst: &mut BytesMut) -> Result<(), Self::Error> {
119 self.encode(&command, dst)
120 }
121}
122
123#[derive(Debug)]
125#[non_exhaustive]
126pub enum DecodeError {
127 Io(std::io::Error),
129 Decode(crate::decode::DecodeError),
131 MinLength { actual: usize, min: usize },
133 MaxLength { actual: usize, max: usize },
135 InvalidLength(TryFromIntError),
137}
138
139impl From<std::io::Error> for DecodeError {
140 fn from(e: std::io::Error) -> Self {
141 DecodeError::Io(e)
142 }
143}
144
145impl core::fmt::Display for DecodeError {
146 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
147 match self {
148 DecodeError::Io(e) => write!(f, "I/O error: {e}"),
149 DecodeError::Decode(e) => write!(f, "Decode error: {e}"),
150 DecodeError::MinLength { actual, min } => {
151 write!(
152 f,
153 "Minimum command length not met. actual: {actual}, min: {min}"
154 )
155 }
156 DecodeError::MaxLength { actual, max } => {
157 write!(
158 f,
159 "Maximum command length exceeded. actual: {actual}, max: {max}"
160 )
161 }
162 DecodeError::InvalidLength(e) => {
163 write!(f, "Integral type conversion failed: {e}")
164 }
165 }
166 }
167}
168
169impl core::error::Error for DecodeError {
170 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
171 match self {
172 DecodeError::Io(e) => Some(e),
173 DecodeError::Decode(e) => Some(e),
174 DecodeError::MinLength { .. } => None,
175 DecodeError::MaxLength { .. } => None,
176 DecodeError::InvalidLength(e) => Some(e),
177 }
178 }
179
180 fn cause(&self) -> Option<&dyn core::error::Error> {
181 self.source()
182 }
183}
184
185impl Decoder for CommandCodec {
186 type Item = Command;
187 type Error = DecodeError;
188
189 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
190 const HEADER_LENGTH: usize = 16;
191
192 if src.len() < HEADER_LENGTH {
193 trace!(target: "rusmpp::codec::decode", source_length=src.len(), "Not enough bytes to read the header");
194
195 return Ok(None);
196 }
197
198 let command_length = usize::try_from(u32::from_be_bytes([src[0], src[1], src[2], src[3]])).map_err(|err|
199 {
200 error!(target: "rusmpp::codec::decode", ?err, "Failed to convert command length to usize");
201
202 DecodeError::InvalidLength(err)
203 })?;
204
205 trace!(target: "rusmpp::codec::decode", command_length);
206
207 if command_length < HEADER_LENGTH {
208 error!(target: "rusmpp::codec::decode", command_length, min_command_length=HEADER_LENGTH, "Minimum command length not met");
209
210 return Err(DecodeError::MinLength {
211 actual: command_length,
212 min: HEADER_LENGTH,
213 });
214 }
215
216 #[allow(clippy::collapsible_if)]
218 if let Some(max_command_length) = self.max_length {
219 if command_length > max_command_length {
220 error!(target: "rusmpp::codec::decode", command_length, max_command_length, "Maximum command length exceeded");
221
222 return Err(DecodeError::MaxLength {
223 actual: command_length,
224 max: max_command_length,
225 });
226 }
227 }
228
229 if src.len() < command_length {
230 src.reserve(command_length - src.len());
232
233 trace!(target: "rusmpp::codec::decode", command_length, "Not enough bytes to read the entire command");
234
235 return Ok(None);
236 }
237
238 let pdu_len = command_length - 4;
240
241 debug!(target: "rusmpp::codec::decode", decoding=?crate::formatter::Formatter(&src[..command_length]), "Decoding");
242
243 let (command, _size) = match Command::decode(&src[4..command_length], pdu_len) {
244 Ok((command, size)) => {
245 debug!(target: "rusmpp::codec::decode", command=?command, command_length, decoded_length=size, "Decoded");
246
247 (command, size)
248 }
249 Err(err) => {
250 error!(target: "rusmpp::codec::decode", ?err);
251
252 return Err(DecodeError::Decode(err));
253 }
254 };
255
256 src.advance(command_length);
257
258 Ok(Some(command))
259 }
260}