s2n_quic_core/inet/
checksum.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::{fmt, hash::Hasher, num::Wrapping};
5
6#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
7mod x86;
8
9/// Computes the [IP checksum](https://www.rfc-editor.org/rfc/rfc1071) over the given slice of bytes
10#[inline]
11pub fn checksum(data: &[u8]) -> u16 {
12    let mut checksum = Checksum::default();
13    checksum.write(data);
14    checksum.finish()
15}
16
17/// Minimum size for a payload to be considered for platform-specific code
18const LARGE_WRITE_LEN: usize = 32;
19
20type Accumulator = u64;
21type State = Wrapping<Accumulator>;
22
23/// Platform-specific function for computing a checksum
24type LargeWriteFn = for<'a> unsafe fn(&mut State, bytes: &'a [u8]) -> &'a [u8];
25
26#[inline(always)]
27fn write_sized_generic<'a, const MAX_LEN: usize, const CHUNK_LEN: usize>(
28    state: &mut State,
29    mut bytes: &'a [u8],
30    on_chunk: impl Fn(&[u8; CHUNK_LEN], &mut Accumulator),
31) -> &'a [u8] {
32    //= https://www.rfc-editor.org/rfc/rfc1071#section-4.1
33    //# The following "C" code algorithm computes the checksum with an inner
34    //# loop that sums 16-bits at a time in a 32-bit accumulator.
35    //#
36    //# in 6
37    //#    {
38    //#        /* Compute Internet Checksum for "count" bytes
39    //#         *         beginning at location "addr".
40    //#         */
41    //#    register long sum = 0;
42    //#
43    //#     while( count > 1 )  {
44    //#        /*  This is the inner loop */
45    //#            sum += * (unsigned short) addr++;
46    //#            count -= 2;
47    //#    }
48    //#
49    //#        /*  Add left-over byte, if any */
50    //#    if( count > 0 )
51    //#            sum += * (unsigned char *) addr;
52    //#
53    //#        /*  Fold 32-bit sum to 16 bits */
54    //#    while (sum>>16)
55    //#        sum = (sum & 0xffff) + (sum >> 16);
56    //#
57    //#    checksum = ~sum;
58    //# }
59
60    while bytes.len() >= MAX_LEN {
61        // use `get_unchecked` to make it easier for kani to analyze
62        let chunks = unsafe { bytes.get_unchecked(..MAX_LEN) };
63        bytes = unsafe { bytes.get_unchecked(MAX_LEN..) };
64
65        let mut sum = 0;
66        // for each pair of bytes, interpret them as integers and sum them up
67        for chunk in chunks.chunks_exact(CHUNK_LEN) {
68            let chunk = unsafe {
69                // SAFETY: chunks_exact always produces a slice of CHUNK_LEN
70                debug_assert_eq!(chunk.len(), CHUNK_LEN);
71                &*(chunk.as_ptr() as *const [u8; CHUNK_LEN])
72            };
73            on_chunk(chunk, &mut sum);
74        }
75        *state += sum;
76    }
77
78    bytes
79}
80
81/// Generic implementation of a function that computes a checksum over the given slice
82#[inline(always)]
83fn write_sized_generic_u16<'a, const LEN: usize>(state: &mut State, bytes: &'a [u8]) -> &'a [u8] {
84    write_sized_generic::<LEN, 2>(
85        state,
86        bytes,
87        #[inline(always)]
88        |&bytes, acc| {
89            *acc += u16::from_ne_bytes(bytes) as Accumulator;
90        },
91    )
92}
93
94#[inline(always)]
95fn write_sized_generic_u32<'a, const LEN: usize>(state: &mut State, bytes: &'a [u8]) -> &'a [u8] {
96    write_sized_generic::<LEN, 4>(
97        state,
98        bytes,
99        #[inline(always)]
100        |&bytes, acc| {
101            *acc += u32::from_ne_bytes(bytes) as Accumulator;
102        },
103    )
104}
105
106/// Returns the most optimized function implementation for the current platform
107#[inline]
108#[cfg(all(feature = "once_cell", not(any(kani, miri))))]
109fn probe_write_large() -> LargeWriteFn {
110    static LARGE_WRITE_FN: once_cell::sync::Lazy<LargeWriteFn> = once_cell::sync::Lazy::new(|| {
111        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
112        {
113            if let Some(fun) = x86::probe() {
114                return fun;
115            }
116        }
117
118        write_sized_generic_u32::<16>
119    });
120
121    *LARGE_WRITE_FN
122}
123
124#[inline]
125#[cfg(not(all(feature = "once_cell", not(any(kani, miri)))))]
126fn probe_write_large() -> LargeWriteFn {
127    write_sized_generic_u32::<16>
128}
129
130/// Computes the [IP checksum](https://www.rfc-editor.org/rfc/rfc1071) over an arbitrary set of inputs
131#[derive(Clone, Copy)]
132pub struct Checksum {
133    state: State,
134    partial_write: bool,
135    write_large: LargeWriteFn,
136}
137
138impl Default for Checksum {
139    fn default() -> Self {
140        Self {
141            state: Default::default(),
142            partial_write: false,
143            write_large: probe_write_large(),
144        }
145    }
146}
147
148impl fmt::Debug for Checksum {
149    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
150        let mut v = *self;
151        v.carry();
152        f.debug_tuple("Checksum").field(&v.finish()).finish()
153    }
154}
155
156impl Checksum {
157    /// Creates a checksum instance without enabling the native implementation
158    #[inline]
159    pub fn generic() -> Self {
160        Self {
161            state: Default::default(),
162            partial_write: false,
163            write_large: write_sized_generic_u32::<16>,
164        }
165    }
166
167    /// Writes a single byte to the checksum state
168    #[inline]
169    fn write_byte(&mut self, byte: u8, shift: bool) {
170        if shift {
171            self.state += (byte as Accumulator) << 8;
172        } else {
173            self.state += byte as Accumulator;
174        }
175    }
176
177    /// Carries all of the bits into a single 16 bit range
178    #[inline]
179    fn carry(&mut self) {
180        #[cfg(kani)]
181        self.carry_rfc();
182        #[cfg(not(kani))]
183        self.carry_optimized();
184    }
185
186    /// Carries all of the bits into a single 16 bit range
187    ///
188    /// This implementation is very similar to the way the RFC is written.
189    #[inline]
190    #[allow(dead_code)]
191    fn carry_rfc(&mut self) {
192        let mut state = self.state.0;
193
194        for _ in 0..core::mem::size_of::<Accumulator>() {
195            state = (state & 0xffff) + (state >> 16);
196        }
197
198        self.state.0 = state;
199    }
200
201    /// Carries all of the bits into a single 16 bit range
202    ///
203    /// This implementation was written after some optimization on the RFC version. It results in
204    /// about half the instructions needed as the RFC.
205    #[inline]
206    #[allow(dead_code)]
207    fn carry_optimized(&mut self) {
208        let values: [u16; core::mem::size_of::<Accumulator>() / 2] = unsafe {
209            // SAFETY: alignment of the State is >= of u16
210            debug_assert!(core::mem::align_of::<State>() >= core::mem::align_of::<u16>());
211            core::mem::transmute(self.state.0)
212        };
213
214        let mut sum = 0u16;
215
216        for value in values {
217            let (res, overflowed) = sum.overflowing_add(value);
218            sum = res;
219            if overflowed {
220                sum += 1;
221            }
222        }
223
224        self.state.0 = sum as _;
225    }
226
227    /// Writes bytes to the checksum and ensures any single byte remainders are padded
228    #[inline]
229    pub fn write_padded(&mut self, bytes: &[u8]) {
230        self.write(bytes);
231
232        // write a null byte if `bytes` wasn't 16-bit aligned
233        if core::mem::take(&mut self.partial_write) {
234            self.write_byte(0, cfg!(target_endian = "little"));
235        }
236    }
237
238    /// Computes the final checksum
239    #[inline]
240    pub fn finish(self) -> u16 {
241        self.finish_be().to_be()
242    }
243
244    #[inline]
245    pub fn finish_be(mut self) -> u16 {
246        self.carry();
247
248        let value = self.state.0 as u16;
249        let value = !value;
250
251        // if value is 0, we need to set it to the max value to indicate the checksum was actually
252        // computed
253        if value == 0 {
254            return 0xffff;
255        }
256
257        value
258    }
259}
260
261impl Hasher for Checksum {
262    #[inline]
263    fn write(&mut self, mut bytes: &[u8]) {
264        if bytes.is_empty() {
265            return;
266        }
267
268        // Check to see if we have a partial write to flush
269        if core::mem::take(&mut self.partial_write) {
270            let (chunk, remaining) = bytes.split_at(1);
271            bytes = remaining;
272
273            // shift the byte if we're on little endian
274            self.write_byte(chunk[0], cfg!(target_endian = "little"));
275        }
276
277        // Only delegate to the optimized platform function if the payload is big enough
278        if bytes.len() >= LARGE_WRITE_LEN {
279            bytes = unsafe { (self.write_large)(&mut self.state, bytes) };
280        }
281
282        // Fall back on the generic implementation to wrap things up
283        //
284        // NOTE: We don't use the u32 version with kani as it causes the verification time to
285        // increase by quite a bit. We have a separate proof for the functional equivalence of
286        // these two configurations.
287        #[cfg(not(kani))]
288        {
289            bytes = write_sized_generic_u32::<4>(&mut self.state, bytes);
290        }
291
292        bytes = write_sized_generic_u16::<2>(&mut self.state, bytes);
293
294        // if we only have a single byte left, write it to the state and mark it as a partial write
295        if let Some(byte) = bytes.first().copied() {
296            self.partial_write = true;
297            self.write_byte(byte, cfg!(target_endian = "big"));
298        }
299    }
300
301    #[inline]
302    fn finish(&self) -> u64 {
303        Self::finish(*self) as _
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use bolero::check;
311
312    #[test]
313    fn rfc_example_test() {
314        //= https://www.rfc-editor.org/rfc/rfc1071#section-3
315        //= type=test
316        //# We now present explicit examples of calculating a simple 1's
317        //# complement sum on a 2's complement machine.  The examples show the
318        //# same sum calculated byte by bye, by 16-bits words in normal and
319        //# swapped order, and 32 bits at a time in 3 different orders.  All
320        //# numbers are in hex.
321        //#
322        //#               Byte-by-byte    "Normal"  Swapped
323        //#                                 Order    Order
324        //#
325        //#     Byte 0/1:    00   01        0001      0100
326        //#     Byte 2/3:    f2   03        f203      03f2
327        //#     Byte 4/5:    f4   f5        f4f5      f5f4
328        //#     Byte 6/7:    f6   f7        f6f7      f7f6
329        //#                 ---  ---       -----     -----
330        //#     Sum1:       2dc  1f0       2ddf0     1f2dc
331        //#
332        //#                  dc   f0        ddf0      f2dc
333        //#     Carrys:       1    2           2         1
334        //#                  --   --        ----      ----
335        //#     Sum2:        dd   f2        ddf2      f2dd
336        //#
337        //#     Final Swap:  dd   f2        ddf2      ddf2
338        let bytes = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
339
340        let mut checksum = Checksum::default();
341        checksum.write(&bytes);
342        checksum.carry();
343
344        assert_eq!((checksum.state.0 as u16).to_le_bytes(), [0xdd, 0xf2]);
345        assert_eq!((!rfc_c_port(&bytes)).to_be_bytes(), [0xdd, 0xf2]);
346    }
347
348    fn rfc_c_port(data: &[u8]) -> u16 {
349        //= https://www.rfc-editor.org/rfc/rfc1071#section-4.1
350        //= type=test
351        //# The following "C" code algorithm computes the checksum with an inner
352        //# loop that sums 16-bits at a time in a 32-bit accumulator.
353        //#
354        //# in 6
355        //#    {
356        //#        /* Compute Internet Checksum for "count" bytes
357        //#         *         beginning at location "addr".
358        //#         */
359        //#    register long sum = 0;
360        //#
361        //#     while( count > 1 )  {
362        //#        /*  This is the inner loop */
363        //#            sum += * (unsigned short) addr++;
364        //#            count -= 2;
365        //#    }
366        //#
367        //#        /*  Add left-over byte, if any */
368        //#    if( count > 0 )
369        //#            sum += * (unsigned char *) addr;
370        //#
371        //#        /*  Fold 32-bit sum to 16 bits */
372        //#    while (sum>>16)
373        //#        sum = (sum & 0xffff) + (sum >> 16);
374        //#
375        //#    checksum = ~sum;
376        //# }
377
378        let mut addr = data.as_ptr();
379        let mut count = data.len();
380
381        unsafe {
382            let mut sum = 0u32;
383
384            while count > 1 {
385                let value = u16::from_be_bytes([*addr, *addr.add(1)]);
386                sum = sum.wrapping_add(value as u32);
387                addr = addr.add(2);
388                count -= 2;
389            }
390
391            if count > 0 {
392                let value = u16::from_be_bytes([*addr, 0]);
393                sum = sum.wrapping_add(value as u32);
394            }
395
396            while sum >> 16 != 0 {
397                sum = (sum & 0xffff) + (sum >> 16);
398            }
399
400            !(sum as u16)
401        }
402    }
403
404    #[cfg(any(kani, miri))]
405    const LEN: usize = if cfg!(kani) { 16 } else { 32 };
406
407    /// * Compares the implementation to a port of the C code defined in the RFC
408    /// * Ensures partial writes are correctly handled, even if they're not at a 16 bit boundary
409    #[test]
410    #[cfg_attr(kani, kani::proof, kani::unwind(17), kani::solver(cadical))]
411    fn differential() {
412        #[cfg(any(kani, miri))]
413        type Bytes = crate::testing::InlineVec<u8, LEN>;
414        #[cfg(not(any(kani, miri)))]
415        type Bytes = Vec<u8>;
416
417        check!()
418            .with_type::<(usize, Bytes)>()
419            .for_each(|(index, bytes)| {
420                let index = if bytes.is_empty() {
421                    0
422                } else {
423                    *index % bytes.len()
424                };
425                let (a, b) = bytes.split_at(index);
426                let mut cs = Checksum::default();
427                cs.write(a);
428                cs.write(b);
429
430                let mut rfc_value = rfc_c_port(bytes);
431                if rfc_value == 0 {
432                    rfc_value = 0xffff;
433                }
434
435                assert_eq!(rfc_value.to_be_bytes(), cs.finish().to_be_bytes());
436            });
437    }
438
439    /// Shows that using the u32+u16 methods is the same as only using u16
440    #[test]
441    #[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))]
442    fn u32_u16_differential() {
443        #[cfg(any(kani, miri))]
444        type Bytes = crate::testing::InlineVec<u8, 8>;
445        #[cfg(not(any(kani, miri)))]
446        type Bytes = Vec<u8>;
447
448        check!().with_type::<Bytes>().for_each(|bytes| {
449            let a = {
450                let mut cs = Checksum::generic();
451                let bytes = write_sized_generic_u32::<4>(&mut cs.state, bytes);
452                write_sized_generic_u16::<2>(&mut cs.state, bytes);
453                cs.finish()
454            };
455
456            let b = {
457                let mut cs = Checksum::generic();
458                write_sized_generic_u16::<2>(&mut cs.state, bytes);
459                cs.finish()
460            };
461
462            assert_eq!(a, b);
463        });
464    }
465
466    /// Shows that RFC carry implementation is the same as the optimized version
467    #[test]
468    #[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))]
469    fn carry_differential() {
470        check!().with_type::<u64>().cloned().for_each(|state| {
471            let mut opt = Checksum::generic();
472            opt.state.0 = state;
473            opt.carry_optimized();
474
475            let mut rfc = Checksum::generic();
476            rfc.state.0 = state;
477            rfc.carry_rfc();
478
479            assert_eq!(opt.state.0, rfc.state.0);
480        });
481    }
482}