sqlx_mysql/protocol/
packet.rs

1use std::cmp::min;
2use std::ops::{Deref, DerefMut};
3
4use bytes::Bytes;
5
6use crate::error::Error;
7use crate::io::{ProtocolDecode, ProtocolEncode};
8use crate::protocol::response::{EofPacket, OkPacket};
9use crate::protocol::Capabilities;
10
11#[derive(Debug)]
12pub struct Packet<T>(pub(crate) T);
13
14impl<'en, 'stream, T> ProtocolEncode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
15where
16    T: ProtocolEncode<'en, Capabilities>,
17{
18    fn encode_with(
19        &self,
20        buf: &mut Vec<u8>,
21        (capabilities, sequence_id): (Capabilities, &'stream mut u8),
22    ) -> Result<(), Error> {
23        let mut next_header = |len: u32| {
24            let mut buf = len.to_le_bytes();
25            buf[3] = *sequence_id;
26            *sequence_id = sequence_id.wrapping_add(1);
27
28            buf
29        };
30
31        // reserve space to write the prefixed length
32        let offset = buf.len();
33        buf.extend(&[0_u8; 4]);
34
35        // encode the payload
36        self.0.encode_with(buf, capabilities)?;
37
38        // determine the length of the encoded payload
39        // and write to our reserved space
40        let len = buf.len() - offset - 4;
41        let header = &mut buf[offset..];
42
43        // // `min(.., 0xFF_FF_FF)` cannot overflow
44        #[allow(clippy::cast_possible_truncation)]
45        header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32));
46
47        // add more packets if we need to split the data
48        if len >= 0xFF_FF_FF {
49            let rest = buf.split_off(offset + 4 + 0xFF_FF_FF);
50            let mut chunks = rest.chunks_exact(0xFF_FF_FF);
51
52            for chunk in chunks.by_ref() {
53                buf.reserve(chunk.len() + 4);
54
55                // `chunk.len() == 0xFF_FF_FF`
56                #[allow(clippy::cast_possible_truncation)]
57                buf.extend(&next_header(chunk.len() as u32));
58                buf.extend(chunk);
59            }
60
61            // this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF
62            let remainder = chunks.remainder();
63            buf.reserve(remainder.len() + 4);
64
65            // `remainder.len() < 0xFF_FF_FF`
66            #[allow(clippy::cast_possible_truncation)]
67            buf.extend(&next_header(remainder.len() as u32));
68            buf.extend(remainder);
69        }
70
71        Ok(())
72    }
73}
74
75impl Packet<Bytes> {
76    pub(crate) fn decode<'de, T>(self) -> Result<T, Error>
77    where
78        T: ProtocolDecode<'de, ()>,
79    {
80        self.decode_with(())
81    }
82
83    pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result<T, Error>
84    where
85        T: ProtocolDecode<'de, C>,
86    {
87        T::decode_with(self.0, context)
88    }
89
90    pub(crate) fn ok(self) -> Result<OkPacket, Error> {
91        self.decode()
92    }
93
94    pub(crate) fn eof(self, capabilities: Capabilities) -> Result<EofPacket, Error> {
95        if capabilities.contains(Capabilities::DEPRECATE_EOF) {
96            let ok = self.ok()?;
97
98            Ok(EofPacket {
99                warnings: ok.warnings,
100                status: ok.status,
101            })
102        } else {
103            self.decode_with(capabilities)
104        }
105    }
106}
107
108impl Deref for Packet<Bytes> {
109    type Target = Bytes;
110
111    fn deref(&self) -> &Bytes {
112        &self.0
113    }
114}
115
116impl DerefMut for Packet<Bytes> {
117    fn deref_mut(&mut self) -> &mut Bytes {
118        &mut self.0
119    }
120}