sqlx_mysql/connection/
stream.rs1use std::collections::VecDeque;
2use std::ops::{Deref, DerefMut};
3
4use bytes::{Buf, Bytes, BytesMut};
5
6use crate::collation::{CharSet, Collation};
7use crate::error::Error;
8use crate::io::MySqlBufExt;
9use crate::io::{ProtocolDecode, ProtocolEncode};
10use crate::net::{BufferedSocket, Socket};
11use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
12use crate::protocol::{Capabilities, Packet};
13use crate::{MySqlConnectOptions, MySqlDatabaseError};
14
15pub struct MySqlStream<S = Box<dyn Socket>> {
16 pub(crate) socket: BufferedSocket<S>,
18 pub(crate) server_version: (u16, u16, u16),
19 pub(super) capabilities: Capabilities,
20 pub(crate) sequence_id: u8,
21 pub(crate) waiting: VecDeque<Waiting>,
22 pub(crate) charset: CharSet,
23 pub(crate) collation: Collation,
24 pub(crate) is_tls: bool,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub(crate) enum Waiting {
29 Result,
31
32 Row,
34}
35
36impl<S: Socket> MySqlStream<S> {
37 pub(crate) fn with_socket(
38 charset: CharSet,
39 collation: Collation,
40 options: &MySqlConnectOptions,
41 socket: S,
42 ) -> Self {
43 let mut capabilities = Capabilities::PROTOCOL_41
44 | Capabilities::IGNORE_SPACE
45 | Capabilities::DEPRECATE_EOF
46 | Capabilities::FOUND_ROWS
47 | Capabilities::TRANSACTIONS
48 | Capabilities::SECURE_CONNECTION
49 | Capabilities::PLUGIN_AUTH_LENENC_DATA
50 | Capabilities::MULTI_STATEMENTS
51 | Capabilities::MULTI_RESULTS
52 | Capabilities::PLUGIN_AUTH
53 | Capabilities::PS_MULTI_RESULTS
54 | Capabilities::SSL;
55
56 if options.database.is_some() {
57 capabilities |= Capabilities::CONNECT_WITH_DB;
58 }
59
60 Self {
61 waiting: VecDeque::new(),
62 capabilities,
63 server_version: (0, 0, 0),
64 sequence_id: 0,
65 collation,
66 charset,
67 socket: BufferedSocket::new(socket),
68 is_tls: false,
69 }
70 }
71
72 pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
73 if !self.socket.write_buffer().is_empty() {
74 self.socket.flush().await?;
75 }
76
77 while !self.waiting.is_empty() {
78 while self.waiting.front() == Some(&Waiting::Row) {
79 let packet = self.recv_packet().await?;
80
81 if !packet.is_empty() && packet[0] == 0xfe && packet.len() < 9 {
82 let eof = packet.eof(self.capabilities)?;
83
84 if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
85 *self.waiting.front_mut().unwrap() = Waiting::Result;
86 } else {
87 self.waiting.pop_front();
88 };
89 }
90 }
91
92 while self.waiting.front() == Some(&Waiting::Result) {
93 let packet = self.recv_packet().await?;
94
95 if !packet.is_empty() && (packet[0] == 0x00 || packet[0] == 0xff) {
96 let ok = packet.ok()?;
97
98 if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
99 self.waiting.pop_front();
100 }
101 } else {
102 *self.waiting.front_mut().unwrap() = Waiting::Row;
103 self.skip_result_metadata(packet).await?;
104 }
105 }
106 }
107
108 Ok(())
109 }
110
111 pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
112 where
113 T: ProtocolEncode<'en, Capabilities>,
114 {
115 self.sequence_id = 0;
116 self.write_packet(payload)?;
117 self.flush().await?;
118 Ok(())
119 }
120
121 pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
122 where
123 T: ProtocolEncode<'en, Capabilities>,
124 {
125 self.socket
126 .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id))
127 }
128
129 async fn recv_packet_part(&mut self) -> Result<Bytes, Error> {
130 let mut header: Bytes = self.socket.read(4).await?;
134
135 #[allow(clippy::cast_possible_truncation)]
137 let packet_size = header.get_uint_le(3) as usize;
138 let sequence_id = header.get_u8();
139
140 self.sequence_id = sequence_id.wrapping_add(1);
141
142 let payload: Bytes = self.socket.read(packet_size).await?;
143
144 Ok(payload)
147 }
148
149 pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
152 let payload = self.recv_packet_part().await?;
153 let payload = if payload.len() < 0xFF_FF_FF {
154 payload
155 } else {
156 let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2);
157 final_payload.extend_from_slice(&payload);
158
159 drop(payload); let mut last_read = 0xFF_FF_FF;
162 while last_read == 0xFF_FF_FF {
163 let part = self.recv_packet_part().await?;
164 last_read = part.len();
165 final_payload.extend_from_slice(&part);
166 }
167 final_payload.into()
168 };
169
170 if payload
171 .first()
172 .ok_or(err_protocol!("Packet empty"))?
173 .eq(&0xff)
174 {
175 self.waiting.pop_front();
176
177 return Err(
180 MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(),
181 );
182 }
183
184 Ok(Packet(payload))
185 }
186
187 pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
188 where
189 T: ProtocolDecode<'de, Capabilities>,
190 {
191 self.recv_packet().await?.decode_with(self.capabilities)
192 }
193
194 pub(crate) async fn recv_ok(&mut self) -> Result<OkPacket, Error> {
195 self.recv_packet().await?.ok()
196 }
197
198 pub(crate) async fn maybe_recv_eof(&mut self) -> Result<Option<EofPacket>, Error> {
199 if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
200 Ok(None)
201 } else {
202 self.recv().await.map(Some)
203 }
204 }
205
206 async fn skip_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
207 let num_columns: u64 = packet.get_uint_lenenc(); for _ in 0..num_columns {
210 let _ = self.recv_packet().await?;
211 }
212
213 self.maybe_recv_eof().await?;
214
215 Ok(())
216 }
217
218 pub fn boxed_socket(self) -> MySqlStream {
219 MySqlStream {
220 socket: self.socket.boxed(),
221 server_version: self.server_version,
222 capabilities: self.capabilities,
223 sequence_id: self.sequence_id,
224 waiting: self.waiting,
225 charset: self.charset,
226 collation: self.collation,
227 is_tls: self.is_tls,
228 }
229 }
230}
231
232impl<S> Deref for MySqlStream<S> {
233 type Target = BufferedSocket<S>;
234
235 fn deref(&self) -> &Self::Target {
236 &self.socket
237 }
238}
239
240impl<S> DerefMut for MySqlStream<S> {
241 fn deref_mut(&mut self) -> &mut Self::Target {
242 &mut self.socket
243 }
244}