tacacs_plus_protocol/
arguments.rs

1use core::fmt;
2use core::iter::zip;
3
4use getset::{CopyGetters, Getters, Setters};
5
6use super::{DeserializeError, SerializeError};
7use crate::FieldText;
8
9#[cfg(test)]
10mod tests;
11
12/// An argument in the TACACS+ protocol, which exists for extensibility.
13#[derive(Clone, Default, PartialEq, Eq, Debug, Hash, Getters, CopyGetters, Setters)]
14#[getset(set = "pub")]
15pub struct Argument<'data> {
16    /// The name of the argument.
17    #[getset(get = "pub")]
18    name: FieldText<'data>,
19
20    /// The value of the argument.
21    #[getset(get = "pub")]
22    value: FieldText<'data>,
23
24    /// Whether processing this argument is mandatory.
25    #[getset(get_copy = "pub")]
26    mandatory: bool,
27}
28
29impl fmt::Display for Argument<'_> {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        // just write as encoded form (name + delimiter + value)
32        write!(f, "{}{}{}", self.name, self.delimiter(), self.value)
33    }
34}
35
36/// Error to determine
37#[derive(Debug, PartialEq, Eq)]
38pub enum InvalidArgument {
39    /// Argument had empty name.
40    EmptyName,
41
42    /// Argument name contained a delimiter (= or *).
43    NameContainsDelimiter,
44
45    /// Argument encoding did not contain a delimiter.
46    NoDelimiter,
47
48    /// Argument was too long to be encodeable.
49    TooLong,
50
51    /// Argument wasn't valid printable ASCII, as specified in [RFC8907 section 3.7].
52    ///
53    /// [RFC8907 section 3.7]: https://www.rfc-editor.org/rfc/rfc8907.html#section-6.1-18
54    BadText,
55}
56
57impl fmt::Display for InvalidArgument {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        match self {
60            Self::EmptyName => write!(f, "arguments cannot have empty names"),
61            Self::NameContainsDelimiter => write!(
62                f,
63                "names cannot contain value delimiter characters (= or *)"
64            ),
65            Self::NoDelimiter => write!(f, "encoded argument value had no delimiter"),
66            Self::TooLong => write!(f, "the total length of an argument (name + length + delimiter) must not exceed u8::MAX, for encoding reasons"),
67            Self::BadText => write!(f, "encoded argument value was not printable ASCII")
68        }
69    }
70}
71
72impl From<InvalidArgument> for DeserializeError {
73    fn from(value: InvalidArgument) -> Self {
74        Self::InvalidArgument(value)
75    }
76}
77
78impl<'data> Argument<'data> {
79    /// The delimiter used for a required argument.
80    const MANDATORY_DELIMITER: char = '=';
81
82    /// The delimiter used for an optional argument.
83    const OPTIONAL_DELIMITER: char = '*';
84
85    /// Constructs an argument, enforcing a maximum combined name + value + delimiter length of `u8::MAX` (as it must fit in a single byte for encoding reasons).
86    pub fn new(
87        name: FieldText<'data>,
88        value: FieldText<'data>,
89        mandatory: bool,
90    ) -> Result<Self, InvalidArgument> {
91        // NOTE: since both name/value are already `FieldText`s, we don't have to check if they are ASCII
92
93        if name.is_empty() {
94            // name must be nonempty (?)
95            Err(InvalidArgument::EmptyName)
96        } else if name.contains_any(&[Self::MANDATORY_DELIMITER, Self::OPTIONAL_DELIMITER]) {
97            // "An argument name MUST NOT contain either of the separators." [RFC 8907]
98            Err(InvalidArgument::NameContainsDelimiter)
99        } else if u8::try_from(name.len() + 1 + value.len()).is_err() {
100            // length of encoded argument (i.e., including delimiter) must also fit in a u8 to be encodeable
101            Err(InvalidArgument::TooLong)
102        } else {
103            Ok(Argument {
104                name,
105                value,
106                mandatory,
107            })
108        }
109    }
110
111    /// Converts this `Argument` to one which owns its fields.
112    #[cfg(feature = "std")]
113    pub fn into_owned<'out>(self) -> Argument<'out> {
114        Argument {
115            name: self.name.into_owned(),
116            value: self.value.into_owned(),
117            mandatory: self.mandatory,
118        }
119    }
120
121    /// The encoded length of an argument, including the name/value/delimiter but not the byte holding its length earlier on in a packet.
122    fn encoded_length(&self) -> u8 {
123        // SAFETY: this should never panic due to length checks in new()
124        // length includes delimiter
125        (self.name.len() + 1 + self.value.len()).try_into().unwrap()
126    }
127
128    /// Serializes an argument's name-value encoding, as done in the body of a packet.
129    fn serialize(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
130        let name_len = self.name.len();
131        let value_len = self.value.len();
132
133        // delimiter is placed just after name, meaning its index is exactly the name length
134        let delimiter_index = name_len;
135
136        // name + value + 1 extra byte for delimiter
137        let encoded_len = name_len + 1 + value_len;
138
139        // buffer must be large enough to hold name, value, and delimiter
140        if buffer.len() >= encoded_len {
141            buffer[..delimiter_index].copy_from_slice(self.name.as_bytes());
142
143            // choose delimiter based on whether argument is required
144            buffer[delimiter_index] = self.delimiter() as u8;
145
146            // value goes just after delimiter
147            buffer[delimiter_index + 1..encoded_len].copy_from_slice(self.value.as_bytes());
148
149            Ok(encoded_len)
150        } else {
151            Err(SerializeError::NotEnoughSpace)
152        }
153    }
154
155    /// Returns the delimiter that will be used for this argument when it's encoded on the wire,
156    /// based on whether it's mandatory or not.
157    fn delimiter(&self) -> char {
158        if self.mandatory {
159            Self::MANDATORY_DELIMITER
160        } else {
161            Self::OPTIONAL_DELIMITER
162        }
163    }
164
165    /// Attempts to deserialize a packet from its name-value encoding on the wire.
166    pub(super) fn deserialize(buffer: &'data [u8]) -> Result<Self, InvalidArgument> {
167        // note: these are guaranteed to be unequal, since a single index cannot contain two characters at once
168        let equals_index = buffer.iter().position(|c| *c == b'=');
169        let star_index = buffer.iter().position(|c| *c == b'*');
170
171        // determine first delimiter that appears, which is the actual delimiter as names MUST NOT (RFC 8907) contain either delimiter character
172        let delimiter_index = match (equals_index, star_index) {
173            (None, star) => star,
174            (equals, None) => equals,
175            (Some(equals), Some(star)) => Some(equals.min(star)),
176        }
177        .ok_or(InvalidArgument::NoDelimiter)?;
178
179        // at this point, delimiter_index was non-None and must contain one of {*, =}
180        let required = buffer[delimiter_index] == Self::MANDATORY_DELIMITER as u8;
181
182        // ensure name/value are valid text values per RFC 8907 (i.e., fully printable ASCII)
183        let name = FieldText::try_from(&buffer[..delimiter_index])
184            .map_err(|_| InvalidArgument::BadText)?;
185        let value = FieldText::try_from(&buffer[delimiter_index + 1..])
186            .map_err(|_| InvalidArgument::BadText)?;
187
188        // use constructor here to perform checks on fields to avoid diverging code paths
189        Self::new(name, value, required)
190    }
191}
192
193/// A set of arguments known to be of valid length for use in a TACACS+ packet.
194#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
195pub struct Arguments<'args>(&'args [Argument<'args>]);
196
197impl<'args> Arguments<'args> {
198    /// Constructs a new `Arguments`, returning `Some` if the provided slice has less than `u8::MAX` and None otherwise.
199    ///
200    /// The `u8::MAX` restriction is due to the argument count being required to fit into a single byte when encoding.
201    pub fn new<T: AsRef<[Argument<'args>]>>(arguments: &'args T) -> Option<Self> {
202        if u8::try_from(arguments.as_ref().len()).is_ok() {
203            Some(Self(arguments.as_ref()))
204        } else {
205            None
206        }
207    }
208
209    /// Returns the number of arguments an `Arguments` object contains.
210    pub fn argument_count(&self) -> u8 {
211        // SAFETY: this should not panic as the argument count is verified to fit in a u8 in the constructor
212        self.0.len().try_into().unwrap()
213    }
214
215    /// Returns the size of this set of arguments on the wire, including encoded values as well as lengths & the argument count.
216    pub(super) fn wire_size(&self) -> usize {
217        let argument_count = self.0.len();
218        let argument_values_len: usize = self
219            .0
220            .iter()
221            .map(|argument| argument.encoded_length() as usize)
222            .sum();
223
224        // number of arguments itself takes up extra byte when serializing
225        1 + argument_count + argument_values_len
226    }
227
228    /// Serializes the argument count & lengths of the stored arguments into a buffer.
229    pub(super) fn serialize_count_and_lengths(
230        &self,
231        buffer: &mut [u8],
232    ) -> Result<usize, SerializeError> {
233        let argument_count = self.argument_count();
234
235        // strict greater than to allow room for encoded argument count itself
236        if buffer.len() > argument_count as usize {
237            buffer[0] = argument_count;
238
239            // fill in argument lengths after argument count
240            for (position, argument) in zip(&mut buffer[1..1 + argument_count as usize], self.0) {
241                *position = argument.encoded_length();
242            }
243
244            // total bytes written: number of arguments + one extra byte for argument count itself
245            Ok(1 + argument_count as usize)
246        } else {
247            Err(SerializeError::NotEnoughSpace)
248        }
249    }
250
251    /// Serializes the stored arguments in their proper encoding to a buffer.
252    pub(super) fn serialize_encoded_values(
253        &self,
254        buffer: &mut [u8],
255    ) -> Result<usize, SerializeError> {
256        let full_encoded_length = self
257            .0
258            .iter()
259            .map(|argument| argument.encoded_length() as usize)
260            .sum();
261
262        if buffer.len() >= full_encoded_length {
263            let mut argument_start = 0;
264            let mut total_written = 0;
265
266            for argument in self.0.iter() {
267                let argument_length = argument.encoded_length() as usize;
268                let next_argument_start = argument_start + argument_length;
269                let written_length =
270                    argument.serialize(&mut buffer[argument_start..next_argument_start])?;
271
272                // update loop state
273                argument_start = next_argument_start;
274
275                // this is technically redundant with the initial full_encoded_length calculation above
276                // but better to be safe than sorry right?
277                total_written += written_length;
278            }
279
280            // this case shouldn't happen since argument serialization is basically just direct slice copying
281            // but on the off chance that it does this makes it easier to debug
282            if total_written != full_encoded_length {
283                Err(SerializeError::LengthMismatch {
284                    expected: full_encoded_length,
285                    actual: total_written,
286                })
287            } else {
288                Ok(total_written)
289            }
290        } else {
291            Err(SerializeError::NotEnoughSpace)
292        }
293    }
294}
295
296impl<'args> AsRef<[Argument<'args>]> for Arguments<'args> {
297    fn as_ref(&self) -> &[Argument<'args>] {
298        self.0
299    }
300}