1use crate::ip::IpNextLevelProtocol;
4use xenet_macro_helper::types::u16be;
5
6use core::convert::TryInto;
7use core::u16;
8use core::u8;
9use std::net::{Ipv4Addr, Ipv6Addr};
10
11pub trait Octets {
13 type Output;
15
16 fn octets(&self) -> Self::Output;
18}
19
20impl Octets for u64 {
21 type Output = [u8; 8];
22
23 fn octets(&self) -> Self::Output {
24 [
25 (*self >> 56) as u8,
26 (*self >> 48) as u8,
27 (*self >> 40) as u8,
28 (*self >> 32) as u8,
29 (*self >> 24) as u8,
30 (*self >> 16) as u8,
31 (*self >> 8) as u8,
32 *self as u8,
33 ]
34 }
35}
36
37impl Octets for u32 {
38 type Output = [u8; 4];
39
40 fn octets(&self) -> Self::Output {
41 [
42 (*self >> 24) as u8,
43 (*self >> 16) as u8,
44 (*self >> 8) as u8,
45 *self as u8,
46 ]
47 }
48}
49
50impl Octets for u16 {
51 type Output = [u8; 2];
52
53 fn octets(&self) -> Self::Output {
54 [(*self >> 8) as u8, *self as u8]
55 }
56}
57
58impl Octets for u8 {
59 type Output = [u8; 1];
60
61 fn octets(&self) -> Self::Output {
62 [*self]
63 }
64}
65
66pub fn checksum(data: &[u8], skipword: usize) -> u16be {
69 if data.len() == 0 {
70 return 0;
71 }
72 let sum = sum_be_words(data, skipword);
73 finalize_checksum(sum)
74}
75
76fn finalize_checksum(mut sum: u32) -> u16be {
77 while sum >> 16 != 0 {
78 sum = (sum >> 16) + (sum & 0xFFFF);
79 }
80 !sum as u16
81}
82
83pub fn ipv4_checksum(
85 data: &[u8],
86 skipword: usize,
87 extra_data: &[u8],
88 source: &Ipv4Addr,
89 destination: &Ipv4Addr,
90 next_level_protocol: IpNextLevelProtocol,
91) -> u16be {
92 let mut sum = 0u32;
93
94 sum += ipv4_word_sum(source);
96 sum += ipv4_word_sum(destination);
97 sum += next_level_protocol as u32;
98
99 let len = data.len() + extra_data.len();
100 sum += len as u32;
101
102 sum += sum_be_words(data, skipword);
104 sum += sum_be_words(extra_data, extra_data.len() / 2);
105
106 finalize_checksum(sum)
107}
108
109fn ipv4_word_sum(ip: &Ipv4Addr) -> u32 {
110 let octets = ip.octets();
111 ((octets[0] as u32) << 8 | octets[1] as u32) + ((octets[2] as u32) << 8 | octets[3] as u32)
112}
113
114pub fn ipv6_checksum(
116 data: &[u8],
117 skipword: usize,
118 extra_data: &[u8],
119 source: &Ipv6Addr,
120 destination: &Ipv6Addr,
121 next_level_protocol: IpNextLevelProtocol,
122) -> u16be {
123 let mut sum = 0u32;
124
125 sum += ipv6_word_sum(source);
127 sum += ipv6_word_sum(destination);
128 sum += next_level_protocol as u32;
129
130 let len = data.len() + extra_data.len();
131 sum += len as u32;
132
133 sum += sum_be_words(data, skipword);
135 sum += sum_be_words(extra_data, extra_data.len() / 2);
136
137 finalize_checksum(sum)
138}
139
140fn ipv6_word_sum(ip: &Ipv6Addr) -> u32 {
141 ip.segments().iter().map(|x| *x as u32).sum()
142}
143
144fn sum_be_words(data: &[u8], skipword: usize) -> u32 {
147 if data.len() == 0 {
148 return 0;
149 }
150 let len = data.len();
151 let mut cur_data = &data[..];
152 let mut sum = 0u32;
153 let mut i = 0;
154 while cur_data.len() >= 2 {
155 if i != skipword {
156 sum += u16::from_be_bytes(cur_data[0..2].try_into().unwrap()) as u32;
158 }
159 cur_data = &cur_data[2..];
160 i += 1;
161 }
162
163 if i != skipword && len & 1 != 0 {
165 sum += (data[len - 1] as u32) << 8;
166 }
167
168 sum
169}
170
171#[cfg(test)]
172mod tests {
173 use super::sum_be_words;
174 use alloc::{vec, vec::Vec};
175 use core::slice;
176
177 #[test]
178 fn sum_be_words_different_skipwords() {
179 let data = (0..11).collect::<Vec<u8>>();
180 assert_eq!(7190, sum_be_words(&data, 1));
181 assert_eq!(6676, sum_be_words(&data, 2));
182 assert_eq!(7705, sum_be_words(&data, 99));
185 assert_eq!(7705, sum_be_words(&data, 101));
186 }
187
188 #[test]
189 fn sum_be_words_small_sizes() {
190 let data_zero = vec![0; 0];
191 assert_eq!(0, sum_be_words(&data_zero, 0));
192 assert_eq!(0, sum_be_words(&data_zero, 10));
193 let data_one = vec![1; 1];
194 assert_eq!(0, sum_be_words(&data_zero, 0));
195 assert_eq!(256, sum_be_words(&data_one, 1));
196 let data_two = vec![1; 2];
197 assert_eq!(0, sum_be_words(&data_two, 0));
198 assert_eq!(257, sum_be_words(&data_two, 1));
199 let data_three = vec![4; 3];
200 assert_eq!(1024, sum_be_words(&data_three, 0));
201 assert_eq!(1028, sum_be_words(&data_three, 1));
202 assert_eq!(2052, sum_be_words(&data_three, 2));
203 assert_eq!(2052, sum_be_words(&data_three, 3));
204 }
205
206 #[test]
207 fn sum_be_words_misaligned_ptr() {
208 let mut data = vec![0; 13];
209 let ptr = match data.as_ptr() as usize % 2 {
210 0 => unsafe { data.as_mut_ptr().offset(1) },
211 _ => data.as_mut_ptr(),
212 };
213 unsafe {
214 let slice_data = slice::from_raw_parts_mut(ptr, 12);
215 for i in 0..11 {
216 slice_data[i] = i as u8;
217 }
218 assert_eq!(7190, sum_be_words(&slice_data, 1));
219 assert_eq!(6676, sum_be_words(&slice_data, 2));
220 assert_eq!(7705, sum_be_words(&slice_data, 99));
223 assert_eq!(7705, sum_be_words(&slice_data, 101));
224 }
225 }
226}
227
228#[cfg(all(test, feature = "benchmark"))]
229mod checksum_benchmarks {
230 use super::checksum;
231 use test::{black_box, Bencher};
232
233 #[bench]
234 fn bench_checksum_small(b: &mut Bencher) {
235 let data = vec![99u8; 20];
236 b.iter(|| checksum(black_box(&data), 5));
237 }
238
239 #[bench]
240 fn bench_checksum_large(b: &mut Bencher) {
241 let data = vec![123u8; 1024];
242 b.iter(|| checksum(black_box(&data), 5));
243 }
244}