zero_packet/network/
ipv4.rs

1use super::checksum::internet_checksum;
2use crate::misc::IpFormatter;
3use core::fmt;
4
5/// The length of an IPv4 header in bytes without options.
6pub const IPV4_MIN_HEADER_LENGTH: usize = 20;
7
8/// Writes IPv4 header fields.
9pub struct IPv4Writer<'a> {
10    pub bytes: &'a mut [u8],
11}
12
13impl<'a> IPv4Writer<'a> {
14    /// Creates a new `IPv4Writer` from the given slice.
15    #[inline]
16    pub fn new(bytes: &'a mut [u8]) -> Result<Self, &'static str> {
17        if bytes.len() < IPV4_MIN_HEADER_LENGTH {
18            return Err("Slice is too short to contain an IPv4 header.");
19        }
20
21        Ok(Self { bytes })
22    }
23
24    /// Returns the actual header length in bytes.
25    #[inline]
26    pub fn header_len(&self) -> usize {
27        (self.bytes[0] & 0x0f) as usize * 4
28    }
29
30    /// Sets the version field.
31    ///
32    /// Indicates the IP version number. Should be set to 4.
33    #[inline]
34    pub fn set_version(&mut self, version: u8) {
35        self.bytes[0] = (self.bytes[0] & 0x0F) | (version << 4);
36    }
37
38    /// Sets the IHL field.
39    #[inline]
40    pub fn set_ihl(&mut self, ihl: u8) {
41        self.bytes[0] = (self.bytes[0] & 0xF0) | (ihl & 0x0F);
42    }
43
44    /// Sets the DSCP field.
45    #[inline]
46    pub fn set_dscp(&mut self, dscp: u8) {
47        self.bytes[1] = (self.bytes[1] & 0x03) | (dscp << 2);
48    }
49
50    /// Sets the ECN field.
51    #[inline]
52    pub fn set_ecn(&mut self, ecn: u8) {
53        self.bytes[1] = (self.bytes[1] & 0xFC) | (ecn & 0x03);
54    }
55
56    /// Sets the total length field.
57    ///
58    /// Should include the entire packet (header and payload).
59    #[inline]
60    pub fn set_total_length(&mut self, total_length: u16) {
61        self.bytes[2] = (total_length >> 8) as u8;
62        self.bytes[3] = (total_length & 0xFF) as u8;
63    }
64
65    /// Sets the identification field.
66    #[inline]
67    pub fn set_id(&mut self, identification: u16) {
68        self.bytes[4] = (identification >> 8) as u8;
69        self.bytes[5] = (identification & 0xFF) as u8;
70    }
71
72    /// Sets the flags field.
73    #[inline]
74    pub fn set_flags(&mut self, flags: u8) {
75        self.bytes[6] = (self.bytes[6] & 0x1F) | ((flags << 5) & 0xE0);
76    }
77
78    /// Sets the fragment offset field.
79    #[inline]
80    pub fn set_fragment_offset(&mut self, fragment_offset: u16) {
81        self.bytes[6] = (self.bytes[6] & 0xE0) | ((fragment_offset >> 8) & 0x1F) as u8;
82        self.bytes[7] = (fragment_offset & 0xFF) as u8;
83    }
84
85    /// Sets the TTL field.
86    #[inline]
87    pub fn set_ttl(&mut self, ttl: u8) {
88        self.bytes[8] = ttl;
89    }
90
91    /// Sets the protocol field.
92    #[inline]
93    pub fn set_protocol(&mut self, protocol: u8) {
94        self.bytes[9] = protocol;
95    }
96
97    /// Sets the source IP address field.
98    #[inline]
99    pub fn set_src_ip(&mut self, src_ip: &[u8; 4]) {
100        self.bytes[12] = src_ip[0];
101        self.bytes[13] = src_ip[1];
102        self.bytes[14] = src_ip[2];
103        self.bytes[15] = src_ip[3];
104    }
105
106    /// Sets the destination IP address field.
107    #[inline]
108    pub fn set_dest_ip(&mut self, dest_ip: &[u8; 4]) {
109        self.bytes[16] = dest_ip[0];
110        self.bytes[17] = dest_ip[1];
111        self.bytes[18] = dest_ip[2];
112        self.bytes[19] = dest_ip[3];
113    }
114
115    /// Calculates the checksum field.
116    ///
117    /// Only includes the header.
118    #[inline]
119    pub fn set_checksum(&mut self) {
120        self.bytes[10] = 0;
121        self.bytes[11] = 0;
122        let header_len = self.header_len();
123        let checksum = internet_checksum(&self.bytes[..header_len], 0);
124        self.bytes[10] = (checksum >> 8) as u8;
125        self.bytes[11] = (checksum & 0xff) as u8;
126    }
127}
128
129/// Reads IPv4 header fields.
130#[derive(PartialEq)]
131pub struct IPv4Reader<'a> {
132    pub bytes: &'a [u8],
133}
134
135impl<'a> IPv4Reader<'a> {
136    /// Creates a new `IPv4Reader` from the given slice.
137    #[inline]
138    pub fn new(bytes: &'a [u8]) -> Result<Self, &'static str> {
139        if bytes.len() < IPV4_MIN_HEADER_LENGTH {
140            return Err("Slice is too short to contain an IPv4 header.");
141        }
142
143        Ok(Self { bytes })
144    }
145
146    /// Returns the version field.
147    #[inline]
148    pub fn version(&self) -> u8 {
149        self.bytes[0] >> 4
150    }
151
152    /// Returns the IHL field.
153    #[inline]
154    pub fn ihl(&self) -> u8 {
155        self.bytes[0] & 0x0f
156    }
157
158    /// Returns the DSCP field.
159    #[inline]
160    pub fn dscp(&self) -> u8 {
161        self.bytes[1] >> 2
162    }
163
164    /// Returns the ECN field.
165    #[inline]
166    pub fn ecn(&self) -> u8 {
167        self.bytes[1] & 0x03
168    }
169
170    /// Returns the total length field.
171    ///
172    /// Includes the entire packet (header and payload).
173    #[inline]
174    pub fn total_length(&self) -> u16 {
175        ((self.bytes[2] as u16) << 8) | (self.bytes[3] as u16)
176    }
177
178    /// Returns the identification field.
179    #[inline]
180    pub fn id(&self) -> u16 {
181        ((self.bytes[4] as u16) << 8) | (self.bytes[5] as u16)
182    }
183
184    /// Returns the flags field.
185    #[inline]
186    pub fn flags(&self) -> u8 {
187        self.bytes[6] >> 5
188    }
189
190    /// Returns the fragment offset field.
191    #[inline]
192    pub fn fragment_offset(&self) -> u16 {
193        ((self.bytes[6] as u16 & 0x1F) << 8) | (self.bytes[7] as u16)
194    }
195
196    /// Returns the TTL field.
197    #[inline]
198    pub fn ttl(&self) -> u8 {
199        self.bytes[8]
200    }
201
202    /// Returns the protocol field.
203    #[inline]
204    pub fn protocol(&self) -> u8 {
205        self.bytes[9]
206    }
207
208    /// Returns a reference to the source IP address field.
209    #[inline]
210    pub fn src_ip(&self) -> &[u8] {
211        &self.bytes[12..16]
212    }
213
214    /// Returns a reference to the destination IP address field.
215    #[inline]
216    pub fn dest_ip(&self) -> &[u8] {
217        &self.bytes[16..20]
218    }
219
220    /// Returns the checksum field.
221    #[inline]
222    pub fn checksum(&self) -> u16 {
223        ((self.bytes[10] as u16) << 8) | (self.bytes[11] as u16)
224    }
225
226    /// Returns the indicated header length in bytes.
227    #[inline]
228    pub fn header_len(&self) -> usize {
229        self.ihl() as usize * 4
230    }
231
232    /// Returns a reference to the header.
233    /// 
234    /// May fail if the indicated header length exceeds the allocated buffer.
235    #[inline]
236    pub fn header(&self) -> Result<&'a [u8], &'static str> {
237        let end = self.header_len();
238
239        if end > self.bytes.len() {
240            return Err("Indicated IPv4 header length exceeds the allocated buffer.");
241        }
242
243        Ok(&self.bytes[..self.header_len()])
244    }
245
246    /// Returns a reference to the payload.
247    /// 
248    /// May fail if the indicated header length exceeds the allocated buffer.
249    #[inline]
250    pub fn payload(&self) -> Result<&'a [u8], &'static str> {
251        let start = self.header_len();
252
253        if start > self.bytes.len() {
254            return Err("Indicated IPv4 header length exceeds the allocated buffer.");
255        }
256
257        Ok(&self.bytes[self.header_len()..])
258    }
259
260    /// Verifies the checksum field.
261    #[inline]
262    pub fn valid_checksum(&self) -> Result<bool, &'static str> {
263        Ok(internet_checksum(self.header()?, 0) == 0)
264    }
265}
266
267impl fmt::Debug for IPv4Reader<'_> {
268    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
269        let src_ip = self.src_ip();
270        let dest_ip = self.dest_ip();
271        f.debug_struct("IPv4Packet")
272            .field("version", &self.version())
273            .field("ihl", &self.ihl())
274            .field("dscp", &self.dscp())
275            .field("ecn", &self.ecn())
276            .field("total_length", &self.total_length())
277            .field("identification", &self.id())
278            .field("flags", &self.flags())
279            .field("fragment_offset", &self.fragment_offset())
280            .field("ttl", &self.ttl())
281            .field("protocol", &self.protocol())
282            .field("checksum", &self.checksum())
283            .field("src_ip", &IpFormatter(src_ip))
284            .field("dest_ip", &IpFormatter(dest_ip))
285            .finish()
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn getters_and_setters() {
295        // Raw packet.
296        let mut bytes = [0u8; 20];
297
298        // Random values.
299        let version = 4;
300        let ihl = 5;
301        let src_ip = [192, 168, 1, 1];
302        let dest_ip = [192, 168, 1, 2];
303        let dscp = 0;
304        let ecn = 0;
305        let total_length = 40;
306        let identification = 0;
307        let flags = 2;
308        let fragment_offset = 0;
309        let ttl = 64;
310        let protocol = 0x06;
311
312        // Create a IPv4 packet writer.
313        let mut writer = IPv4Writer::new(&mut bytes).unwrap();
314
315        // Set the fields.
316        writer.set_version(version);
317        writer.set_ihl(ihl);
318        writer.set_dscp(dscp);
319        writer.set_ecn(ecn);
320        writer.set_total_length(total_length);
321        writer.set_id(identification);
322        writer.set_flags(flags);
323        writer.set_fragment_offset(fragment_offset);
324        writer.set_ttl(ttl);
325        writer.set_protocol(protocol);
326        writer.set_src_ip(&src_ip);
327        writer.set_dest_ip(&dest_ip);
328        writer.set_checksum();
329
330        // Create a IPv4 packet reader.
331        let reader = IPv4Reader::new(&bytes).unwrap();
332
333        // Ensure the fields are set and retrieved correctly.
334        assert_eq!(reader.version(), version);
335        assert_eq!(reader.ihl(), ihl);
336        assert_eq!(reader.dscp(), dscp);
337        assert_eq!(reader.ecn(), ecn);
338        assert_eq!(reader.total_length(), total_length);
339        assert_eq!(reader.id(), identification);
340        assert_eq!(reader.flags(), flags);
341        assert_eq!(reader.fragment_offset(), fragment_offset);
342        assert_eq!(reader.ttl(), ttl);
343        assert_eq!(reader.protocol(), protocol);
344        assert_eq!(reader.src_ip(), src_ip);
345        assert_eq!(reader.dest_ip(), dest_ip);
346    }
347}