sqlx_core/mysql/connection/
stream.rs1use std::collections::VecDeque;
2use std::ops::{Deref, DerefMut};
3
4use bytes::{Buf, Bytes};
5
6use crate::error::Error;
7use crate::io::{BufStream, Decode, Encode};
8use crate::mysql::collation::{CharSet, Collation};
9use crate::mysql::io::MySqlBufExt;
10use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
11use crate::mysql::protocol::{Capabilities, Packet};
12use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError};
13use crate::net::{MaybeTlsStream, Socket};
14
15pub struct MySqlStream {
16 stream: BufStream<MaybeTlsStream<Socket>>,
17 pub(crate) server_version: (u16, u16, u16),
18 pub(super) capabilities: Capabilities,
19 pub(crate) sequence_id: u8,
20 pub(crate) waiting: VecDeque<Waiting>,
21 pub(crate) charset: CharSet,
22 pub(crate) collation: Collation,
23}
24
25#[derive(Debug, PartialEq, Eq)]
26pub(crate) enum Waiting {
27 Result,
29
30 Row,
32}
33
34impl MySqlStream {
35 pub(super) async fn connect(options: &MySqlConnectOptions) -> Result<Self, Error> {
36 let charset: CharSet = options.charset.parse()?;
37 let collation: Collation = options
38 .collation
39 .as_deref()
40 .map(|collation| collation.parse())
41 .transpose()?
42 .unwrap_or_else(|| charset.default_collation());
43
44 let socket = match options.socket {
45 Some(ref path) => Socket::connect_uds(path).await?,
46 None => Socket::connect_tcp(&options.host, options.port).await?,
47 };
48
49 let mut capabilities = Capabilities::PROTOCOL_41
50 | Capabilities::IGNORE_SPACE
51 | Capabilities::DEPRECATE_EOF
52 | Capabilities::FOUND_ROWS
53 | Capabilities::TRANSACTIONS
54 | Capabilities::SECURE_CONNECTION
55 | Capabilities::PLUGIN_AUTH_LENENC_DATA
56 | Capabilities::MULTI_STATEMENTS
57 | Capabilities::MULTI_RESULTS
58 | Capabilities::PLUGIN_AUTH
59 | Capabilities::PS_MULTI_RESULTS
60 | Capabilities::SSL;
61
62 if options.database.is_some() {
63 capabilities |= Capabilities::CONNECT_WITH_DB;
64 }
65
66 Ok(Self {
67 waiting: VecDeque::new(),
68 capabilities,
69 server_version: (0, 0, 0),
70 sequence_id: 0,
71 collation,
72 charset,
73 stream: BufStream::new(MaybeTlsStream::Raw(socket)),
74 })
75 }
76
77 pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
78 if !self.stream.wbuf.is_empty() {
79 self.stream.flush().await?;
80 }
81
82 while !self.waiting.is_empty() {
83 while self.waiting.front() == Some(&Waiting::Row) {
84 let packet = self.recv_packet().await?;
85
86 if !packet.is_empty() && packet[0] == 0xfe && packet.len() < 9 {
87 let eof = packet.eof(self.capabilities)?;
88
89 if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
90 *self.waiting.front_mut().unwrap() = Waiting::Result;
91 } else {
92 self.waiting.pop_front();
93 };
94 }
95 }
96
97 while self.waiting.front() == Some(&Waiting::Result) {
98 let packet = self.recv_packet().await?;
99
100 if !packet.is_empty() && (packet[0] == 0x00 || packet[0] == 0xff) {
101 let ok = packet.ok()?;
102
103 if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
104 self.waiting.pop_front();
105 }
106 } else {
107 *self.waiting.front_mut().unwrap() = Waiting::Row;
108 self.skip_result_metadata(packet).await?;
109 }
110 }
111 }
112
113 Ok(())
114 }
115
116 pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
117 where
118 T: Encode<'en, Capabilities>,
119 {
120 self.sequence_id = 0;
121 self.write_packet(payload);
122 self.flush().await
123 }
124
125 pub(crate) fn write_packet<'en, T>(&mut self, payload: T)
126 where
127 T: Encode<'en, Capabilities>,
128 {
129 self.stream
130 .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
131 }
132
133 pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
136 let mut header: Bytes = self.stream.read(4).await?;
140
141 let packet_size = header.get_uint_le(3) as usize;
142 let sequence_id = header.get_u8();
143
144 self.sequence_id = sequence_id.wrapping_add(1);
145
146 let payload: Bytes = self.stream.read(packet_size).await?;
147
148 if payload
152 .get(0)
153 .ok_or(err_protocol!("Packet empty"))?
154 .eq(&0xff)
155 {
156 self.waiting.pop_front();
157
158 return Err(
161 MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(),
162 );
163 }
164
165 Ok(Packet(payload))
166 }
167
168 pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
169 where
170 T: Decode<'de, Capabilities>,
171 {
172 self.recv_packet().await?.decode_with(self.capabilities)
173 }
174
175 pub(crate) async fn recv_ok(&mut self) -> Result<OkPacket, Error> {
176 self.recv_packet().await?.ok()
177 }
178
179 pub(crate) async fn maybe_recv_eof(&mut self) -> Result<Option<EofPacket>, Error> {
180 if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
181 Ok(None)
182 } else {
183 self.recv().await.map(Some)
184 }
185 }
186
187 async fn skip_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
188 let num_columns: u64 = packet.get_uint_lenenc(); for _ in 0..num_columns {
191 let _ = self.recv_packet().await?;
192 }
193
194 self.maybe_recv_eof().await?;
195
196 Ok(())
197 }
198}
199
200impl Deref for MySqlStream {
201 type Target = BufStream<MaybeTlsStream<Socket>>;
202
203 fn deref(&self) -> &Self::Target {
204 &self.stream
205 }
206}
207
208impl DerefMut for MySqlStream {
209 fn deref_mut(&mut self) -> &mut Self::Target {
210 &mut self.stream
211 }
212}