1use core::fmt;
4
5use byteorder::{ByteOrder, NetworkEndian};
6use getset::Getters;
7use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
8
9use super::{
10 Argument, Arguments, AuthenticationContext, AuthenticationMethod, DeserializeError,
11 InvalidArgument, PacketBody, PacketType, Serialize, SerializeError, UserInformation,
12};
13use crate::{Deserialize, FieldText};
14
15#[cfg(test)]
16mod tests;
17
18#[cfg(feature = "std")]
19mod owned;
20
21#[cfg(feature = "std")]
22pub use owned::ReplyOwned;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct Request<'packet> {
27 method: AuthenticationMethod,
29
30 authentication_context: AuthenticationContext,
32
33 user_information: UserInformation<'packet>,
35
36 arguments: Arguments<'packet>,
38}
39
40impl<'packet> Request<'packet> {
41 pub fn new(
43 method: AuthenticationMethod,
44 authentication_context: AuthenticationContext,
45 user_information: UserInformation<'packet>,
46 arguments: Arguments<'packet>,
47 ) -> Self {
48 Self {
49 method,
50 authentication_context,
51 user_information,
52 arguments,
53 }
54 }
55}
56
57impl PacketBody for Request<'_> {
58 const TYPE: PacketType = PacketType::Authorization;
59
60 const REQUIRED_FIELDS_LENGTH: usize =
62 AuthenticationMethod::WIRE_SIZE + AuthenticationContext::WIRE_SIZE + 4;
63}
64
65impl Serialize for Request<'_> {
66 fn wire_size(&self) -> usize {
67 AuthenticationMethod::WIRE_SIZE
68 + AuthenticationContext::WIRE_SIZE
69 + self.user_information.wire_size()
70 + self.arguments.wire_size()
71 }
72
73 fn serialize_into_buffer(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
74 let wire_size = self.wire_size();
75
76 if buffer.len() >= wire_size {
77 buffer[0] = self.method as u8;
78 self.authentication_context.serialize(&mut buffer[1..4]);
79 self.user_information
80 .serialize_field_lengths(&mut buffer[4..7])?;
81
82 let argument_count = self.arguments.argument_count() as usize;
83
84 let user_info_start = Self::REQUIRED_FIELDS_LENGTH + argument_count;
86
87 let user_info_written_len = self
89 .user_information
90 .serialize_field_values(&mut buffer[user_info_start..wire_size])?;
91
92 let arguments_wire_len = self.arguments.serialize_count_and_lengths(&mut buffer[7..7 + argument_count + 1])?
95 + self
97 .arguments
98 .serialize_encoded_values(&mut buffer[user_info_start + user_info_written_len..wire_size])?;
99
100 let actual_written_len =
102 (Self::REQUIRED_FIELDS_LENGTH - 1) + user_info_written_len + arguments_wire_len;
103
104 if actual_written_len == wire_size {
105 Ok(actual_written_len)
106 } else {
107 Err(SerializeError::LengthMismatch {
108 expected: wire_size,
109 actual: actual_written_len,
110 })
111 }
112 } else {
113 Err(SerializeError::NotEnoughSpace)
114 }
115 }
116}
117
118#[repr(u8)]
120#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash, TryFromPrimitive)]
121pub enum Status {
122 PassAdd = 0x01,
124
125 PassReplace = 0x02,
127
128 Fail = 0x10,
130
131 Error = 0x11,
133
134 #[deprecated = "Forwarding to an alternative daemon was deprecated in RFC 8907."]
136 Follow = 0x21,
137}
138
139impl Status {
140 const WIRE_SIZE: usize = 1;
142}
143
144impl fmt::Display for Status {
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 write!(
147 f,
148 "{}",
149 match self {
150 Self::PassAdd => "pass, arguments added",
151 Self::PassReplace => "pass, arguments replaced",
152 Self::Fail => "fail",
153 Self::Error => "server-side error",
154 #[allow(deprecated)]
155 Self::Follow => "redirect to alternative daemon",
156 }
157 )
158 }
159}
160
161#[doc(hidden)]
163impl From<TryFromPrimitiveError<Status>> for DeserializeError {
164 fn from(value: TryFromPrimitiveError<Status>) -> Self {
165 Self::InvalidStatus(value.number)
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
171struct ArgumentsInfo<'raw> {
172 argument_count: u8,
173 argument_lengths: &'raw [u8],
174 arguments_buffer: &'raw [u8],
175}
176
177#[derive(Debug, Clone, PartialEq, Eq, Hash, Getters)]
179pub struct Reply<'packet> {
180 #[getset(get = "pub")]
182 status: Status,
183
184 #[getset(get = "pub")]
186 server_message: FieldText<'packet>,
187
188 #[getset(get = "pub")]
190 data: FieldText<'packet>,
191
192 arguments_info: ArgumentsInfo<'packet>,
195}
196
197struct ReplyFieldLengths {
199 data_length: u16,
200 server_message_length: u16,
201 total_length: u32,
202}
203
204#[derive(Debug, Clone)]
206pub struct ArgumentsIterator<'iter> {
207 arguments_info: &'iter ArgumentsInfo<'iter>,
209
210 next_argument_number: usize,
212
213 next_offset: usize,
215}
216
217impl<'iter> Iterator for ArgumentsIterator<'iter> {
218 type Item = Argument<'iter>;
219
220 fn next(&mut self) -> Option<Self::Item> {
221 if self.next_argument_number < self.arguments_info.argument_count as usize {
222 let next_length =
224 self.arguments_info.argument_lengths[self.next_argument_number] as usize;
225 let raw_argument = &self.arguments_info.arguments_buffer
226 [self.next_offset..self.next_offset + next_length];
227
228 self.next_argument_number += 1;
230 self.next_offset += next_length;
231
232 Argument::deserialize(raw_argument).ok()
234 } else {
235 None
236 }
237 }
238
239 fn size_hint(&self) -> (usize, Option<usize>) {
241 let total_size = self.arguments_info.argument_count as usize;
242 let remaining_size = total_size - self.next_argument_number;
243
244 (remaining_size, Some(remaining_size))
246 }
247}
248
249impl ExactSizeIterator for ArgumentsIterator<'_> {}
251
252impl<'packet> Reply<'packet> {
253 const ARGUMENT_LENGTHS_START: usize = 6;
254
255 pub fn extract_total_length(buffer: &[u8]) -> Result<u32, DeserializeError> {
257 Self::extract_field_lengths(buffer).map(|lengths| lengths.total_length)
258 }
259
260 fn extract_field_lengths(buffer: &[u8]) -> Result<ReplyFieldLengths, DeserializeError> {
262 if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH {
264 let argument_count = buffer[1];
265
266 if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH + argument_count as usize {
268 let server_message_length = NetworkEndian::read_u16(&buffer[2..4]);
269 let data_length = NetworkEndian::read_u16(&buffer[4..6]);
270
271 let encoded_arguments_length: u32 = buffer[Self::ARGUMENT_LENGTHS_START
272 ..Self::ARGUMENT_LENGTHS_START + argument_count as usize]
273 .iter()
274 .map(|&length| u32::from(length))
275 .sum();
276
277 let total_length = u32::try_from(Self::REQUIRED_FIELDS_LENGTH).unwrap()
279 + u32::from(argument_count) + u32::from(server_message_length)
281 + u32::from(data_length)
282 + encoded_arguments_length;
283
284 Ok(ReplyFieldLengths {
285 data_length,
286 server_message_length,
287 total_length,
288 })
289 } else {
290 Err(DeserializeError::UnexpectedEnd)
291 }
292 } else {
293 Err(DeserializeError::UnexpectedEnd)
294 }
295 }
296
297 fn ensure_arguments_valid(lengths: &[u8], values: &[u8]) -> Result<(), InvalidArgument> {
299 let mut argument_start = 0;
300
301 lengths.iter().try_fold((), |_, &length| {
302 let raw_argument = &values[argument_start..argument_start + length as usize];
303 argument_start += length as usize;
304
305 Argument::deserialize(raw_argument).map(|_| ())
307 })
308 }
309
310 pub fn iter_arguments(&self) -> ArgumentsIterator<'_> {
312 ArgumentsIterator {
313 arguments_info: &self.arguments_info,
314 next_argument_number: 0,
315 next_offset: 0,
316 }
317 }
318}
319
320impl PacketBody for Reply<'_> {
321 const TYPE: PacketType = PacketType::Authorization;
322
323 const REQUIRED_FIELDS_LENGTH: usize = Status::WIRE_SIZE + 1 + 4;
325}
326
327impl<'raw> Deserialize<'raw> for Reply<'raw> {
328 fn deserialize_from_buffer(buffer: &'raw [u8]) -> Result<Self, DeserializeError> {
329 let ReplyFieldLengths {
330 data_length,
331 server_message_length,
332 total_length,
333 } = Self::extract_field_lengths(buffer)?;
334
335 let length_from_header = buffer.len();
337
338 if total_length as usize == length_from_header {
339 let status = Status::try_from(buffer[0])?;
340 let argument_count = buffer[1];
341
342 let body_start = Self::ARGUMENT_LENGTHS_START + argument_count as usize;
344 let data_start = body_start + server_message_length as usize;
345 let arguments_start = data_start + data_length as usize;
346
347 let server_message = FieldText::try_from(&buffer[body_start..data_start])
348 .map_err(|_| DeserializeError::BadText)?;
349 let data = FieldText::try_from(&buffer[data_start..arguments_start])
350 .map_err(|_| DeserializeError::BadText)?;
351
352 let argument_lengths = &buffer[Self::ARGUMENT_LENGTHS_START..body_start];
354 let argument_values = &buffer[arguments_start..total_length as usize];
355
356 Self::ensure_arguments_valid(argument_lengths, argument_values)?;
357
358 let arguments_info = ArgumentsInfo {
360 argument_count,
361 argument_lengths,
362 arguments_buffer: argument_values,
363 };
364
365 Ok(Self {
366 status,
367 server_message,
368 data,
369 arguments_info,
370 })
371 } else {
372 Err(DeserializeError::WrongBodyBufferSize {
373 expected: total_length as usize,
374 buffer_size: length_from_header,
375 })
376 }
377 }
378}