tacacs_plus_protocol/
packet.rs1use core::fmt;
2use core::iter::zip;
3
4use bitflags::bitflags;
5use byteorder::{ByteOrder, NetworkEndian};
6use getset::Getters;
7use md5::{Digest, Md5};
8use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
9
10use super::{Deserialize, PacketBody, Serialize};
11use super::{DeserializeError, SerializeError};
12
13pub(super) mod header;
14use header::HeaderInfo;
15
16#[cfg(test)]
17mod tests;
18
19#[repr(transparent)]
21#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
22pub struct PacketFlags(u8);
23
24bitflags! {
25 impl PacketFlags: u8 {
26 const UNENCRYPTED = 0b00000001;
32
33 const SINGLE_CONNECTION = 0b00000100;
35 }
36}
37
38impl fmt::Display for PacketFlags {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 if self.is_empty() {
41 write!(f, "no flags set")
42 } else {
43 for (name, _) in self.iter_names() {
44 write!(f, "{name} ")?;
45 }
46
47 Ok(())
48 }
49 }
50}
51
52#[repr(u8)]
54#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, TryFromPrimitive)]
55pub enum PacketType {
56 Authentication = 0x1,
58
59 Authorization = 0x2,
61
62 Accounting = 0x3,
64}
65
66impl fmt::Display for PacketType {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 write!(
69 f,
70 "{}",
71 match self {
72 Self::Authentication => "authentication",
73 Self::Authorization => "authorization",
74 Self::Accounting => "accounting",
75 }
76 )
77 }
78}
79
80#[doc(hidden)]
81impl From<TryFromPrimitiveError<PacketType>> for DeserializeError {
82 fn from(value: TryFromPrimitiveError<PacketType>) -> Self {
83 Self::InvalidPacketType(value.number)
84 }
85}
86
87#[derive(Clone, Debug, PartialEq, Eq, Hash, Getters)]
89#[getset(get = "pub")]
90pub struct Packet<B> {
91 header: HeaderInfo,
93
94 body: B,
96}
97
98impl<B: PacketBody> Packet<B> {
99 pub(super) const BODY_START: usize = 12;
101
102 pub fn new(mut header: HeaderInfo, body: B) -> Self {
109 if let Some(minor) = body.required_minor_version() {
111 header.version_mut().minor = minor;
112 }
113
114 Self { header, body }
115 }
116}
117
118const MD5_OUTPUT_SIZE: usize = 16;
120
121pub(super) fn xor_body_with_pad(header: &HeaderInfo, secret_key: &[u8], body_buffer: &mut [u8]) {
127 let mut pseudo_pad = [0; MD5_OUTPUT_SIZE];
128
129 let mut prefix_hasher = Md5::new();
132 prefix_hasher.update(header.session_id().to_be_bytes());
133 prefix_hasher.update(secret_key);
134
135 prefix_hasher.update(u8::from(header.version()).to_be_bytes());
137 prefix_hasher.update(header.sequence_number().to_be_bytes());
138
139 let mut chunks_iter = body_buffer.chunks_mut(MD5_OUTPUT_SIZE);
140
141 prefix_hasher
143 .clone()
144 .finalize_into((&mut pseudo_pad).into());
145
146 let first_chunk = chunks_iter.next().unwrap();
149
150 xor_slices(first_chunk, &pseudo_pad);
152
153 for chunk in chunks_iter {
154 let mut hasher = prefix_hasher.clone();
156 hasher.update(pseudo_pad);
157 hasher.finalize_into((&mut pseudo_pad).into());
158
159 xor_slices(chunk, &pseudo_pad);
161 }
162}
163
164fn xor_slices(output: &mut [u8], pseudo_pad: &[u8]) {
166 for (out, pad) in zip(output, pseudo_pad) {
167 *out ^= pad;
168 }
169}
170
171impl<B: PacketBody + Serialize> Packet<B> {
173 pub fn wire_size(&self) -> usize {
175 HeaderInfo::HEADER_SIZE_BYTES + self.body.wire_size()
176 }
177
178 pub fn serialize<K: AsRef<[u8]>>(
182 mut self,
183 secret_key: K,
184 buffer: &mut [u8],
185 ) -> Result<usize, SerializeError> {
186 self.header.flags_mut().remove(PacketFlags::UNENCRYPTED);
188
189 let packet_length = self.serialize_packet(buffer)?;
190
191 xor_body_with_pad(
192 &self.header,
193 secret_key.as_ref(),
194 &mut buffer[Self::BODY_START..packet_length],
195 );
196
197 Ok(packet_length)
198 }
199
200 pub fn serialize_unobfuscated(mut self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
208 self.header.flags_mut().insert(PacketFlags::UNENCRYPTED);
210
211 self.serialize_packet(buffer)
212 }
213
214 fn serialize_packet(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
215 let wire_size = self.wire_size();
216
217 if buffer.len() >= wire_size {
218 let body_length = self
220 .body
221 .serialize_into_buffer(&mut buffer[Self::BODY_START..wire_size])?;
222
223 let header_bytes = self.header.serialize(
225 &mut buffer[..HeaderInfo::HEADER_SIZE_BYTES],
226 B::TYPE,
227 body_length.try_into()?,
228 )?;
229
230 Ok(header_bytes + body_length)
232 } else {
233 Err(SerializeError::NotEnoughSpace)
234 }
235 }
236}
237
238impl<'raw, B: PacketBody + Deserialize<'raw>> Packet<B> {
239 pub fn deserialize<K: AsRef<[u8]>>(
244 secret_key: K,
245 buffer: &'raw mut [u8],
246 ) -> Result<Self, DeserializeError> {
247 let header = HeaderInfo::try_from(&buffer[..HeaderInfo::HEADER_SIZE_BYTES])?;
248
249 if !header.flags().contains(PacketFlags::UNENCRYPTED) {
251 xor_body_with_pad(
252 &header,
253 secret_key.as_ref(),
254 &mut buffer[Self::BODY_START..],
255 );
256
257 let body = Self::deserialize_body(buffer)?;
258
259 Ok(Self::new(header, body))
260 } else {
261 Err(DeserializeError::IncorrectUnencryptedFlag)
262 }
263 }
264
265 pub fn deserialize_unobfuscated(buffer: &'raw [u8]) -> Result<Self, DeserializeError> {
270 let header = HeaderInfo::try_from(&buffer[..HeaderInfo::HEADER_SIZE_BYTES])?;
271
272 if header.flags().contains(PacketFlags::UNENCRYPTED) {
274 let body = Self::deserialize_body(buffer)?;
275 Ok(Self::new(header, body))
276 } else {
277 Err(DeserializeError::IncorrectUnencryptedFlag)
278 }
279 }
280
281 fn deserialize_body(buffer: &'raw [u8]) -> Result<B, DeserializeError> {
282 if buffer.len() > HeaderInfo::HEADER_SIZE_BYTES {
283 let actual_packet_type = PacketType::try_from(buffer[1])?;
284 if actual_packet_type == B::TYPE {
285 let body_length = NetworkEndian::read_u32(&buffer[8..12]) as usize;
287
288 if buffer[Self::BODY_START..].len() >= body_length {
291 let body = B::deserialize_from_buffer(
292 &buffer[Self::BODY_START..Self::BODY_START + body_length],
293 )?;
294 Ok(body)
295 } else {
296 Err(DeserializeError::UnexpectedEnd)
297 }
298 } else {
299 Err(DeserializeError::PacketTypeMismatch {
300 expected: B::TYPE,
301 actual: actual_packet_type,
302 })
303 }
304 } else {
305 Err(DeserializeError::UnexpectedEnd)
306 }
307 }
308
309 #[cfg(feature = "std")]
311 pub fn to_owned<'b, O: super::owned::FromBorrowedBody<Borrowed<'b> = B>>(&self) -> Packet<O> {
312 Packet {
313 header: self.header,
314 body: O::from_borrowed(&self.body),
315 }
316 }
317}