1use core::{fmt, hash::Hasher, num::Wrapping};
5
6#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
7mod x86;
8
9#[inline]
11pub fn checksum(data: &[u8]) -> u16 {
12 let mut checksum = Checksum::default();
13 checksum.write(data);
14 checksum.finish()
15}
16
17const LARGE_WRITE_LEN: usize = 32;
19
20type Accumulator = u64;
21type State = Wrapping<Accumulator>;
22
23type LargeWriteFn = for<'a> unsafe fn(&mut State, bytes: &'a [u8]) -> &'a [u8];
25
26#[inline(always)]
27fn write_sized_generic<'a, const MAX_LEN: usize, const CHUNK_LEN: usize>(
28 state: &mut State,
29 mut bytes: &'a [u8],
30 on_chunk: impl Fn(&[u8; CHUNK_LEN], &mut Accumulator),
31) -> &'a [u8] {
32 while bytes.len() >= MAX_LEN {
61 let chunks = unsafe { bytes.get_unchecked(..MAX_LEN) };
63 bytes = unsafe { bytes.get_unchecked(MAX_LEN..) };
64
65 let mut sum = 0;
66 for chunk in chunks.chunks_exact(CHUNK_LEN) {
68 let chunk = unsafe {
69 debug_assert_eq!(chunk.len(), CHUNK_LEN);
71 &*(chunk.as_ptr() as *const [u8; CHUNK_LEN])
72 };
73 on_chunk(chunk, &mut sum);
74 }
75 *state += sum;
76 }
77
78 bytes
79}
80
81#[inline(always)]
83fn write_sized_generic_u16<'a, const LEN: usize>(state: &mut State, bytes: &'a [u8]) -> &'a [u8] {
84 write_sized_generic::<LEN, 2>(
85 state,
86 bytes,
87 #[inline(always)]
88 |&bytes, acc| {
89 *acc += u16::from_ne_bytes(bytes) as Accumulator;
90 },
91 )
92}
93
94#[inline(always)]
95fn write_sized_generic_u32<'a, const LEN: usize>(state: &mut State, bytes: &'a [u8]) -> &'a [u8] {
96 write_sized_generic::<LEN, 4>(
97 state,
98 bytes,
99 #[inline(always)]
100 |&bytes, acc| {
101 *acc += u32::from_ne_bytes(bytes) as Accumulator;
102 },
103 )
104}
105
106#[inline]
108#[cfg(all(feature = "once_cell", not(any(kani, miri))))]
109fn probe_write_large() -> LargeWriteFn {
110 static LARGE_WRITE_FN: once_cell::sync::Lazy<LargeWriteFn> = once_cell::sync::Lazy::new(|| {
111 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
112 {
113 if let Some(fun) = x86::probe() {
114 return fun;
115 }
116 }
117
118 write_sized_generic_u32::<16>
119 });
120
121 *LARGE_WRITE_FN
122}
123
124#[inline]
125#[cfg(not(all(feature = "once_cell", not(any(kani, miri)))))]
126fn probe_write_large() -> LargeWriteFn {
127 write_sized_generic_u32::<16>
128}
129
130#[derive(Clone, Copy)]
132pub struct Checksum {
133 state: State,
134 partial_write: bool,
135 write_large: LargeWriteFn,
136}
137
138impl Default for Checksum {
139 fn default() -> Self {
140 Self {
141 state: Default::default(),
142 partial_write: false,
143 write_large: probe_write_large(),
144 }
145 }
146}
147
148impl fmt::Debug for Checksum {
149 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
150 let mut v = *self;
151 v.carry();
152 f.debug_tuple("Checksum").field(&v.finish()).finish()
153 }
154}
155
156impl Checksum {
157 #[inline]
159 pub fn generic() -> Self {
160 Self {
161 state: Default::default(),
162 partial_write: false,
163 write_large: write_sized_generic_u32::<16>,
164 }
165 }
166
167 #[inline]
169 fn write_byte(&mut self, byte: u8, shift: bool) {
170 if shift {
171 self.state += (byte as Accumulator) << 8;
172 } else {
173 self.state += byte as Accumulator;
174 }
175 }
176
177 #[inline]
179 fn carry(&mut self) {
180 #[cfg(kani)]
181 self.carry_rfc();
182 #[cfg(not(kani))]
183 self.carry_optimized();
184 }
185
186 #[inline]
190 #[allow(dead_code)]
191 fn carry_rfc(&mut self) {
192 let mut state = self.state.0;
193
194 for _ in 0..core::mem::size_of::<Accumulator>() {
195 state = (state & 0xffff) + (state >> 16);
196 }
197
198 self.state.0 = state;
199 }
200
201 #[inline]
206 #[allow(dead_code)]
207 fn carry_optimized(&mut self) {
208 let values: [u16; core::mem::size_of::<Accumulator>() / 2] = unsafe {
209 debug_assert!(core::mem::align_of::<State>() >= core::mem::align_of::<u16>());
211 core::mem::transmute(self.state.0)
212 };
213
214 let mut sum = 0u16;
215
216 for value in values {
217 let (res, overflowed) = sum.overflowing_add(value);
218 sum = res;
219 if overflowed {
220 sum += 1;
221 }
222 }
223
224 self.state.0 = sum as _;
225 }
226
227 #[inline]
229 pub fn write_padded(&mut self, bytes: &[u8]) {
230 self.write(bytes);
231
232 if core::mem::take(&mut self.partial_write) {
234 self.write_byte(0, cfg!(target_endian = "little"));
235 }
236 }
237
238 #[inline]
240 pub fn finish(self) -> u16 {
241 self.finish_be().to_be()
242 }
243
244 #[inline]
245 pub fn finish_be(mut self) -> u16 {
246 self.carry();
247
248 let value = self.state.0 as u16;
249 let value = !value;
250
251 if value == 0 {
254 return 0xffff;
255 }
256
257 value
258 }
259}
260
261impl Hasher for Checksum {
262 #[inline]
263 fn write(&mut self, mut bytes: &[u8]) {
264 if bytes.is_empty() {
265 return;
266 }
267
268 if core::mem::take(&mut self.partial_write) {
270 let (chunk, remaining) = bytes.split_at(1);
271 bytes = remaining;
272
273 self.write_byte(chunk[0], cfg!(target_endian = "little"));
275 }
276
277 if bytes.len() >= LARGE_WRITE_LEN {
279 bytes = unsafe { (self.write_large)(&mut self.state, bytes) };
280 }
281
282 #[cfg(not(kani))]
288 {
289 bytes = write_sized_generic_u32::<4>(&mut self.state, bytes);
290 }
291
292 bytes = write_sized_generic_u16::<2>(&mut self.state, bytes);
293
294 if let Some(byte) = bytes.first().copied() {
296 self.partial_write = true;
297 self.write_byte(byte, cfg!(target_endian = "big"));
298 }
299 }
300
301 #[inline]
302 fn finish(&self) -> u64 {
303 Self::finish(*self) as _
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use bolero::check;
311
312 #[test]
313 fn rfc_example_test() {
314 let bytes = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
339
340 let mut checksum = Checksum::default();
341 checksum.write(&bytes);
342 checksum.carry();
343
344 assert_eq!((checksum.state.0 as u16).to_le_bytes(), [0xdd, 0xf2]);
345 assert_eq!((!rfc_c_port(&bytes)).to_be_bytes(), [0xdd, 0xf2]);
346 }
347
348 fn rfc_c_port(data: &[u8]) -> u16 {
349 let mut addr = data.as_ptr();
379 let mut count = data.len();
380
381 unsafe {
382 let mut sum = 0u32;
383
384 while count > 1 {
385 let value = u16::from_be_bytes([*addr, *addr.add(1)]);
386 sum = sum.wrapping_add(value as u32);
387 addr = addr.add(2);
388 count -= 2;
389 }
390
391 if count > 0 {
392 let value = u16::from_be_bytes([*addr, 0]);
393 sum = sum.wrapping_add(value as u32);
394 }
395
396 while sum >> 16 != 0 {
397 sum = (sum & 0xffff) + (sum >> 16);
398 }
399
400 !(sum as u16)
401 }
402 }
403
404 #[cfg(any(kani, miri))]
405 const LEN: usize = if cfg!(kani) { 16 } else { 32 };
406
407 #[test]
410 #[cfg_attr(kani, kani::proof, kani::unwind(17), kani::solver(cadical))]
411 fn differential() {
412 #[cfg(any(kani, miri))]
413 type Bytes = crate::testing::InlineVec<u8, LEN>;
414 #[cfg(not(any(kani, miri)))]
415 type Bytes = Vec<u8>;
416
417 check!()
418 .with_type::<(usize, Bytes)>()
419 .for_each(|(index, bytes)| {
420 let index = if bytes.is_empty() {
421 0
422 } else {
423 *index % bytes.len()
424 };
425 let (a, b) = bytes.split_at(index);
426 let mut cs = Checksum::default();
427 cs.write(a);
428 cs.write(b);
429
430 let mut rfc_value = rfc_c_port(bytes);
431 if rfc_value == 0 {
432 rfc_value = 0xffff;
433 }
434
435 assert_eq!(rfc_value.to_be_bytes(), cs.finish().to_be_bytes());
436 });
437 }
438
439 #[test]
441 #[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))]
442 fn u32_u16_differential() {
443 #[cfg(any(kani, miri))]
444 type Bytes = crate::testing::InlineVec<u8, 8>;
445 #[cfg(not(any(kani, miri)))]
446 type Bytes = Vec<u8>;
447
448 check!().with_type::<Bytes>().for_each(|bytes| {
449 let a = {
450 let mut cs = Checksum::generic();
451 let bytes = write_sized_generic_u32::<4>(&mut cs.state, bytes);
452 write_sized_generic_u16::<2>(&mut cs.state, bytes);
453 cs.finish()
454 };
455
456 let b = {
457 let mut cs = Checksum::generic();
458 write_sized_generic_u16::<2>(&mut cs.state, bytes);
459 cs.finish()
460 };
461
462 assert_eq!(a, b);
463 });
464 }
465
466 #[test]
468 #[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))]
469 fn carry_differential() {
470 check!().with_type::<u64>().cloned().for_each(|state| {
471 let mut opt = Checksum::generic();
472 opt.state.0 = state;
473 opt.carry_optimized();
474
475 let mut rfc = Checksum::generic();
476 rfc.state.0 = state;
477 rfc.carry_rfc();
478
479 assert_eq!(opt.state.0, rfc.state.0);
480 });
481 }
482}