1use bitflags::bitflags;
4use byteorder::{ByteOrder, NetworkEndian};
5use core::fmt;
6use getset::Getters;
7use num_enum::{TryFromPrimitive, TryFromPrimitiveError};
8
9use super::{
10 Arguments, AuthenticationContext, AuthenticationMethod, Deserialize, DeserializeError,
11 PacketBody, PacketType, Serialize, SerializeError, UserInformation,
12};
13use crate::FieldText;
14
15#[cfg(test)]
16mod tests;
17
18#[cfg(feature = "std")]
19mod owned;
20
21#[cfg(feature = "std")]
22pub use owned::ReplyOwned;
23
24bitflags! {
25 struct RawFlags: u8 {
27 const START = 0b00000010;
28 const STOP = 0b00000100;
29 const WATCHDOG = 0b00001000;
30 }
31}
32
33#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
35pub enum Flags {
36 StartRecord,
38
39 StopRecord,
41
42 WatchdogNoUpdate,
44
45 WatchdogUpdate,
47}
48
49impl fmt::Display for Flags {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(
52 f,
53 "{}",
54 match self {
55 Self::StartRecord => "start of record",
56 Self::StopRecord => "end of record",
57 Self::WatchdogUpdate => "update with new information",
58 Self::WatchdogNoUpdate => "update with no new information",
59 }
60 )
61 }
62}
63
64impl From<Flags> for RawFlags {
65 fn from(value: Flags) -> Self {
66 match value {
67 Flags::StartRecord => RawFlags::START,
68 Flags::StopRecord => RawFlags::STOP,
69 Flags::WatchdogNoUpdate => RawFlags::WATCHDOG,
70 Flags::WatchdogUpdate => RawFlags::WATCHDOG | RawFlags::START,
71 }
72 }
73}
74
75impl Flags {
76 pub(super) const WIRE_SIZE: usize = 1;
78}
79
80#[derive(PartialEq, Eq, Clone, Debug, Hash)]
82pub struct Request<'packet> {
83 flags: Flags,
85
86 authentication_method: AuthenticationMethod,
88
89 authentication: AuthenticationContext,
91
92 user_information: UserInformation<'packet>,
94
95 arguments: Arguments<'packet>,
97}
98
99impl<'packet> Request<'packet> {
100 const ARGUMENT_LENGTHS_OFFSET: usize = 9;
102
103 pub fn new(
105 flags: Flags,
106 authentication_method: AuthenticationMethod,
107 authentication: AuthenticationContext,
108 user_information: UserInformation<'packet>,
109 arguments: Arguments<'packet>,
110 ) -> Self {
111 Self {
112 flags,
113 authentication_method,
114 authentication,
115 user_information,
116 arguments,
117 }
118 }
119}
120
121impl PacketBody for Request<'_> {
122 const TYPE: PacketType = PacketType::Accounting;
123
124 const REQUIRED_FIELDS_LENGTH: usize =
126 Flags::WIRE_SIZE + AuthenticationMethod::WIRE_SIZE + AuthenticationContext::WIRE_SIZE + 4;
127}
128
129impl Serialize for Request<'_> {
130 fn wire_size(&self) -> usize {
131 Flags::WIRE_SIZE
132 + AuthenticationMethod::WIRE_SIZE
133 + AuthenticationContext::WIRE_SIZE
134 + self.user_information.wire_size()
135 + self.arguments.wire_size()
136 }
137
138 fn serialize_into_buffer(&self, buffer: &mut [u8]) -> Result<usize, SerializeError> {
139 let wire_size = self.wire_size();
140
141 if buffer.len() >= wire_size {
142 buffer[0] = RawFlags::from(self.flags).bits();
143 buffer[1] = self.authentication_method as u8;
144
145 self.authentication.serialize(&mut buffer[2..5]);
147 self.user_information
148 .serialize_field_lengths(&mut buffer[5..8])?;
149
150 let argument_count = self.arguments.argument_count() as usize;
151
152 let body_start = Self::ARGUMENT_LENGTHS_OFFSET + argument_count;
154
155 let user_information_len = self
158 .user_information
159 .serialize_field_values(&mut buffer[body_start..wire_size])?;
160
161 let arguments_serialized_len =
162 self.arguments.serialize_count_and_lengths(&mut buffer[8..8 + argument_count + 1])?
165 + self
167 .arguments
168 .serialize_encoded_values(&mut buffer[body_start + user_information_len..wire_size])?;
169
170 let actual_written_len = (Self::REQUIRED_FIELDS_LENGTH - 1)
172 + user_information_len
173 + arguments_serialized_len;
174
175 if actual_written_len == wire_size {
177 Ok(actual_written_len)
178 } else {
179 Err(SerializeError::LengthMismatch {
180 expected: wire_size,
181 actual: actual_written_len,
182 })
183 }
184 } else {
185 Err(SerializeError::NotEnoughSpace)
186 }
187 }
188}
189
190#[repr(u8)]
192#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, TryFromPrimitive)]
193pub enum Status {
194 Success = 0x01,
196
197 Error = 0x02,
199
200 #[deprecated = "Forwarding to an alternative daemon was deprecated in RFC-8907."]
202 Follow = 0x21,
203}
204
205impl Status {
206 pub(super) const WIRE_SIZE: usize = 1;
208}
209
210impl fmt::Display for Status {
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 write!(
213 f,
214 "{}",
215 match self {
216 Self::Success => "success",
217 Self::Error => "error",
218 #[allow(deprecated)]
219 Self::Follow => "follow",
220 }
221 )
222 }
223}
224
225#[doc(hidden)]
226impl From<TryFromPrimitiveError<Status>> for DeserializeError {
227 fn from(value: TryFromPrimitiveError<Status>) -> Self {
228 Self::InvalidStatus(value.number)
229 }
230}
231
232#[derive(Clone, PartialEq, Eq, Debug, Hash, Getters)]
234pub struct Reply<'packet> {
235 #[getset(get = "pub")]
237 status: Status,
238
239 #[getset(get = "pub")]
241 server_message: FieldText<'packet>,
242
243 #[getset(get = "pub")]
245 data: FieldText<'packet>,
246}
247
248struct ReplyFieldLengths {
250 server_message_length: u16,
251 data_length: u16,
252 total_length: u32,
253}
254
255impl Reply<'_> {
256 const SERVER_MESSAGE_OFFSET: usize = 5;
258
259 pub fn extract_total_length(buffer: &[u8]) -> Result<u32, DeserializeError> {
261 if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH {
262 Self::extract_field_lengths(buffer).map(|lengths| lengths.total_length)
263 } else {
264 Err(DeserializeError::UnexpectedEnd)
265 }
266 }
267
268 fn extract_field_lengths(buffer: &[u8]) -> Result<ReplyFieldLengths, DeserializeError> {
270 if buffer.len() >= Self::REQUIRED_FIELDS_LENGTH {
272 let server_message_length = NetworkEndian::read_u16(&buffer[..2]);
274
275 let data_length = NetworkEndian::read_u16(&buffer[2..4]);
277
278 let total_length = u32::try_from(Self::REQUIRED_FIELDS_LENGTH).unwrap()
281 + u32::from(server_message_length)
282 + u32::from(data_length);
283
284 Ok(ReplyFieldLengths {
285 server_message_length,
286 data_length,
287 total_length,
288 })
289 } else {
290 Err(DeserializeError::UnexpectedEnd)
291 }
292 }
293}
294
295impl PacketBody for Reply<'_> {
296 const TYPE: PacketType = PacketType::Accounting;
297
298 const REQUIRED_FIELDS_LENGTH: usize = Status::WIRE_SIZE + 4;
300}
301
302impl<'raw> Deserialize<'raw> for Reply<'raw> {
303 fn deserialize_from_buffer(buffer: &'raw [u8]) -> Result<Self, DeserializeError> {
304 let extracted_lengths = Self::extract_field_lengths(buffer)?;
305
306 let length_from_header = buffer.len();
309
310 if extracted_lengths.total_length as usize == length_from_header {
312 let status = Status::try_from(buffer[4])?;
314
315 let data_offset =
316 Self::SERVER_MESSAGE_OFFSET + extracted_lengths.server_message_length as usize;
317
318 let server_message =
319 FieldText::try_from(&buffer[Self::SERVER_MESSAGE_OFFSET..data_offset])
320 .map_err(|_| DeserializeError::BadText)?;
321 let data = FieldText::try_from(
322 &buffer[data_offset..data_offset + extracted_lengths.data_length as usize],
323 )
324 .map_err(|_| DeserializeError::BadText)?;
325
326 Ok(Self {
327 status,
328 server_message,
329 data,
330 })
331 } else {
332 Err(DeserializeError::WrongBodyBufferSize {
333 expected: extracted_lengths.total_length as usize,
334 buffer_size: length_from_header,
335 })
336 }
337 }
338}