Skip to main content

sandlock_core/netlink/
proto.rs

1use std::mem::size_of;
2
3pub const NLMSG_ALIGN_TO: usize = 4;
4pub const fn nlmsg_align(len: usize) -> usize {
5    (len + NLMSG_ALIGN_TO - 1) & !(NLMSG_ALIGN_TO - 1)
6}
7
8pub const NLMSG_ERROR: u16 = 0x0002;
9pub const NLMSG_DONE: u16 = 0x0003;
10pub const RTM_GETLINK: u16 = 18;
11pub const RTM_NEWLINK: u16 = 16;
12pub const RTM_GETADDR: u16 = 22;
13pub const RTM_NEWADDR: u16 = 20;
14
15pub const NLM_F_REQUEST: u16 = 0x001;
16pub const NLM_F_MULTI:   u16 = 0x002;
17pub const NLM_F_DUMP:    u16 = 0x300;
18
19#[repr(C)]
20#[derive(Debug, Clone, Copy)]
21pub struct NlMsgHdr {
22    pub nlmsg_len:   u32,
23    pub nlmsg_type:  u16,
24    pub nlmsg_flags: u16,
25    pub nlmsg_seq:   u32,
26    pub nlmsg_pid:   u32,
27}
28
29#[repr(C)]
30#[derive(Debug, Clone, Copy)]
31pub struct IfInfoMsg {
32    pub ifi_family: u8,
33    pub _pad:       u8,
34    pub ifi_type:   u16,
35    pub ifi_index:  i32,
36    pub ifi_flags:  u32,
37    pub ifi_change: u32,
38}
39
40#[repr(C)]
41#[derive(Debug, Clone, Copy)]
42pub struct IfAddrMsg {
43    pub ifa_family:    u8,
44    pub ifa_prefixlen: u8,
45    pub ifa_flags:     u8,
46    pub ifa_scope:     u8,
47    pub ifa_index:     u32,
48}
49
50#[repr(C)]
51#[derive(Debug, Clone, Copy)]
52pub struct RtAttr {
53    pub rta_len:  u16,
54    pub rta_type: u16,
55}
56
57pub const NLMSG_HDRLEN: usize = size_of::<NlMsgHdr>();
58pub const RTA_HDRLEN:   usize = size_of::<RtAttr>();
59
60pub struct Writer { buf: Vec<u8> }
61
62impl Writer {
63    pub fn new() -> Self { Self { buf: Vec::new() } }
64    pub fn into_vec(self) -> Vec<u8> { self.buf }
65
66    pub fn write_aligned(&mut self, bytes: &[u8]) {
67        self.buf.extend_from_slice(bytes);
68        let pad = nlmsg_align(bytes.len()) - bytes.len();
69        self.buf.resize(self.buf.len() + pad, 0);
70    }
71
72    pub fn write_attr(&mut self, rta_type: u16, payload: &[u8]) {
73        let total = RTA_HDRLEN + payload.len();
74        let hdr = RtAttr { rta_len: total as u16, rta_type };
75        let hdr_bytes = unsafe {
76            std::slice::from_raw_parts(&hdr as *const _ as *const u8, RTA_HDRLEN)
77        };
78        self.buf.extend_from_slice(hdr_bytes);
79        self.buf.extend_from_slice(payload);
80        let pad = nlmsg_align(total) - total;
81        self.buf.resize(self.buf.len() + pad, 0);
82    }
83
84    pub fn begin_msg(&mut self, nlmsg_type: u16, flags: u16, seq: u32, pid: u32) -> usize {
85        let start = self.buf.len();
86        let hdr = NlMsgHdr {
87            nlmsg_len: 0,
88            nlmsg_type, nlmsg_flags: flags, nlmsg_seq: seq, nlmsg_pid: pid,
89        };
90        let hdr_bytes = unsafe {
91            std::slice::from_raw_parts(&hdr as *const _ as *const u8, NLMSG_HDRLEN)
92        };
93        self.buf.extend_from_slice(hdr_bytes);
94        start
95    }
96
97    pub fn finish_msg(&mut self, start: usize) {
98        let total = self.buf.len() - start;
99        let len_bytes = (total as u32).to_ne_bytes();
100        self.buf[start..start + 4].copy_from_slice(&len_bytes);
101        let pad = nlmsg_align(total) - total;
102        self.buf.resize(self.buf.len() + pad, 0);
103    }
104}
105
106#[derive(Debug, Clone, Copy)]
107pub struct ParsedRequest {
108    pub nlmsg_type: u16,
109    pub nlmsg_flags: u16,
110    pub nlmsg_seq: u32,
111    pub nlmsg_pid: u32,
112}
113
114pub fn parse_request(buf: &[u8]) -> Option<ParsedRequest> {
115    if buf.len() < NLMSG_HDRLEN { return None; }
116    let hdr: NlMsgHdr = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const _) };
117    if (hdr.nlmsg_len as usize) > buf.len() { return None; }
118    Some(ParsedRequest {
119        nlmsg_type: hdr.nlmsg_type,
120        nlmsg_flags: hdr.nlmsg_flags,
121        nlmsg_seq: hdr.nlmsg_seq,
122        nlmsg_pid: hdr.nlmsg_pid,
123    })
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn align_rounds_to_4() {
132        assert_eq!(nlmsg_align(0), 0);
133        assert_eq!(nlmsg_align(1), 4);
134        assert_eq!(nlmsg_align(4), 4);
135        assert_eq!(nlmsg_align(5), 8);
136        assert_eq!(nlmsg_align(16), 16);
137    }
138
139    #[test]
140    fn writer_msg_round_trip() {
141        let mut w = Writer::new();
142        let start = w.begin_msg(RTM_NEWLINK, NLM_F_MULTI, 42, 0);
143        w.write_attr(3 /* IFLA_IFNAME */, b"lo\0");
144        w.finish_msg(start);
145        let buf = w.into_vec();
146        let parsed = parse_request(&buf).unwrap();
147        assert_eq!(parsed.nlmsg_type, RTM_NEWLINK);
148        assert_eq!(parsed.nlmsg_seq, 42);
149        let total = u32::from_ne_bytes(buf[0..4].try_into().unwrap()) as usize;
150        assert!(total >= NLMSG_HDRLEN + RTA_HDRLEN + 3);
151    }
152
153    #[test]
154    fn parse_request_rejects_short_buffer() {
155        assert!(parse_request(&[0u8; 4]).is_none());
156    }
157}