Skip to main content

simple_someip/protocol/
header.rs

1use crate::{
2    protocol::{Error, MessageId, MessageTypeField, ReturnCode, byte_order::WriteBytesExt},
3    traits::WireFormat,
4};
5
6/// SOME/IP header
7#[derive(Clone, Debug, Eq, PartialEq)]
8pub struct Header {
9    /// Message ID, encoding service ID and method ID
10    message_id: MessageId,
11    /// Length of the message in bytes, starting at the request Id
12    /// Total length of the message is therefore length + 8
13    length: u32,
14    /// SOME/IP Request ID (4 bytes): Client ID [31:16] + Session ID [15:0].
15    request_id: u32,
16    protocol_version: u8,
17    interface_version: u8,
18    message_type: MessageTypeField,
19    return_code: ReturnCode,
20}
21
22impl Header {
23    /// Returns the message ID (service ID + method ID).
24    #[must_use]
25    pub const fn message_id(&self) -> MessageId {
26        self.message_id
27    }
28
29    /// Returns the length field (payload size + 8).
30    #[must_use]
31    pub const fn length(&self) -> u32 {
32        self.length
33    }
34
35    /// Returns the request ID (client ID + session ID).
36    #[must_use]
37    pub const fn request_id(&self) -> u32 {
38        self.request_id
39    }
40
41    /// Returns the protocol version.
42    #[must_use]
43    pub const fn protocol_version(&self) -> u8 {
44        self.protocol_version
45    }
46
47    /// Returns the interface version.
48    #[must_use]
49    pub const fn interface_version(&self) -> u8 {
50        self.interface_version
51    }
52
53    /// Returns the message type field.
54    #[must_use]
55    pub const fn message_type(&self) -> MessageTypeField {
56        self.message_type
57    }
58
59    /// Returns the return code.
60    #[must_use]
61    pub const fn return_code(&self) -> ReturnCode {
62        self.return_code
63    }
64
65    /// Return the 8-byte "upper header" used by E2E UPPER-HEADER-BITS-TO-SHIFT.
66    ///
67    /// Layout (big-endian): `request_id(4)` + `protocol_version(1)` + `interface_version(1)`
68    ///                      + `message_type(1)` + `return_code(1)`
69    ///
70    /// Note: `request_id` is the full 4-byte SOME/IP Request ID field
71    /// (Client ID \[31:16\] + Session ID \[15:0\]), not just the 2-byte Session ID.
72    #[must_use]
73    pub const fn upper_header_bytes(&self) -> [u8; 8] {
74        let rid = self.request_id.to_be_bytes();
75        [
76            rid[0],
77            rid[1],
78            rid[2],
79            rid[3],
80            self.protocol_version,
81            self.interface_version,
82            self.message_type.as_u8(),
83            self.return_code.as_u8(),
84        ]
85    }
86
87    /// Creates a header from raw field values.
88    ///
89    /// Unlike [`new`](Self::new), the `length` field is taken directly rather
90    /// than being computed from a payload size.  This is the inverse of the
91    /// accessor methods and is useful for FFI or any context where the caller
92    /// already has the raw on-wire field values.
93    #[must_use]
94    pub const fn from_fields(
95        message_id: MessageId,
96        length: u32,
97        request_id: u32,
98        protocol_version: u8,
99        interface_version: u8,
100        message_type: MessageTypeField,
101        return_code: ReturnCode,
102    ) -> Self {
103        Self {
104            message_id,
105            length,
106            request_id,
107            protocol_version,
108            interface_version,
109            message_type,
110            return_code,
111        }
112    }
113
114    /// Creates a new header with the given fields.
115    ///
116    /// # Panics
117    ///
118    /// Panics if `payload_len` exceeds `u32::MAX - 8`.
119    #[must_use]
120    #[allow(clippy::cast_possible_truncation)]
121    pub const fn new(
122        message_id: MessageId,
123        request_id: u32,
124        protocol_version: u8,
125        interface_version: u8,
126        message_type: MessageTypeField,
127        return_code: ReturnCode,
128        payload_len: usize,
129    ) -> Self {
130        assert!(payload_len <= u32::MAX as usize - 8);
131        Self {
132            message_id,
133            length: 8 + payload_len as u32,
134            request_id,
135            protocol_version,
136            interface_version,
137            message_type,
138            return_code,
139        }
140    }
141
142    /// Creates a new SOME/IP-SD header with standard SD field values.
143    ///
144    /// # Panics
145    ///
146    /// Panics if `sd_header_size` exceeds `u32::MAX - 8`.
147    #[must_use]
148    #[allow(clippy::cast_possible_truncation)]
149    pub const fn new_sd(request_id: u32, sd_header_size: usize) -> Self {
150        assert!(sd_header_size <= u32::MAX as usize - 8);
151        Self {
152            message_id: MessageId::SD,
153            length: 8 + sd_header_size as u32,
154            request_id,
155            protocol_version: 0x01,
156            interface_version: 0x01,
157            message_type: MessageTypeField::new_sd(),
158            return_code: ReturnCode::Ok,
159        }
160    }
161
162    /// Creates a new header for a SOME/IP event notification.
163    ///
164    /// # Panics
165    ///
166    /// Panics if `payload_len` exceeds `u32::MAX - 8`.
167    #[must_use]
168    #[allow(clippy::cast_possible_truncation)]
169    pub const fn new_event(
170        service_id: u16,
171        event_id: u16,
172        request_id: u32,
173        protocol_version: u8,
174        interface_version: u8,
175        payload_len: usize,
176    ) -> Self {
177        assert!(payload_len <= u32::MAX as usize - 8);
178        Self {
179            message_id: MessageId::new_from_service_and_method(service_id, event_id),
180            length: 8 + payload_len as u32,
181            request_id,
182            protocol_version,
183            interface_version,
184            message_type: MessageTypeField::new(crate::protocol::MessageType::Notification, false),
185            return_code: ReturnCode::Ok,
186        }
187    }
188
189    /// Returns `true` if this is a SOME/IP-SD message.
190    #[must_use]
191    pub const fn is_sd(&self) -> bool {
192        self.message_id.is_sd()
193    }
194
195    /// Returns the payload size in bytes (`length - 8`).
196    #[must_use]
197    pub const fn payload_size(&self) -> usize {
198        self.length as usize - 8
199    }
200
201    /// Sets the request ID field.
202    pub const fn set_request_id(&mut self, request_id: u32) {
203        self.request_id = request_id;
204    }
205}
206
207/// Zero-copy view into a 16-byte SOME/IP header in a buffer.
208#[derive(Clone, Copy, Debug)]
209pub struct HeaderView<'a>(&'a [u8; 16]);
210
211impl<'a> HeaderView<'a> {
212    /// Parse and validate a SOME/IP header from the beginning of `buf`.
213    /// Returns `(view, remaining_bytes)` on success.
214    ///
215    /// # Errors
216    ///
217    /// Returns an error if `buf` is shorter than 16 bytes, the protocol version is
218    /// not `0x01`, the message type byte is unrecognized, or the return code is invalid.
219    ///
220    /// # Panics
221    ///
222    /// Cannot panic — the `expect` is guarded by a length check above it.
223    pub fn parse(buf: &'a [u8]) -> Result<(Self, &'a [u8]), Error> {
224        if buf.len() < 16 {
225            return Err(Error::UnexpectedEof);
226        }
227        let header_bytes: &[u8; 16] = buf[..16].try_into().expect("length checked above");
228        let view = Self(header_bytes);
229
230        // Validate protocol version
231        let pv = view.protocol_version();
232        if pv != 0x01 {
233            return Err(Error::InvalidProtocolVersion(pv));
234        }
235        // Validate message type
236        MessageTypeField::try_from(header_bytes[14])?;
237        // Validate return code
238        ReturnCode::try_from(header_bytes[15])?;
239
240        Ok((view, &buf[16..]))
241    }
242
243    /// Returns the message ID (service ID + method ID).
244    #[must_use]
245    pub fn message_id(&self) -> MessageId {
246        MessageId::from(u32::from_be_bytes([
247            self.0[0], self.0[1], self.0[2], self.0[3],
248        ]))
249    }
250
251    /// Returns the length field (payload size + 8).
252    #[must_use]
253    pub fn length(&self) -> u32 {
254        u32::from_be_bytes([self.0[4], self.0[5], self.0[6], self.0[7]])
255    }
256
257    /// Returns the request ID (client ID + session ID).
258    #[must_use]
259    pub fn request_id(&self) -> u32 {
260        u32::from_be_bytes([self.0[8], self.0[9], self.0[10], self.0[11]])
261    }
262
263    /// Returns the payload size in bytes (`length - 8`).
264    #[must_use]
265    pub fn payload_size(&self) -> usize {
266        self.length() as usize - 8
267    }
268
269    /// Returns the protocol version.
270    #[must_use]
271    pub fn protocol_version(&self) -> u8 {
272        self.0[12]
273    }
274
275    /// Returns the interface version.
276    #[must_use]
277    pub fn interface_version(&self) -> u8 {
278        self.0[13]
279    }
280
281    /// Returns the message type field.
282    ///
283    /// # Panics
284    ///
285    /// Cannot panic — the value is validated during [`Self::parse`].
286    #[must_use]
287    pub fn message_type(&self) -> MessageTypeField {
288        // Safe: validated in parse()
289        MessageTypeField::try_from(self.0[14]).expect("validated in parse")
290    }
291
292    /// Returns the return code.
293    ///
294    /// # Panics
295    ///
296    /// Cannot panic — the value is validated during [`Self::parse`].
297    #[must_use]
298    pub fn return_code(&self) -> ReturnCode {
299        // Safe: validated in parse()
300        ReturnCode::try_from(self.0[15]).expect("validated in parse")
301    }
302
303    /// Returns `true` if this is a SOME/IP-SD message.
304    #[must_use]
305    pub fn is_sd(&self) -> bool {
306        self.message_id().is_sd()
307    }
308
309    /// Copies the view into an owned [`Header`].
310    #[must_use]
311    pub fn to_owned(&self) -> Header {
312        Header {
313            message_id: self.message_id(),
314            length: self.length(),
315            request_id: self.request_id(),
316            protocol_version: self.protocol_version(),
317            interface_version: self.interface_version(),
318            message_type: self.message_type(),
319            return_code: self.return_code(),
320        }
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use crate::protocol::{Error, MessageId, MessageTypeField, ReturnCode};
328
329    fn make_header() -> Header {
330        Header {
331            message_id: MessageId::new_from_service_and_method(0x1234, 0x0001),
332            length: 16,
333            request_id: 0xABCD_0042,
334            protocol_version: 0x01,
335            interface_version: 0x03,
336            message_type: MessageTypeField::try_from(0x00).unwrap(), // Request
337            return_code: ReturnCode::Ok,
338        }
339    }
340
341    fn encode_header(h: &Header) -> [u8; 16] {
342        let mut buf = [0u8; 16];
343        h.encode(&mut buf.as_mut_slice()).unwrap();
344        buf
345    }
346
347    // --- upper_header_bytes ---
348
349    #[test]
350    fn upper_header_bytes_layout() {
351        let h = make_header();
352        let ub = h.upper_header_bytes();
353        let rid = h.request_id().to_be_bytes();
354        assert_eq!(ub[0..4], rid);
355        assert_eq!(ub[4], h.protocol_version());
356        assert_eq!(ub[5], h.interface_version());
357        assert_eq!(ub[6], u8::from(h.message_type()));
358        assert_eq!(ub[7], u8::from(h.return_code()));
359    }
360
361    // --- new_sd ---
362
363    #[test]
364    fn new_sd_fields() {
365        let h = Header::new_sd(0x0000_0001, 28);
366        assert_eq!(h.message_id(), MessageId::SD);
367        assert_eq!(h.length(), 8 + 28);
368        assert_eq!(h.request_id(), 0x0000_0001);
369        assert_eq!(h.protocol_version(), 0x01);
370        assert_eq!(h.interface_version(), 0x01);
371        assert_eq!(h.return_code(), ReturnCode::Ok);
372    }
373
374    // --- is_sd ---
375
376    #[test]
377    fn is_sd_true_for_sd_header() {
378        let h = Header::new_sd(0, 12);
379        assert!(h.is_sd());
380    }
381
382    #[test]
383    fn is_sd_false_for_non_sd_header() {
384        let h = make_header();
385        assert!(!h.is_sd());
386    }
387
388    // --- payload_size ---
389
390    #[test]
391    fn payload_size_returns_length_minus_8() {
392        let h = Header {
393            length: 24,
394            ..make_header()
395        };
396        assert_eq!(h.payload_size(), 16);
397    }
398
399    // --- set_request_id ---
400
401    #[test]
402    fn set_request_id_updates_value() {
403        let mut h = make_header();
404        h.set_request_id(0xDEAD_BEEF);
405        assert_eq!(h.request_id(), 0xDEAD_BEEF);
406    }
407
408    // --- required_size ---
409
410    #[test]
411    fn required_size_is_16() {
412        assert_eq!(make_header().required_size(), 16);
413    }
414
415    // --- encode / parse round-trip ---
416
417    #[test]
418    fn encode_parse_round_trip() {
419        let h = make_header();
420        let buf = encode_header(&h);
421        let (view, remaining) = HeaderView::parse(&buf[..]).unwrap();
422        assert_eq!(view.to_owned(), h);
423        assert!(remaining.is_empty());
424    }
425
426    #[test]
427    fn encode_returns_16() {
428        let h = make_header();
429        let mut buf = [0u8; 16];
430        let n = h.encode(&mut buf.as_mut_slice()).unwrap();
431        assert_eq!(n, 16);
432    }
433
434    #[test]
435    fn sd_header_round_trips() {
436        let h = Header::new_sd(0x0000_0042, 28);
437        let buf = encode_header(&h);
438        let (view, _) = HeaderView::parse(&buf[..]).unwrap();
439        assert_eq!(view.to_owned(), h);
440    }
441
442    // --- parse with exactly-sized slice ---
443
444    #[test]
445    fn parse_exact_size_slice_returns_empty_remainder() {
446        let h = make_header();
447        let buf = encode_header(&h);
448        // buf is exactly 16 bytes — no extra data
449        let (view, remaining) = HeaderView::parse(&buf).unwrap();
450        assert_eq!(view.to_owned(), h);
451        assert!(remaining.is_empty());
452    }
453
454    // --- parse error paths ---
455
456    #[test]
457    fn parse_invalid_protocol_version_returns_error() {
458        let mut h = make_header();
459        h.protocol_version = 0x02;
460        // Manually encode with wrong protocol version
461        let mid = h.message_id.message_id().to_be_bytes();
462        let len = h.length.to_be_bytes();
463        let rid = h.request_id.to_be_bytes();
464        let buf: [u8; 16] = [
465            mid[0], mid[1], mid[2], mid[3], len[0], len[1], len[2], len[3], rid[0], rid[1], rid[2],
466            rid[3], 0x02, // bad protocol version
467            0x03, 0x00, 0x00,
468        ];
469        assert!(matches!(
470            HeaderView::parse(&buf[..]),
471            Err(Error::InvalidProtocolVersion(0x02))
472        ));
473    }
474
475    #[test]
476    fn parse_invalid_message_type_returns_error() {
477        let h = make_header();
478        let mut buf = encode_header(&h);
479        buf[14] = 0xFF; // invalid message type
480        assert!(matches!(
481            HeaderView::parse(&buf[..]),
482            Err(Error::InvalidMessageTypeField(0xFF))
483        ));
484    }
485
486    #[test]
487    fn parse_invalid_return_code_returns_error() {
488        let h = make_header();
489        let mut buf = encode_header(&h);
490        buf[15] = 0x5F; // invalid return code
491        assert!(matches!(
492            HeaderView::parse(&buf[..]),
493            Err(Error::InvalidReturnCode(0x5F))
494        ));
495    }
496
497    #[test]
498    fn parse_truncated_input_returns_eof() {
499        let buf: [u8; 4] = [0x00, 0x00, 0x00, 0x00];
500        assert!(matches!(
501            HeaderView::parse(&buf[..]),
502            Err(Error::UnexpectedEof)
503        ));
504    }
505
506    // --- from_fields ---
507
508    #[test]
509    fn from_fields_round_trip() {
510        let h = make_header();
511        let h2 = Header::from_fields(
512            h.message_id(),
513            h.length(),
514            h.request_id(),
515            h.protocol_version(),
516            h.interface_version(),
517            h.message_type(),
518            h.return_code(),
519        );
520        assert_eq!(h, h2);
521    }
522
523    // --- new_event ---
524
525    #[test]
526    fn new_event_fields() {
527        let h = Header::new_event(0x5B, 0x8001, 0x0001, 0x01, 0x03, 10);
528        assert_eq!(h.message_id().service_id(), 0x5B);
529        assert_eq!(h.message_id().method_id(), 0x8001);
530        assert_eq!(h.request_id(), 0x0001);
531        assert_eq!(h.protocol_version(), 0x01);
532        assert_eq!(h.interface_version(), 0x03);
533        assert_eq!(h.length(), 18); // 8 + 10
534        assert_eq!(h.return_code(), ReturnCode::Ok);
535    }
536
537    // --- new constructor ---
538
539    #[test]
540    fn new_constructor_sets_length() {
541        let h = Header::new(
542            MessageId::new_from_service_and_method(0x1234, 0x0001),
543            0x0001,
544            0x01,
545            0x01,
546            MessageTypeField::try_from(0x00).unwrap(),
547            ReturnCode::Ok,
548            100,
549        );
550        assert_eq!(h.length(), 108); // 8 + 100
551        assert_eq!(h.payload_size(), 100);
552    }
553
554    // --- HeaderView accessors ---
555
556    #[test]
557    fn header_view_accessors() {
558        let h = make_header();
559        let buf = encode_header(&h);
560        let (view, _) = HeaderView::parse(&buf[..]).unwrap();
561        assert_eq!(view.message_id(), h.message_id());
562        assert_eq!(view.length(), h.length());
563        assert_eq!(view.request_id(), h.request_id());
564        assert_eq!(view.payload_size(), h.payload_size());
565        assert_eq!(view.protocol_version(), h.protocol_version());
566        assert_eq!(view.interface_version(), h.interface_version());
567        assert_eq!(view.message_type(), h.message_type());
568        assert_eq!(view.return_code(), h.return_code());
569        assert_eq!(view.is_sd(), h.is_sd());
570    }
571
572    // --- WireFormat default methods ---
573
574    #[test]
575    fn encode_to_slice_works() {
576        let h = make_header();
577        let mut buf = [0u8; 16];
578        let n = h.encode_to_slice(&mut buf).unwrap();
579        assert_eq!(n, 16);
580        let (view, _) = HeaderView::parse(&buf).unwrap();
581        assert_eq!(view.to_owned(), h);
582    }
583
584    #[cfg(feature = "std")]
585    #[test]
586    fn encode_to_vec_works() {
587        let h = make_header();
588        let buf = h.encode_to_vec().unwrap();
589        assert_eq!(buf.len(), 16);
590        let (view, _) = HeaderView::parse(&buf).unwrap();
591        assert_eq!(view.to_owned(), h);
592    }
593}
594
595impl WireFormat for Header {
596    fn required_size(&self) -> usize {
597        16
598    }
599
600    fn encode<T: embedded_io::Write>(&self, writer: &mut T) -> Result<usize, Error> {
601        writer.write_u32_be(self.message_id.message_id())?;
602        writer.write_u32_be(self.length)?;
603        writer.write_u32_be(self.request_id)?;
604        writer.write_u8(self.protocol_version)?;
605        writer.write_u8(self.interface_version)?;
606        writer.write_u8(u8::from(self.message_type))?;
607        writer.write_u8(u8::from(self.return_code))?;
608        Ok(16)
609    }
610}