tacacs_plus_protocol/
accounting.rs

1//! Accounting protocol packet (de)serialization.
2
3use bitflags::bitflags;
4use byteorder::{ByteOrder, NetworkEndian};
5use core::fmt;
6use getset::Getters;
7use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
8
9use super::{
10    Arguments, AuthenticationContext, AuthenticationMethod, Deserialize, DeserializeError,
11    PacketBody, PacketType, Serialize, SerializeError, UserInformation,
12};
13use crate::FieldText;
14
15#[cfg(test)]
16mod tests;
17
18#[cfg(feature = "std")]
19mod owned;
20
21#[cfg(feature = "std")]
22pub use owned::ReplyOwned;
23
24bitflags! {
25    /// Raw bitflags for accounting request packet.
26    struct RawFlags: u8 {
27        const START    = 0b00000010;
28        const STOP     = 0b00000100;
29        const WATCHDOG = 0b00001000;
30    }
31}
32
33/// Valid flag combinations for a TACACS+ account REQUEST packet.
34#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
35pub enum Flags {
36    /// Start of a task.
37    StartRecord,
38
39    /// Task complete.
40    StopRecord,
41
42    /// Indication that task is still running, with no extra arguments.
43    WatchdogNoUpdate,
44
45    /// Update on long-running task, including updated/new argument values.
46    WatchdogUpdate,
47}
48
49impl fmt::Display for Flags {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        write!(
52            f,
53            "{}",
54            match self {
55                Self::StartRecord => "start of record",
56                Self::StopRecord => "end of record",
57                Self::WatchdogUpdate => "update with new information",
58                Self::WatchdogNoUpdate => "update with no new information",
59            }
60        )
61    }
62}
63
64impl From<Flags> for RawFlags {
65    fn from(value: Flags) -> Self {
66        match value {
67            Flags::StartRecord => RawFlags::START,
68            Flags::StopRecord => RawFlags::STOP,
69            Flags::WatchdogNoUpdate => RawFlags::WATCHDOG,
70            Flags::WatchdogUpdate => RawFlags::WATCHDOG | RawFlags::START,
71        }
72    }
73}
74
75impl Flags {
76    /// The number of bytes occupied by a flag set on the wire.
77    pub(super) const WIRE_SIZE: usize = 1;
78}
79
80/// An accounting request packet, used to start, stop, or provide progress on a running job.
81#[derive(PartialEq, Eq, Clone, Debug, Hash)]
82pub struct Request<'packet> {
83    /// Flags to indicate what kind of accounting record this packet includes.
84    flags: Flags,
85
86    /// Method used to authenticate to TACACS+ client.
87    authentication_method: AuthenticationMethod,
88
89    /// Other information about authentication to TACACS+ client.
90    authentication: AuthenticationContext,
91
92    /// Information about the user connected to the client.
93    user_information: UserInformation<'packet>,
94
95    /// Arguments to provide additional information to the server.
96    arguments: Arguments<'packet>,
97}
98
99impl<'packet> Request<'packet> {
100    /// Argument lengths in a request packet start at index 9, if present.
101    const ARGUMENT_LENGTHS_OFFSET: usize = 9;
102
103    /// Assembles a new accounting request packet body.
104    pub fn new(
105        flags: Flags,
106        authentication_method: AuthenticationMethod,
107        authentication: AuthenticationContext,
108        user_information: UserInformation<'packet>,
109        arguments: Arguments<'packet>,
110    ) -> Self {
111        Self {
112            flags,
113            authentication_method,
114            authentication,
115            user_information,
116            arguments,
117        }
118    }
119}
120
121impl PacketBody for Request<'_> {
122    const TYPE: PacketType = PacketType::Accounting;
123
124    // 4 extra bytes come from user information lengths (user, port, remote address) & argument count
125    const REQUIRED_FIELDS_LENGTH: usize =
126        Flags::WIRE_SIZE + AuthenticationMethod::WIRE_SIZE + AuthenticationContext::WIRE_SIZE + 4;
127}
128
129impl Serialize for Request<'_> {
130    fn wire_size(&self) -> usize {
131        Flags::WIRE_SIZE
132            + AuthenticationMethod::WIRE_SIZE
133            + AuthenticationContext::WIRE_SIZE
134            + self.user_information.wire_size()
135            + self.arguments.wire_size()
136    }
137
138    fn serialize_into_buffer(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
139        let wire_size = self.wire_size();
140
141        if buffer.len() >= wire_size {
142            buffer[0] = RawFlags::from(self.flags).bits();
143            buffer[1] = self.authentication_method as u8;
144
145            // header information (lengths, etc.)
146            self.authentication.serialize(&mut buffer[2..5]);
147            self.user_information
148                .serialize_field_lengths(&mut buffer[5..8])?;
149
150            let argument_count = self.arguments.argument_count() as usize;
151
152            // body starts after the required fields & the argument lengths (1 byte per argument)
153            let body_start = Self::ARGUMENT_LENGTHS_OFFSET + argument_count;
154
155            // actual request content
156            // as below, slice bounds are capped to end of packet body to avoid overflowing
157            let user_information_len = self
158                .user_information
159                .serialize_field_values(&mut buffer[body_start..wire_size])?;
160
161            let arguments_serialized_len =
162                // argument lengths start at index 8
163                // extra byte is included in slice for argument count itself
164                self.arguments.serialize_count_and_lengths(&mut buffer[8..8 + argument_count + 1])?
165                    // argument values go after the user information values in the body
166                    + self
167                        .arguments
168                        .serialize_encoded_values(&mut buffer[body_start + user_information_len..wire_size])?;
169
170            // NOTE: as with authorization, 1 is subtracted from REQUIRED_FIELDS_LENGTH as the argument count would be double counted otherwise
171            let actual_written_len = (Self::REQUIRED_FIELDS_LENGTH - 1)
172                + user_information_len
173                + arguments_serialized_len;
174
175            // ensure expected/actual sizes match
176            if actual_written_len == wire_size {
177                Ok(actual_written_len)
178            } else {
179                Err(SerializeError::LengthMismatch {
180                    expected: wire_size,
181                    actual: actual_written_len,
182                })
183            }
184        } else {
185            Err(SerializeError::NotEnoughSpace)
186        }
187    }
188}
189
190/// The server's reply status in an accounting session.
191#[repr(u8)]
192#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, TryFromPrimitive)]
193pub enum Status {
194    /// Task logging succeeded.
195    Success = 0x01,
196
197    /// Something went wrong when logging the task.
198    Error = 0x02,
199
200    /// Forward accounting request to an alternative daemon.
201    #[deprecated = "Forwarding to an alternative daemon was deprecated in RFC-8907."]
202    Follow = 0x21,
203}
204
205impl Status {
206    /// The number of bytes an accounting reply status occupies on the wire.
207    pub(super) const WIRE_SIZE: usize = 1;
208}
209
210impl fmt::Display for Status {
211    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212        write!(
213            f,
214            "{}",
215            match self {
216                Self::Success => "success",
217                Self::Error => "error",
218                #[allow(deprecated)]
219                Self::Follow => "follow",
220            }
221        )
222    }
223}
224
225#[doc(hidden)]
226impl From<TryFromPrimitiveError<Status>> for DeserializeError {
227    fn from(value: TryFromPrimitiveError<Status>) -> Self {
228        Self::InvalidStatus(value.number)
229    }
230}
231
232/// An accounting reply packet received from a TACACS+ server.
233#[derive(Clone, PartialEq, Eq, Debug, Hash, Getters)]
234pub struct Reply<'packet> {
235    /// Gets the status of an accounting reply.
236    #[getset(get = "pub")]
237    status: Status,
238
239    /// Gets the server message, which may be presented to a user connected to a client.
240    #[getset(get = "pub")]
241    server_message: FieldText<'packet>,
242
243    /// Gets the administrative/log data received from the server.
244    #[getset(get = "pub")]
245    data: FieldText<'packet>,
246}
247
248/// Field lengths of a reply packet as well as the total length.
249struct ReplyFieldLengths {
250    server_message_length: u16,
251    data_length: u16,
252    total_length: u32,
253}
254
255impl Reply<'_> {
256    /// Offset of the server message in an accounting reply packet body, if present.
257    const SERVER_MESSAGE_OFFSET: usize = 5;
258
259    /// Determines how long a raw reply packet is, if applicable, based on various lengths stored in the body "header."
260    pub fn extract_total_length(buffer: &[u8]) -> Result<u32, DeserializeError> {
261        if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH {
262            Self::extract_field_lengths(buffer).map(|lengths| lengths.total_length)
263        } else {
264            Err(DeserializeError::UnexpectedEnd)
265        }
266    }
267
268    /// Extracts the server message and data field lengths from a buffer, treating it as if it were a serialized reply packet body.
269    fn extract_field_lengths(buffer: &[u8]) -> Result<ReplyFieldLengths, DeserializeError> {
270        // ensure buffer is large enough to comprise a valid reply packet
271        if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH {
272            // server message length is at the beginning of the packet
273            let server_message_length = NetworkEndian::read_u16(&buffer[..2]);
274
275            // data length is just after the server message length
276            let data_length = NetworkEndian::read_u16(&buffer[2..4]);
277
278            // full packet has required fields/lengths as well as the field values themselves
279            // SAFETY: REQUIRED_FIELDS_LENGTH is guaranteed to fit in a u32 based on its defined value
280            let total_length = u32::try_from(Self::REQUIRED_FIELDS_LENGTH).unwrap()
281                + u32::from(server_message_length)
282                + u32::from(data_length);
283
284            Ok(ReplyFieldLengths {
285                server_message_length,
286                data_length,
287                total_length,
288            })
289        } else {
290            Err(DeserializeError::UnexpectedEnd)
291        }
292    }
293}
294
295impl PacketBody for Reply<'_> {
296    const TYPE: PacketType = PacketType::Accounting;
297
298    // 4 extra bytes are 2 bytes each for lengths of server message/data
299    const REQUIRED_FIELDS_LENGTH: usize = Status::WIRE_SIZE + 4;
300}
301
302impl<'raw> Deserialize<'raw> for Reply<'raw> {
303    fn deserialize_from_buffer(buffer: &'raw [u8]) -> Result<Self, DeserializeError> {
304        let extracted_lengths = Self::extract_field_lengths(buffer)?;
305
306        // the provided buffer is sliced to the length reported in the packet header in Packet::deserialize_body(),
307        // so we can compare against it this way
308        let length_from_header = buffer.len();
309
310        // ensure buffer length & calculated length from body fields match
311        if extracted_lengths.total_length as usize == length_from_header {
312            // SAFETY: extract_field_lengths() performs a check against REQUIRED_FIELDS_LENGTH (5), so this will not panic
313            let status = Status::try_from(buffer[4])?;
314
315            let data_offset =
316                Self::SERVER_MESSAGE_OFFSET + extracted_lengths.server_message_length as usize;
317
318            let server_message =
319                FieldText::try_from(&buffer[Self::SERVER_MESSAGE_OFFSET..data_offset])
320                    .map_err(|_| DeserializeError::BadText)?;
321            let data = FieldText::try_from(
322                &buffer[data_offset..data_offset + extracted_lengths.data_length as usize],
323            )
324            .map_err(|_| DeserializeError::BadText)?;
325
326            Ok(Self {
327                status,
328                server_message,
329                data,
330            })
331        } else {
332            Err(DeserializeError::WrongBodyBufferSize {
333                expected: extracted_lengths.total_length as usize,
334                buffer_size: length_from_header,
335            })
336        }
337    }
338}