sqlx_mysql/connection/
stream.rs

1use std::collections::VecDeque;
2use std::ops::{Deref, DerefMut};
3
4use bytes::{Buf, Bytes, BytesMut};
5
6use crate::collation::{CharSet, Collation};
7use crate::error::Error;
8use crate::io::MySqlBufExt;
9use crate::io::{ProtocolDecode, ProtocolEncode};
10use crate::net::{BufferedSocket, Socket};
11use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
12use crate::protocol::{Capabilities, Packet};
13use crate::{MySqlConnectOptions, MySqlDatabaseError};
14
15pub struct MySqlStream<S = Box<dyn Socket>> {
16    // Wrapping the socket in `Box` allows us to unsize in-place.
17    pub(crate) socket: BufferedSocket<S>,
18    pub(crate) server_version: (u16, u16, u16),
19    pub(super) capabilities: Capabilities,
20    pub(crate) sequence_id: u8,
21    pub(crate) waiting: VecDeque<Waiting>,
22    pub(crate) charset: CharSet,
23    pub(crate) collation: Collation,
24    pub(crate) is_tls: bool,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub(crate) enum Waiting {
29    // waiting for a result set
30    Result,
31
32    // waiting for a row within a result set
33    Row,
34}
35
36impl<S: Socket> MySqlStream<S> {
37    pub(crate) fn with_socket(
38        charset: CharSet,
39        collation: Collation,
40        options: &MySqlConnectOptions,
41        socket: S,
42    ) -> Self {
43        let mut capabilities = Capabilities::PROTOCOL_41
44            | Capabilities::IGNORE_SPACE
45            | Capabilities::DEPRECATE_EOF
46            | Capabilities::FOUND_ROWS
47            | Capabilities::TRANSACTIONS
48            | Capabilities::SECURE_CONNECTION
49            | Capabilities::PLUGIN_AUTH_LENENC_DATA
50            | Capabilities::MULTI_STATEMENTS
51            | Capabilities::MULTI_RESULTS
52            | Capabilities::PLUGIN_AUTH
53            | Capabilities::PS_MULTI_RESULTS
54            | Capabilities::SSL;
55
56        if options.database.is_some() {
57            capabilities |= Capabilities::CONNECT_WITH_DB;
58        }
59
60        Self {
61            waiting: VecDeque::new(),
62            capabilities,
63            server_version: (0, 0, 0),
64            sequence_id: 0,
65            collation,
66            charset,
67            socket: BufferedSocket::new(socket),
68            is_tls: false,
69        }
70    }
71
72    pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
73        if !self.socket.write_buffer().is_empty() {
74            self.socket.flush().await?;
75        }
76
77        while !self.waiting.is_empty() {
78            while self.waiting.front() == Some(&Waiting::Row) {
79                let packet = self.recv_packet().await?;
80
81                if !packet.is_empty() && packet[0] == 0xfe && packet.len() < 9 {
82                    let eof = packet.eof(self.capabilities)?;
83
84                    if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
85                        *self.waiting.front_mut().unwrap() = Waiting::Result;
86                    } else {
87                        self.waiting.pop_front();
88                    };
89                }
90            }
91
92            while self.waiting.front() == Some(&Waiting::Result) {
93                let packet = self.recv_packet().await?;
94
95                if !packet.is_empty() && (packet[0] == 0x00 || packet[0] == 0xff) {
96                    let ok = packet.ok()?;
97
98                    if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
99                        self.waiting.pop_front();
100                    }
101                } else {
102                    *self.waiting.front_mut().unwrap() = Waiting::Row;
103                    self.skip_result_metadata(packet).await?;
104                }
105            }
106        }
107
108        Ok(())
109    }
110
111    pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
112    where
113        T: ProtocolEncode<'en, Capabilities>,
114    {
115        self.sequence_id = 0;
116        self.write_packet(payload)?;
117        self.flush().await?;
118        Ok(())
119    }
120
121    pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
122    where
123        T: ProtocolEncode<'en, Capabilities>,
124    {
125        self.socket
126            .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id))
127    }
128
129    async fn recv_packet_part(&mut self) -> Result<Bytes, Error> {
130        // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
131        // https://mariadb.com/kb/en/library/0-packet/#standard-packet
132
133        let mut header: Bytes = self.socket.read(4).await?;
134
135        // cannot overflow
136        #[allow(clippy::cast_possible_truncation)]
137        let packet_size = header.get_uint_le(3) as usize;
138        let sequence_id = header.get_u8();
139
140        self.sequence_id = sequence_id.wrapping_add(1);
141
142        let payload: Bytes = self.socket.read(packet_size).await?;
143
144        // TODO: packet compression
145
146        Ok(payload)
147    }
148
149    // receive the next packet from the database server
150    // may block (async) on more data from the server
151    pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
152        let payload = self.recv_packet_part().await?;
153        let payload = if payload.len() < 0xFF_FF_FF {
154            payload
155        } else {
156            let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2);
157            final_payload.extend_from_slice(&payload);
158
159            drop(payload); // we don't need the allocation anymore
160
161            let mut last_read = 0xFF_FF_FF;
162            while last_read == 0xFF_FF_FF {
163                let part = self.recv_packet_part().await?;
164                last_read = part.len();
165                final_payload.extend_from_slice(&part);
166            }
167            final_payload.into()
168        };
169
170        if payload
171            .first()
172            .ok_or(err_protocol!("Packet empty"))?
173            .eq(&0xff)
174        {
175            self.waiting.pop_front();
176
177            // instead of letting this packet be looked at everywhere, we check here
178            // and emit a proper Error
179            return Err(
180                MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(),
181            );
182        }
183
184        Ok(Packet(payload))
185    }
186
187    pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
188    where
189        T: ProtocolDecode<'de, Capabilities>,
190    {
191        self.recv_packet().await?.decode_with(self.capabilities)
192    }
193
194    pub(crate) async fn recv_ok(&mut self) -> Result<OkPacket, Error> {
195        self.recv_packet().await?.ok()
196    }
197
198    pub(crate) async fn maybe_recv_eof(&mut self) -> Result<Option<EofPacket>, Error> {
199        if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
200            Ok(None)
201        } else {
202            self.recv().await.map(Some)
203        }
204    }
205
206    async fn skip_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
207        let num_columns: u64 = packet.get_uint_lenenc(); // column count
208
209        for _ in 0..num_columns {
210            let _ = self.recv_packet().await?;
211        }
212
213        self.maybe_recv_eof().await?;
214
215        Ok(())
216    }
217
218    pub fn boxed_socket(self) -> MySqlStream {
219        MySqlStream {
220            socket: self.socket.boxed(),
221            server_version: self.server_version,
222            capabilities: self.capabilities,
223            sequence_id: self.sequence_id,
224            waiting: self.waiting,
225            charset: self.charset,
226            collation: self.collation,
227            is_tls: self.is_tls,
228        }
229    }
230}
231
232impl<S> Deref for MySqlStream<S> {
233    type Target = BufferedSocket<S>;
234
235    fn deref(&self) -> &Self::Target {
236        &self.socket
237    }
238}
239
240impl<S> DerefMut for MySqlStream<S> {
241    fn deref_mut(&mut self) -> &mut Self::Target {
242        &mut self.socket
243    }
244}