1use crate::{
2 protocol::{Error, MessageId, MessageTypeField, ReturnCode, byte_order::WriteBytesExt},
3 traits::WireFormat,
4};
5
6#[derive(Clone, Debug, Eq, PartialEq)]
8pub struct Header {
9 message_id: MessageId,
11 length: u32,
14 request_id: u32,
16 protocol_version: u8,
17 interface_version: u8,
18 message_type: MessageTypeField,
19 return_code: ReturnCode,
20}
21
22impl Header {
23 #[must_use]
25 pub const fn message_id(&self) -> MessageId {
26 self.message_id
27 }
28
29 #[must_use]
31 pub const fn length(&self) -> u32 {
32 self.length
33 }
34
35 #[must_use]
37 pub const fn request_id(&self) -> u32 {
38 self.request_id
39 }
40
41 #[must_use]
43 pub const fn protocol_version(&self) -> u8 {
44 self.protocol_version
45 }
46
47 #[must_use]
49 pub const fn interface_version(&self) -> u8 {
50 self.interface_version
51 }
52
53 #[must_use]
55 pub const fn message_type(&self) -> MessageTypeField {
56 self.message_type
57 }
58
59 #[must_use]
61 pub const fn return_code(&self) -> ReturnCode {
62 self.return_code
63 }
64
65 #[must_use]
73 pub const fn upper_header_bytes(&self) -> [u8; 8] {
74 let rid = self.request_id.to_be_bytes();
75 [
76 rid[0],
77 rid[1],
78 rid[2],
79 rid[3],
80 self.protocol_version,
81 self.interface_version,
82 self.message_type.as_u8(),
83 self.return_code.as_u8(),
84 ]
85 }
86
87 #[must_use]
94 pub const fn from_fields(
95 message_id: MessageId,
96 length: u32,
97 request_id: u32,
98 protocol_version: u8,
99 interface_version: u8,
100 message_type: MessageTypeField,
101 return_code: ReturnCode,
102 ) -> Self {
103 Self {
104 message_id,
105 length,
106 request_id,
107 protocol_version,
108 interface_version,
109 message_type,
110 return_code,
111 }
112 }
113
114 #[must_use]
120 #[allow(clippy::cast_possible_truncation)]
121 pub const fn new(
122 message_id: MessageId,
123 request_id: u32,
124 protocol_version: u8,
125 interface_version: u8,
126 message_type: MessageTypeField,
127 return_code: ReturnCode,
128 payload_len: usize,
129 ) -> Self {
130 assert!(payload_len <= u32::MAX as usize - 8);
131 Self {
132 message_id,
133 length: 8 + payload_len as u32,
134 request_id,
135 protocol_version,
136 interface_version,
137 message_type,
138 return_code,
139 }
140 }
141
142 #[must_use]
148 #[allow(clippy::cast_possible_truncation)]
149 pub const fn new_sd(request_id: u32, sd_header_size: usize) -> Self {
150 assert!(sd_header_size <= u32::MAX as usize - 8);
151 Self {
152 message_id: MessageId::SD,
153 length: 8 + sd_header_size as u32,
154 request_id,
155 protocol_version: 0x01,
156 interface_version: 0x01,
157 message_type: MessageTypeField::new_sd(),
158 return_code: ReturnCode::Ok,
159 }
160 }
161
162 #[must_use]
168 #[allow(clippy::cast_possible_truncation)]
169 pub const fn new_event(
170 service_id: u16,
171 event_id: u16,
172 request_id: u32,
173 protocol_version: u8,
174 interface_version: u8,
175 payload_len: usize,
176 ) -> Self {
177 assert!(payload_len <= u32::MAX as usize - 8);
178 Self {
179 message_id: MessageId::new_from_service_and_method(service_id, event_id),
180 length: 8 + payload_len as u32,
181 request_id,
182 protocol_version,
183 interface_version,
184 message_type: MessageTypeField::new(crate::protocol::MessageType::Notification, false),
185 return_code: ReturnCode::Ok,
186 }
187 }
188
189 #[must_use]
191 pub const fn is_sd(&self) -> bool {
192 self.message_id.is_sd()
193 }
194
195 #[must_use]
197 pub const fn payload_size(&self) -> usize {
198 self.length as usize - 8
199 }
200
201 pub const fn set_request_id(&mut self, request_id: u32) {
203 self.request_id = request_id;
204 }
205}
206
207#[derive(Clone, Copy, Debug)]
209pub struct HeaderView<'a>(&'a [u8; 16]);
210
211impl<'a> HeaderView<'a> {
212 pub fn parse(buf: &'a [u8]) -> Result<(Self, &'a [u8]), Error> {
224 if buf.len() < 16 {
225 return Err(Error::UnexpectedEof);
226 }
227 let header_bytes: &[u8; 16] = buf[..16].try_into().expect("length checked above");
228 let view = Self(header_bytes);
229
230 let pv = view.protocol_version();
232 if pv != 0x01 {
233 return Err(Error::InvalidProtocolVersion(pv));
234 }
235 MessageTypeField::try_from(header_bytes[14])?;
237 ReturnCode::try_from(header_bytes[15])?;
239
240 Ok((view, &buf[16..]))
241 }
242
243 #[must_use]
245 pub fn message_id(&self) -> MessageId {
246 MessageId::from(u32::from_be_bytes([
247 self.0[0], self.0[1], self.0[2], self.0[3],
248 ]))
249 }
250
251 #[must_use]
253 pub fn length(&self) -> u32 {
254 u32::from_be_bytes([self.0[4], self.0[5], self.0[6], self.0[7]])
255 }
256
257 #[must_use]
259 pub fn request_id(&self) -> u32 {
260 u32::from_be_bytes([self.0[8], self.0[9], self.0[10], self.0[11]])
261 }
262
263 #[must_use]
265 pub fn payload_size(&self) -> usize {
266 self.length() as usize - 8
267 }
268
269 #[must_use]
271 pub fn protocol_version(&self) -> u8 {
272 self.0[12]
273 }
274
275 #[must_use]
277 pub fn interface_version(&self) -> u8 {
278 self.0[13]
279 }
280
281 #[must_use]
287 pub fn message_type(&self) -> MessageTypeField {
288 MessageTypeField::try_from(self.0[14]).expect("validated in parse")
290 }
291
292 #[must_use]
298 pub fn return_code(&self) -> ReturnCode {
299 ReturnCode::try_from(self.0[15]).expect("validated in parse")
301 }
302
303 #[must_use]
305 pub fn is_sd(&self) -> bool {
306 self.message_id().is_sd()
307 }
308
309 #[must_use]
311 pub fn to_owned(&self) -> Header {
312 Header {
313 message_id: self.message_id(),
314 length: self.length(),
315 request_id: self.request_id(),
316 protocol_version: self.protocol_version(),
317 interface_version: self.interface_version(),
318 message_type: self.message_type(),
319 return_code: self.return_code(),
320 }
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use crate::protocol::{Error, MessageId, MessageTypeField, ReturnCode};
328
329 fn make_header() -> Header {
330 Header {
331 message_id: MessageId::new_from_service_and_method(0x1234, 0x0001),
332 length: 16,
333 request_id: 0xABCD_0042,
334 protocol_version: 0x01,
335 interface_version: 0x03,
336 message_type: MessageTypeField::try_from(0x00).unwrap(), return_code: ReturnCode::Ok,
338 }
339 }
340
341 fn encode_header(h: &Header) -> [u8; 16] {
342 let mut buf = [0u8; 16];
343 h.encode(&mut buf.as_mut_slice()).unwrap();
344 buf
345 }
346
347 #[test]
350 fn upper_header_bytes_layout() {
351 let h = make_header();
352 let ub = h.upper_header_bytes();
353 let rid = h.request_id().to_be_bytes();
354 assert_eq!(ub[0..4], rid);
355 assert_eq!(ub[4], h.protocol_version());
356 assert_eq!(ub[5], h.interface_version());
357 assert_eq!(ub[6], u8::from(h.message_type()));
358 assert_eq!(ub[7], u8::from(h.return_code()));
359 }
360
361 #[test]
364 fn new_sd_fields() {
365 let h = Header::new_sd(0x0000_0001, 28);
366 assert_eq!(h.message_id(), MessageId::SD);
367 assert_eq!(h.length(), 8 + 28);
368 assert_eq!(h.request_id(), 0x0000_0001);
369 assert_eq!(h.protocol_version(), 0x01);
370 assert_eq!(h.interface_version(), 0x01);
371 assert_eq!(h.return_code(), ReturnCode::Ok);
372 }
373
374 #[test]
377 fn is_sd_true_for_sd_header() {
378 let h = Header::new_sd(0, 12);
379 assert!(h.is_sd());
380 }
381
382 #[test]
383 fn is_sd_false_for_non_sd_header() {
384 let h = make_header();
385 assert!(!h.is_sd());
386 }
387
388 #[test]
391 fn payload_size_returns_length_minus_8() {
392 let h = Header {
393 length: 24,
394 ..make_header()
395 };
396 assert_eq!(h.payload_size(), 16);
397 }
398
399 #[test]
402 fn set_request_id_updates_value() {
403 let mut h = make_header();
404 h.set_request_id(0xDEAD_BEEF);
405 assert_eq!(h.request_id(), 0xDEAD_BEEF);
406 }
407
408 #[test]
411 fn required_size_is_16() {
412 assert_eq!(make_header().required_size(), 16);
413 }
414
415 #[test]
418 fn encode_parse_round_trip() {
419 let h = make_header();
420 let buf = encode_header(&h);
421 let (view, remaining) = HeaderView::parse(&buf[..]).unwrap();
422 assert_eq!(view.to_owned(), h);
423 assert!(remaining.is_empty());
424 }
425
426 #[test]
427 fn encode_returns_16() {
428 let h = make_header();
429 let mut buf = [0u8; 16];
430 let n = h.encode(&mut buf.as_mut_slice()).unwrap();
431 assert_eq!(n, 16);
432 }
433
434 #[test]
435 fn sd_header_round_trips() {
436 let h = Header::new_sd(0x0000_0042, 28);
437 let buf = encode_header(&h);
438 let (view, _) = HeaderView::parse(&buf[..]).unwrap();
439 assert_eq!(view.to_owned(), h);
440 }
441
442 #[test]
445 fn parse_exact_size_slice_returns_empty_remainder() {
446 let h = make_header();
447 let buf = encode_header(&h);
448 let (view, remaining) = HeaderView::parse(&buf).unwrap();
450 assert_eq!(view.to_owned(), h);
451 assert!(remaining.is_empty());
452 }
453
454 #[test]
457 fn parse_invalid_protocol_version_returns_error() {
458 let mut h = make_header();
459 h.protocol_version = 0x02;
460 let mid = h.message_id.message_id().to_be_bytes();
462 let len = h.length.to_be_bytes();
463 let rid = h.request_id.to_be_bytes();
464 let buf: [u8; 16] = [
465 mid[0], mid[1], mid[2], mid[3], len[0], len[1], len[2], len[3], rid[0], rid[1], rid[2],
466 rid[3], 0x02, 0x03, 0x00, 0x00,
468 ];
469 assert!(matches!(
470 HeaderView::parse(&buf[..]),
471 Err(Error::InvalidProtocolVersion(0x02))
472 ));
473 }
474
475 #[test]
476 fn parse_invalid_message_type_returns_error() {
477 let h = make_header();
478 let mut buf = encode_header(&h);
479 buf[14] = 0xFF; assert!(matches!(
481 HeaderView::parse(&buf[..]),
482 Err(Error::InvalidMessageTypeField(0xFF))
483 ));
484 }
485
486 #[test]
487 fn parse_invalid_return_code_returns_error() {
488 let h = make_header();
489 let mut buf = encode_header(&h);
490 buf[15] = 0x5F; assert!(matches!(
492 HeaderView::parse(&buf[..]),
493 Err(Error::InvalidReturnCode(0x5F))
494 ));
495 }
496
497 #[test]
498 fn parse_truncated_input_returns_eof() {
499 let buf: [u8; 4] = [0x00, 0x00, 0x00, 0x00];
500 assert!(matches!(
501 HeaderView::parse(&buf[..]),
502 Err(Error::UnexpectedEof)
503 ));
504 }
505
506 #[test]
509 fn from_fields_round_trip() {
510 let h = make_header();
511 let h2 = Header::from_fields(
512 h.message_id(),
513 h.length(),
514 h.request_id(),
515 h.protocol_version(),
516 h.interface_version(),
517 h.message_type(),
518 h.return_code(),
519 );
520 assert_eq!(h, h2);
521 }
522
523 #[test]
526 fn new_event_fields() {
527 let h = Header::new_event(0x5B, 0x8001, 0x0001, 0x01, 0x03, 10);
528 assert_eq!(h.message_id().service_id(), 0x5B);
529 assert_eq!(h.message_id().method_id(), 0x8001);
530 assert_eq!(h.request_id(), 0x0001);
531 assert_eq!(h.protocol_version(), 0x01);
532 assert_eq!(h.interface_version(), 0x03);
533 assert_eq!(h.length(), 18); assert_eq!(h.return_code(), ReturnCode::Ok);
535 }
536
537 #[test]
540 fn new_constructor_sets_length() {
541 let h = Header::new(
542 MessageId::new_from_service_and_method(0x1234, 0x0001),
543 0x0001,
544 0x01,
545 0x01,
546 MessageTypeField::try_from(0x00).unwrap(),
547 ReturnCode::Ok,
548 100,
549 );
550 assert_eq!(h.length(), 108); assert_eq!(h.payload_size(), 100);
552 }
553
554 #[test]
557 fn header_view_accessors() {
558 let h = make_header();
559 let buf = encode_header(&h);
560 let (view, _) = HeaderView::parse(&buf[..]).unwrap();
561 assert_eq!(view.message_id(), h.message_id());
562 assert_eq!(view.length(), h.length());
563 assert_eq!(view.request_id(), h.request_id());
564 assert_eq!(view.payload_size(), h.payload_size());
565 assert_eq!(view.protocol_version(), h.protocol_version());
566 assert_eq!(view.interface_version(), h.interface_version());
567 assert_eq!(view.message_type(), h.message_type());
568 assert_eq!(view.return_code(), h.return_code());
569 assert_eq!(view.is_sd(), h.is_sd());
570 }
571
572 #[test]
575 fn encode_to_slice_works() {
576 let h = make_header();
577 let mut buf = [0u8; 16];
578 let n = h.encode_to_slice(&mut buf).unwrap();
579 assert_eq!(n, 16);
580 let (view, _) = HeaderView::parse(&buf).unwrap();
581 assert_eq!(view.to_owned(), h);
582 }
583
584 #[cfg(feature = "std")]
585 #[test]
586 fn encode_to_vec_works() {
587 let h = make_header();
588 let buf = h.encode_to_vec().unwrap();
589 assert_eq!(buf.len(), 16);
590 let (view, _) = HeaderView::parse(&buf).unwrap();
591 assert_eq!(view.to_owned(), h);
592 }
593}
594
595impl WireFormat for Header {
596 fn required_size(&self) -> usize {
597 16
598 }
599
600 fn encode<T: embedded_io::Write>(&self, writer: &mut T) -> Result<usize, Error> {
601 writer.write_u32_be(self.message_id.message_id())?;
602 writer.write_u32_be(self.length)?;
603 writer.write_u32_be(self.request_id)?;
604 writer.write_u8(self.protocol_version)?;
605 writer.write_u8(self.interface_version)?;
606 writer.write_u8(u8::from(self.message_type))?;
607 writer.write_u8(u8::from(self.return_code))?;
608 Ok(16)
609 }
610}