tacacs_plus_protocol/
arguments.rs1use core::fmt;
2use core::iter::zip;
3
4use getset::{CopyGetters, Getters, Setters};
5
6use super::{DeserializeError, SerializeError};
7use crate::FieldText;
8
9#[cfg(test)]
10mod tests;
11
12#[derive(Clone, Default, PartialEq, Eq, Debug, Hash, Getters, CopyGetters, Setters)]
14#[getset(set = "pub")]
15pub struct Argument<'data> {
16 #[getset(get = "pub")]
18 name: FieldText<'data>,
19
20 #[getset(get = "pub")]
22 value: FieldText<'data>,
23
24 #[getset(get_copy = "pub")]
26 mandatory: bool,
27}
28
29impl fmt::Display for Argument<'_> {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(f, "{}{}{}", self.name, self.delimiter(), self.value)
33 }
34}
35
36#[derive(Debug, PartialEq, Eq)]
38pub enum InvalidArgument {
39 EmptyName,
41
42 NameContainsDelimiter,
44
45 NoDelimiter,
47
48 TooLong,
50
51 BadText,
55}
56
57impl fmt::Display for InvalidArgument {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 match self {
60 Self::EmptyName => write!(f, "arguments cannot have empty names"),
61 Self::NameContainsDelimiter => write!(
62 f,
63 "names cannot contain value delimiter characters (= or *)"
64 ),
65 Self::NoDelimiter => write!(f, "encoded argument value had no delimiter"),
66 Self::TooLong => write!(f, "the total length of an argument (name + length + delimiter) must not exceed u8::MAX, for encoding reasons"),
67 Self::BadText => write!(f, "encoded argument value was not printable ASCII")
68 }
69 }
70}
71
72impl From<InvalidArgument> for DeserializeError {
73 fn from(value: InvalidArgument) -> Self {
74 Self::InvalidArgument(value)
75 }
76}
77
78impl<'data> Argument<'data> {
79 const MANDATORY_DELIMITER: char = '=';
81
82 const OPTIONAL_DELIMITER: char = '*';
84
85 pub fn new(
87 name: FieldText<'data>,
88 value: FieldText<'data>,
89 mandatory: bool,
90 ) -> Result<Self, InvalidArgument> {
91 if name.is_empty() {
94 Err(InvalidArgument::EmptyName)
96 } else if name.contains_any(&[Self::MANDATORY_DELIMITER, Self::OPTIONAL_DELIMITER]) {
97 Err(InvalidArgument::NameContainsDelimiter)
99 } else if u8::try_from(name.len() + 1 + value.len()).is_err() {
100 Err(InvalidArgument::TooLong)
102 } else {
103 Ok(Argument {
104 name,
105 value,
106 mandatory,
107 })
108 }
109 }
110
111 #[cfg(feature = "std")]
113 pub fn into_owned<'out>(self) -> Argument<'out> {
114 Argument {
115 name: self.name.into_owned(),
116 value: self.value.into_owned(),
117 mandatory: self.mandatory,
118 }
119 }
120
121 fn encoded_length(&self) -> u8 {
123 (self.name.len() + 1 + self.value.len()).try_into().unwrap()
126 }
127
128 fn serialize(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
130 let name_len = self.name.len();
131 let value_len = self.value.len();
132
133 let delimiter_index = name_len;
135
136 let encoded_len = name_len + 1 + value_len;
138
139 if buffer.len() >= encoded_len {
141 buffer[..delimiter_index].copy_from_slice(self.name.as_bytes());
142
143 buffer[delimiter_index] = self.delimiter() as u8;
145
146 buffer[delimiter_index + 1..encoded_len].copy_from_slice(self.value.as_bytes());
148
149 Ok(encoded_len)
150 } else {
151 Err(SerializeError::NotEnoughSpace)
152 }
153 }
154
155 fn delimiter(&self) -> char {
158 if self.mandatory {
159 Self::MANDATORY_DELIMITER
160 } else {
161 Self::OPTIONAL_DELIMITER
162 }
163 }
164
165 pub(super) fn deserialize(buffer: &'data [u8]) -> Result<Self, InvalidArgument> {
167 let equals_index = buffer.iter().position(|c| *c == b'=');
169 let star_index = buffer.iter().position(|c| *c == b'*');
170
171 let delimiter_index = match (equals_index, star_index) {
173 (None, star) => star,
174 (equals, None) => equals,
175 (Some(equals), Some(star)) => Some(equals.min(star)),
176 }
177 .ok_or(InvalidArgument::NoDelimiter)?;
178
179 let required = buffer[delimiter_index] == Self::MANDATORY_DELIMITER as u8;
181
182 let name = FieldText::try_from(&buffer[..delimiter_index])
184 .map_err(|_| InvalidArgument::BadText)?;
185 let value = FieldText::try_from(&buffer[delimiter_index + 1..])
186 .map_err(|_| InvalidArgument::BadText)?;
187
188 Self::new(name, value, required)
190 }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
195pub struct Arguments<'args>(&'args [Argument<'args>]);
196
197impl<'args> Arguments<'args> {
198 pub fn new<T: AsRef<[Argument<'args>]>>(arguments: &'args T) -> Option<Self> {
202 if u8::try_from(arguments.as_ref().len()).is_ok() {
203 Some(Self(arguments.as_ref()))
204 } else {
205 None
206 }
207 }
208
209 pub fn argument_count(&self) -> u8 {
211 self.0.len().try_into().unwrap()
213 }
214
215 pub(super) fn wire_size(&self) -> usize {
217 let argument_count = self.0.len();
218 let argument_values_len: usize = self
219 .0
220 .iter()
221 .map(|argument| argument.encoded_length() as usize)
222 .sum();
223
224 1 + argument_count + argument_values_len
226 }
227
228 pub(super) fn serialize_count_and_lengths(
230 &self,
231 buffer: &mut [u8],
232 ) -> Result<usize, SerializeError> {
233 let argument_count = self.argument_count();
234
235 if buffer.len() > argument_count as usize {
237 buffer[0] = argument_count;
238
239 for (position, argument) in zip(&mut buffer[1..1 + argument_count as usize], self.0) {
241 *position = argument.encoded_length();
242 }
243
244 Ok(1 + argument_count as usize)
246 } else {
247 Err(SerializeError::NotEnoughSpace)
248 }
249 }
250
251 pub(super) fn serialize_encoded_values(
253 &self,
254 buffer: &mut [u8],
255 ) -> Result<usize, SerializeError> {
256 let full_encoded_length = self
257 .0
258 .iter()
259 .map(|argument| argument.encoded_length() as usize)
260 .sum();
261
262 if buffer.len() >= full_encoded_length {
263 let mut argument_start = 0;
264 let mut total_written = 0;
265
266 for argument in self.0.iter() {
267 let argument_length = argument.encoded_length() as usize;
268 let next_argument_start = argument_start + argument_length;
269 let written_length =
270 argument.serialize(&mut buffer[argument_start..next_argument_start])?;
271
272 argument_start = next_argument_start;
274
275 total_written += written_length;
278 }
279
280 if total_written != full_encoded_length {
283 Err(SerializeError::LengthMismatch {
284 expected: full_encoded_length,
285 actual: total_written,
286 })
287 } else {
288 Ok(total_written)
289 }
290 } else {
291 Err(SerializeError::NotEnoughSpace)
292 }
293 }
294}
295
296impl<'args> AsRef<[Argument<'args>]> for Arguments<'args> {
297 fn as_ref(&self) -> &[Argument<'args>] {
298 self.0
299 }
300}