tacacs_plus_protocol/
authorization.rs

1//! Authorization features/packets of the TACACS+ protocol.
2
3use core::fmt;
4
5use byteorder::{ByteOrder, NetworkEndian};
6use getset::Getters;
7use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
8
9use super::{
10    Argument, Arguments, AuthenticationContext, AuthenticationMethod, DeserializeError,
11    InvalidArgument, PacketBody, PacketType, Serialize, SerializeError, UserInformation,
12};
13use crate::{Deserialize, FieldText};
14
15#[cfg(test)]
16mod tests;
17
18#[cfg(feature = "std")]
19mod owned;
20
21#[cfg(feature = "std")]
22pub use owned::ReplyOwned;
23
24/// An authorization request packet body, including arguments.
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct Request<'packet> {
27    /// Method used to authenticate to TACACS+ client.
28    method: AuthenticationMethod,
29
30    /// Other client authentication information.
31    authentication_context: AuthenticationContext,
32
33    /// Information about the user connected to the TACACS+ client.
34    user_information: UserInformation<'packet>,
35
36    /// Additional arguments to provide as part of an authorization request.
37    arguments: Arguments<'packet>,
38}
39
40impl<'packet> Request<'packet> {
41    /// Assembles an authorization request packet from its fields.
42    pub fn new(
43        method: AuthenticationMethod,
44        authentication_context: AuthenticationContext,
45        user_information: UserInformation<'packet>,
46        arguments: Arguments<'packet>,
47    ) -> Self {
48        Self {
49            method,
50            authentication_context,
51            user_information,
52            arguments,
53        }
54    }
55}
56
57impl PacketBody for Request<'_> {
58    const TYPE: PacketType = PacketType::Authorization;
59
60    // 4 extra bytes come from user information lengths (user, port, remote address) and argument count
61    const REQUIRED_FIELDS_LENGTH: usize =
62        AuthenticationMethod::WIRE_SIZE + AuthenticationContext::WIRE_SIZE + 4;
63}
64
65impl Serialize for Request<'_> {
66    fn wire_size(&self) -> usize {
67        AuthenticationMethod::WIRE_SIZE
68            + AuthenticationContext::WIRE_SIZE
69            + self.user_information.wire_size()
70            + self.arguments.wire_size()
71    }
72
73    fn serialize_into_buffer(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
74        let wire_size = self.wire_size();
75
76        if buffer.len() >= wire_size {
77            buffer[0] = self.method as u8;
78            self.authentication_context.serialize(&mut buffer[1..4]);
79            self.user_information
80                .serialize_field_lengths(&mut buffer[4..7])?;
81
82            let argument_count = self.arguments.argument_count() as usize;
83
84            // the user information fields start after all of the required fields and also the argument lengths, the latter of which take up 1 byte each
85            let user_info_start = Self::REQUIRED_FIELDS_LENGTH + argument_count;
86
87            // cap slice with wire slice to avoid overflowing beyond end of packet body
88            let user_info_written_len = self
89                .user_information
90                .serialize_field_values(&mut buffer[user_info_start..wire_size])?;
91
92            // argument lengths start at index 7, just after the argument count
93            // extra 1 added to allow room for argument count itself
94            let arguments_wire_len = self.arguments.serialize_count_and_lengths(&mut buffer[7..7 + argument_count + 1])?
95                // argument values go after all of the user information, and until the end of the packet
96                + self
97                    .arguments
98                    .serialize_encoded_values(&mut buffer[user_info_start + user_info_written_len..wire_size])?;
99
100            // NOTE: 1 is subtracted from REQUIRED_FIELDS_LENGTH since otherwise the argument count field is double counted (from Arguments::wire_size())
101            let actual_written_len =
102                (Self::REQUIRED_FIELDS_LENGTH - 1) + user_info_written_len + arguments_wire_len;
103
104            if actual_written_len == wire_size {
105                Ok(actual_written_len)
106            } else {
107                Err(SerializeError::LengthMismatch {
108                    expected: wire_size,
109                    actual: actual_written_len,
110                })
111            }
112        } else {
113            Err(SerializeError::NotEnoughSpace)
114        }
115    }
116}
117
118/// The status of an authorization operation, as returned by the server.
119#[repr(u8)]
120#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash, TryFromPrimitive)]
121pub enum Status {
122    /// Authorization passed; server may have additional arguments for the client.
123    PassAdd = 0x01,
124
125    /// Authorization passed; server provides argument values to override those provided in the request.
126    PassReplace = 0x02,
127
128    /// Authorization request was denied.
129    Fail = 0x10,
130
131    /// An error ocurred on the server.
132    Error = 0x11,
133
134    /// Forward authorization request to an alternative daemon.
135    #[deprecated = "Forwarding to an alternative daemon was deprecated in RFC 8907."]
136    Follow = 0x21,
137}
138
139impl Status {
140    /// The wire size of an authorization reply status in bytes.
141    const WIRE_SIZE: usize = 1;
142}
143
144impl fmt::Display for Status {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        write!(
147            f,
148            "{}",
149            match self {
150                Self::PassAdd => "pass, arguments added",
151                Self::PassReplace => "pass, arguments replaced",
152                Self::Fail => "fail",
153                Self::Error => "server-side error",
154                #[allow(deprecated)]
155                Self::Follow => "redirect to alternative daemon",
156            }
157        )
158    }
159}
160
161// Implementation detail for num_enum, which is why it's hidden
162#[doc(hidden)]
163impl From<TryFromPrimitiveError<Status>> for DeserializeError {
164    fn from(value: TryFromPrimitiveError<Status>) -> Self {
165        Self::InvalidStatus(value.number)
166    }
167}
168
169/// Information about a reply packet's arguments.
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
171struct ArgumentsInfo<'raw> {
172    argument_count: u8,
173    argument_lengths: &'raw [u8],
174    arguments_buffer: &'raw [u8],
175}
176
177/// The body of an authorization reply packet.
178#[derive(Debug, Clone, PartialEq, Eq, Hash, Getters)]
179pub struct Reply<'packet> {
180    /// Gets the status returned in an authorization exchange.
181    #[getset(get = "pub")]
182    status: Status,
183
184    /// Gets the message sent by the server, to be displayed to the user.
185    #[getset(get = "pub")]
186    server_message: FieldText<'packet>,
187
188    /// Gets the administrative log message returned from the server.
189    #[getset(get = "pub")]
190    data: FieldText<'packet>,
191
192    // this field not publicly exposed on purpose
193    // (used for iterating over arguments)
194    arguments_info: ArgumentsInfo<'packet>,
195}
196
197/// The non-argument field lengths of a (raw) authorization reply packet, as well as its total length.
198struct ReplyFieldLengths {
199    data_length: u16,
200    server_message_length: u16,
201    total_length: u32,
202}
203
204/// An iterator over the arguments in an authorization reply packet.
205#[derive(Debug, Clone)]
206pub struct ArgumentsIterator<'iter> {
207    /// Argument information, including argument count.
208    arguments_info: &'iter ArgumentsInfo<'iter>,
209
210    /// Position of the next argument, as if into a zero-indexed array of complete arguments.
211    next_argument_number: usize,
212
213    /// Offset of an argument within the buffer.
214    next_offset: usize,
215}
216
217impl<'iter> Iterator for ArgumentsIterator<'iter> {
218    type Item = Argument<'iter>;
219
220    fn next(&mut self) -> Option<Self::Item> {
221        if self.next_argument_number < self.arguments_info.argument_count as usize {
222            // get encoded argument from buffer based on stored offset into buffer/length
223            let next_length =
224                self.arguments_info.argument_lengths[self.next_argument_number] as usize;
225            let raw_argument = &self.arguments_info.arguments_buffer
226                [self.next_offset..self.next_offset + next_length];
227
228            // update iterator state
229            self.next_argument_number += 1;
230            self.next_offset += next_length;
231
232            // NOTE: this should always be Some, since the validity of arguments is checked in Reply's TryFrom impl
233            Argument::deserialize(raw_argument).ok()
234        } else {
235            None
236        }
237    }
238
239    // required for ExactSizeIterator impl
240    fn size_hint(&self) -> (usize, Option<usize>) {
241        let total_size = self.arguments_info.argument_count as usize;
242        let remaining_size = total_size - self.next_argument_number;
243
244        // these are asserted to be equal in the default ExactSizeIterator::len() implementation
245        (remaining_size, Some(remaining_size))
246    }
247}
248
249// Gives ArgumentsIterator a .len() method
250impl ExactSizeIterator for ArgumentsIterator<'_> {}
251
252impl<'packet> Reply<'packet> {
253    const ARGUMENT_LENGTHS_START: usize = 6;
254
255    /// Determines the length of a reply packet based on encoded lengths at the beginning of the packet body, if possible.
256    pub fn extract_total_length(buffer: &[u8]) -> Result<u32, DeserializeError> {
257        Self::extract_field_lengths(buffer).map(|lengths| lengths.total_length)
258    }
259
260    /// Extracts the server message and data lengths from a raw reply packet, if possible.
261    fn extract_field_lengths(buffer: &[u8]) -> Result<ReplyFieldLengths, DeserializeError> {
262        // data length is the last field in the required part of the header, so we need a full (minimal) header
263        if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH {
264            let argument_count = buffer[1];
265
266            // also ensure that all argument lengths are present
267            if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH + argument_count as usize {
268                let server_message_length = NetworkEndian::read_u16(&buffer[2..4]);
269                let data_length = NetworkEndian::read_u16(&buffer[4..6]);
270
271                let encoded_arguments_length: u32 = buffer[Self::ARGUMENT_LENGTHS_START
272                    ..Self::ARGUMENT_LENGTHS_START + argument_count as usize]
273                    .iter()
274                    .map(|&length| u32::from(length))
275                    .sum();
276
277                // SAFETY: REQUIRED_FIELDS_LENGTH is guaranteed to fit in a u32 by how it's defined
278                let total_length = u32::try_from(Self::REQUIRED_FIELDS_LENGTH).unwrap()
279                    + u32::from(argument_count) // argument lengths in "header"
280                    + u32::from(server_message_length)
281                    + u32::from(data_length)
282                    + encoded_arguments_length;
283
284                Ok(ReplyFieldLengths {
285                    data_length,
286                    server_message_length,
287                    total_length,
288                })
289            } else {
290                Err(DeserializeError::UnexpectedEnd)
291            }
292        } else {
293            Err(DeserializeError::UnexpectedEnd)
294        }
295    }
296
297    /// Ensures a list of argument lengths and their raw values represent a valid set of arguments.
298    fn ensure_arguments_valid(lengths: &[u8], values: &[u8]) -> Result<(), InvalidArgument> {
299        let mut argument_start = 0;
300
301        lengths.iter().try_fold((), |_, &length| {
302            let raw_argument = &values[argument_start..argument_start + length as usize];
303            argument_start += length as usize;
304
305            // we don't care about the actual argument here, but the specific error should be kept
306            Argument::deserialize(raw_argument).map(|_| ())
307        })
308    }
309
310    /// Returns an iterator over the arguments included in this reply packet.
311    pub fn iter_arguments(&self) -> ArgumentsIterator<'_> {
312        ArgumentsIterator {
313            arguments_info: &self.arguments_info,
314            next_argument_number: 0,
315            next_offset: 0,
316        }
317    }
318}
319
320impl PacketBody for Reply<'_> {
321    const TYPE: PacketType = PacketType::Authorization;
322
323    // 1 byte for status, 1 byte for argument count, 2 bytes each for lengths of server message/data
324    const REQUIRED_FIELDS_LENGTH: usize = Status::WIRE_SIZE + 1 + 4;
325}
326
327impl<'raw> Deserialize<'raw> for Reply<'raw> {
328    fn deserialize_from_buffer(buffer: &'raw [u8]) -> Result<Self, DeserializeError> {
329        let ReplyFieldLengths {
330            data_length,
331            server_message_length,
332            total_length,
333        } = Self::extract_field_lengths(buffer)?;
334
335        // buffer argument is sliced to proper length in Packet::deserialize_body(), so we can compare against that header length indirectly like this
336        let length_from_header = buffer.len();
337
338        if total_length as usize == length_from_header {
339            let status = Status::try_from(buffer[0])?;
340            let argument_count = buffer[1];
341
342            // figure out field offsets
343            let body_start = Self::ARGUMENT_LENGTHS_START + argument_count as usize;
344            let data_start = body_start + server_message_length as usize;
345            let arguments_start = data_start + data_length as usize;
346
347            let server_message = FieldText::try_from(&buffer[body_start..data_start])
348                .map_err(|_| DeserializeError::BadText)?;
349            let data = FieldText::try_from(&buffer[data_start..arguments_start])
350                .map_err(|_| DeserializeError::BadText)?;
351
352            // arguments occupy the rest of the buffer
353            let argument_lengths = &buffer[Self::ARGUMENT_LENGTHS_START..body_start];
354            let argument_values = &buffer[arguments_start..total_length as usize];
355
356            Self::ensure_arguments_valid(argument_lengths, argument_values)?;
357
358            // bundle some information about arguments for iterator purposes
359            let arguments_info = ArgumentsInfo {
360                argument_count,
361                argument_lengths,
362                arguments_buffer: argument_values,
363            };
364
365            Ok(Self {
366                status,
367                server_message,
368                data,
369                arguments_info,
370            })
371        } else {
372            Err(DeserializeError::WrongBodyBufferSize {
373                expected: total_length as usize,
374                buffer_size: length_from_header,
375            })
376        }
377    }
378}