xenet_packet/
util.rs

1//! Utilities for working with packets, eg. checksumming.
2
3use crate::ip::IpNextLevelProtocol;
4use xenet_macro_helper::types::u16be;
5
6use core::convert::TryInto;
7use core::u16;
8use core::u8;
9use std::net::{Ipv4Addr, Ipv6Addr};
10
11/// Convert a value to a byte array.
12pub trait Octets {
13    /// Output type - bytes array.
14    type Output;
15
16    /// Return a value as bytes (big-endian order).
17    fn octets(&self) -> Self::Output;
18}
19
20impl Octets for u64 {
21    type Output = [u8; 8];
22
23    fn octets(&self) -> Self::Output {
24        [
25            (*self >> 56) as u8,
26            (*self >> 48) as u8,
27            (*self >> 40) as u8,
28            (*self >> 32) as u8,
29            (*self >> 24) as u8,
30            (*self >> 16) as u8,
31            (*self >> 8) as u8,
32            *self as u8,
33        ]
34    }
35}
36
37impl Octets for u32 {
38    type Output = [u8; 4];
39
40    fn octets(&self) -> Self::Output {
41        [
42            (*self >> 24) as u8,
43            (*self >> 16) as u8,
44            (*self >> 8) as u8,
45            *self as u8,
46        ]
47    }
48}
49
50impl Octets for u16 {
51    type Output = [u8; 2];
52
53    fn octets(&self) -> Self::Output {
54        [(*self >> 8) as u8, *self as u8]
55    }
56}
57
58impl Octets for u8 {
59    type Output = [u8; 1];
60
61    fn octets(&self) -> Self::Output {
62        [*self]
63    }
64}
65
66/// Calculates a checksum. Used by ipv4 and icmp. The two bytes starting at `skipword * 2` will be
67/// ignored. Supposed to be the checksum field, which is regarded as zero during calculation.
68pub fn checksum(data: &[u8], skipword: usize) -> u16be {
69    if data.len() == 0 {
70        return 0;
71    }
72    let sum = sum_be_words(data, skipword);
73    finalize_checksum(sum)
74}
75
76fn finalize_checksum(mut sum: u32) -> u16be {
77    while sum >> 16 != 0 {
78        sum = (sum >> 16) + (sum & 0xFFFF);
79    }
80    !sum as u16
81}
82
83/// Calculate the checksum for a packet built on IPv4. Used by UDP and TCP.
84pub fn ipv4_checksum(
85    data: &[u8],
86    skipword: usize,
87    extra_data: &[u8],
88    source: &Ipv4Addr,
89    destination: &Ipv4Addr,
90    next_level_protocol: IpNextLevelProtocol,
91) -> u16be {
92    let mut sum = 0u32;
93
94    // Checksum pseudo-header
95    sum += ipv4_word_sum(source);
96    sum += ipv4_word_sum(destination);
97    sum += next_level_protocol as u32;
98
99    let len = data.len() + extra_data.len();
100    sum += len as u32;
101
102    // Checksum packet header and data
103    sum += sum_be_words(data, skipword);
104    sum += sum_be_words(extra_data, extra_data.len() / 2);
105
106    finalize_checksum(sum)
107}
108
109fn ipv4_word_sum(ip: &Ipv4Addr) -> u32 {
110    let octets = ip.octets();
111    ((octets[0] as u32) << 8 | octets[1] as u32) + ((octets[2] as u32) << 8 | octets[3] as u32)
112}
113
114/// Calculate the checksum for a packet built on IPv6.
115pub fn ipv6_checksum(
116    data: &[u8],
117    skipword: usize,
118    extra_data: &[u8],
119    source: &Ipv6Addr,
120    destination: &Ipv6Addr,
121    next_level_protocol: IpNextLevelProtocol,
122) -> u16be {
123    let mut sum = 0u32;
124
125    // Checksum pseudo-header
126    sum += ipv6_word_sum(source);
127    sum += ipv6_word_sum(destination);
128    sum += next_level_protocol as u32;
129
130    let len = data.len() + extra_data.len();
131    sum += len as u32;
132
133    // Checksum packet header and data
134    sum += sum_be_words(data, skipword);
135    sum += sum_be_words(extra_data, extra_data.len() / 2);
136
137    finalize_checksum(sum)
138}
139
140fn ipv6_word_sum(ip: &Ipv6Addr) -> u32 {
141    ip.segments().iter().map(|x| *x as u32).sum()
142}
143
144/// Sum all words (16 bit chunks) in the given data. The word at word offset
145/// `skipword` will be skipped. Each word is treated as big endian.
146fn sum_be_words(data: &[u8], skipword: usize) -> u32 {
147    if data.len() == 0 {
148        return 0;
149    }
150    let len = data.len();
151    let mut cur_data = &data[..];
152    let mut sum = 0u32;
153    let mut i = 0;
154    while cur_data.len() >= 2 {
155        if i != skipword {
156            // It's safe to unwrap because we verified there are at least 2 bytes
157            sum += u16::from_be_bytes(cur_data[0..2].try_into().unwrap()) as u32;
158        }
159        cur_data = &cur_data[2..];
160        i += 1;
161    }
162
163    // If the length is odd, make sure to checksum the final byte
164    if i != skipword && len & 1 != 0 {
165        sum += (data[len - 1] as u32) << 8;
166    }
167
168    sum
169}
170
171#[cfg(test)]
172mod tests {
173    use super::sum_be_words;
174    use alloc::{vec, vec::Vec};
175    use core::slice;
176
177    #[test]
178    fn sum_be_words_different_skipwords() {
179        let data = (0..11).collect::<Vec<u8>>();
180        assert_eq!(7190, sum_be_words(&data, 1));
181        assert_eq!(6676, sum_be_words(&data, 2));
182        // Assert having the skipword outside the range gives correct and equal
183        // results
184        assert_eq!(7705, sum_be_words(&data, 99));
185        assert_eq!(7705, sum_be_words(&data, 101));
186    }
187
188    #[test]
189    fn sum_be_words_small_sizes() {
190        let data_zero = vec![0; 0];
191        assert_eq!(0, sum_be_words(&data_zero, 0));
192        assert_eq!(0, sum_be_words(&data_zero, 10));
193        let data_one = vec![1; 1];
194        assert_eq!(0, sum_be_words(&data_zero, 0));
195        assert_eq!(256, sum_be_words(&data_one, 1));
196        let data_two = vec![1; 2];
197        assert_eq!(0, sum_be_words(&data_two, 0));
198        assert_eq!(257, sum_be_words(&data_two, 1));
199        let data_three = vec![4; 3];
200        assert_eq!(1024, sum_be_words(&data_three, 0));
201        assert_eq!(1028, sum_be_words(&data_three, 1));
202        assert_eq!(2052, sum_be_words(&data_three, 2));
203        assert_eq!(2052, sum_be_words(&data_three, 3));
204    }
205
206    #[test]
207    fn sum_be_words_misaligned_ptr() {
208        let mut data = vec![0; 13];
209        let ptr = match data.as_ptr() as usize % 2 {
210            0 => unsafe { data.as_mut_ptr().offset(1) },
211            _ => data.as_mut_ptr(),
212        };
213        unsafe {
214            let slice_data = slice::from_raw_parts_mut(ptr, 12);
215            for i in 0..11 {
216                slice_data[i] = i as u8;
217            }
218            assert_eq!(7190, sum_be_words(&slice_data, 1));
219            assert_eq!(6676, sum_be_words(&slice_data, 2));
220            // Assert having the skipword outside the range gives correct and equal
221            // results
222            assert_eq!(7705, sum_be_words(&slice_data, 99));
223            assert_eq!(7705, sum_be_words(&slice_data, 101));
224        }
225    }
226}
227
228#[cfg(all(test, feature = "benchmark"))]
229mod checksum_benchmarks {
230    use super::checksum;
231    use test::{black_box, Bencher};
232
233    #[bench]
234    fn bench_checksum_small(b: &mut Bencher) {
235        let data = vec![99u8; 20];
236        b.iter(|| checksum(black_box(&data), 5));
237    }
238
239    #[bench]
240    fn bench_checksum_large(b: &mut Bencher) {
241        let data = vec![123u8; 1024];
242        b.iter(|| checksum(black_box(&data), 5));
243    }
244}