websocket_codec/
mask.rs

1#![allow(clippy::new_without_default)]
2
3#[derive(Copy, Clone, Debug, PartialEq)]
4pub struct Mask(u32);
5
6impl Mask {
7    pub fn new() -> Self {
8        rand::random::<u32>().into()
9    }
10}
11
12impl From<u32> for Mask {
13    fn from(data: u32) -> Self {
14        Mask(data)
15    }
16}
17
18impl From<Mask> for u32 {
19    fn from(mask: Mask) -> Self {
20        mask.0
21    }
22}
23
24/// Masks *by copying* data sent by a client, and unmasks data received by a server.
25pub fn mask_slice_copy(buf: &mut [u8], data: &[u8], Mask(mask): Mask) {
26    assert_eq!(buf.len(), data.len());
27
28    let (buf1, buf2, buf3) = unsafe { buf.align_to_mut() };
29    let (data1, data) = data.split_at(buf1.len());
30    let (data_pre, data2, data3) = unsafe { data.align_to() };
31    if data_pre.is_empty() {
32        let mask = mask_u8_copy(buf1, data1, mask);
33        mask_aligned_copy(buf2, data2, mask);
34        mask_u8_copy(buf3, data3, mask);
35    } else {
36        let (data2, data3) = data.split_at(buf2.len() * 4);
37        let mask = mask_u8_copy(buf1, data1, mask);
38        mask_unaligned_copy(buf2, data2, mask);
39        mask_u8_copy(buf3, data3, mask);
40    }
41}
42
43fn mask_aligned_copy(buf: &mut [u32], data: &[u32], mask: u32) {
44    assert_eq!(buf.len(), data.len());
45
46    for (dest, src) in buf.iter_mut().zip(data) {
47        *dest = src ^ mask;
48    }
49}
50
51fn mask_unaligned_copy(buf: &mut [u32], data: &[u8], mask: u32) {
52    let data = data.chunks_exact(4);
53    assert_eq!(data.len(), buf.len());
54    assert_eq!(data.remainder().len(), 0);
55
56    for (dest, src) in buf.iter_mut().zip(data) {
57        #[allow(clippy::cast_ptr_alignment)]
58        let src = unsafe { (src.as_ptr() as *const u32).read_unaligned() };
59        *dest = src ^ mask;
60    }
61}
62
63fn mask_u8_copy(buf: &mut [u8], data: &[u8], mut mask: u32) -> u32 {
64    assert!(data.len() < 4);
65    assert_eq!(buf.len(), data.len());
66
67    for (dest, &src) in buf.iter_mut().zip(data) {
68        *dest = src ^ (mask as u8);
69        mask = mask.rotate_right(8);
70    }
71
72    mask
73}
74
75/// Masks data sent by a client, and unmasks data received by a server.
76pub fn mask_slice(data: &mut [u8], Mask(mask): Mask) {
77    let (data1, data2, data3) = unsafe { data.align_to_mut() };
78    let mask = mask_u8_in_place(data1, mask);
79    mask_aligned_in_place(data2, mask);
80    mask_u8_in_place(data3, mask);
81}
82
83fn mask_u8_in_place(data: &mut [u8], mut mask: u32) -> u32 {
84    assert!(data.len() < 4);
85
86    for b in data {
87        *b ^= mask as u8;
88        mask = mask.rotate_right(8);
89    }
90
91    mask
92}
93
94fn mask_aligned_in_place(data: &mut [u32], mask: u32) {
95    for n in data {
96        *n ^= mask;
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use assert_allocations::assert_allocated_bytes;
103    use bytes::{BufMut, Bytes, BytesMut};
104
105    use crate::mask::{self, Mask};
106
107    // Test data chosen so that:
108    //  - It's not a multiple of 4, ie masking of the unaligned section works
109    //  - It's longer than bytes::INLINE_CAP = 31 bytes, to force Bytes to make a memory allocation
110    //
111    // Mask chosen so that, per block of four bytes:
112    //  - First byte has all its bits flipped, so it appears in text as an \x sequence higher than \x80
113    //  - Second and third bytes are unchanged
114    //  - Fourth byte has its bottom bit flipped, so in text it's still a recognisable letter
115
116    pub static DATA: &[u8] = b"abcdefghijklmnopqrstuvwxyz123456789";
117
118    static MASKED_DATA: &[u8] = b"\
119        \x9ebce\
120        \x9afgi\
121        \x96jkm\
122        \x92noq\
123        \x8ersu\
124        \x8avwy\
125        \x86z13\
126        \xcc457\
127        \xc889";
128
129    #[test]
130    fn can_mask() {
131        let mask = Mask::from(0xff000001u32.to_be());
132        let orig_data = Bytes::from_static(DATA);
133
134        let mut data = BytesMut::with_capacity(orig_data.len());
135        data.put(orig_data.clone());
136        assert_allocated_bytes(0, || mask::mask_slice(&mut data, mask));
137
138        assert_eq!(b'a' ^ 0xff, data[0]);
139        assert_eq!(b'd' ^ 0x01, data[3]);
140        assert_eq!(MASKED_DATA, &data);
141
142        assert_allocated_bytes(0, || mask::mask_slice(&mut data, mask));
143        assert_eq!(orig_data, data);
144    }
145}