sqlx_core/mysql/connection/
stream.rs

1use std::collections::VecDeque;
2use std::ops::{Deref, DerefMut};
3
4use bytes::{Buf, Bytes};
5
6use crate::error::Error;
7use crate::io::{BufStream, Decode, Encode};
8use crate::mysql::collation::{CharSet, Collation};
9use crate::mysql::io::MySqlBufExt;
10use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
11use crate::mysql::protocol::{Capabilities, Packet};
12use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError};
13use crate::net::{MaybeTlsStream, Socket};
14
15pub struct MySqlStream {
16    stream: BufStream<MaybeTlsStream<Socket>>,
17    pub(crate) server_version: (u16, u16, u16),
18    pub(super) capabilities: Capabilities,
19    pub(crate) sequence_id: u8,
20    pub(crate) waiting: VecDeque<Waiting>,
21    pub(crate) charset: CharSet,
22    pub(crate) collation: Collation,
23}
24
25#[derive(Debug, PartialEq, Eq)]
26pub(crate) enum Waiting {
27    // waiting for a result set
28    Result,
29
30    // waiting for a row within a result set
31    Row,
32}
33
34impl MySqlStream {
35    pub(super) async fn connect(options: &MySqlConnectOptions) -> Result<Self, Error> {
36        let charset: CharSet = options.charset.parse()?;
37        let collation: Collation = options
38            .collation
39            .as_deref()
40            .map(|collation| collation.parse())
41            .transpose()?
42            .unwrap_or_else(|| charset.default_collation());
43
44        let socket = match options.socket {
45            Some(ref path) => Socket::connect_uds(path).await?,
46            None => Socket::connect_tcp(&options.host, options.port).await?,
47        };
48
49        let mut capabilities = Capabilities::PROTOCOL_41
50            | Capabilities::IGNORE_SPACE
51            | Capabilities::DEPRECATE_EOF
52            | Capabilities::FOUND_ROWS
53            | Capabilities::TRANSACTIONS
54            | Capabilities::SECURE_CONNECTION
55            | Capabilities::PLUGIN_AUTH_LENENC_DATA
56            | Capabilities::MULTI_STATEMENTS
57            | Capabilities::MULTI_RESULTS
58            | Capabilities::PLUGIN_AUTH
59            | Capabilities::PS_MULTI_RESULTS
60            | Capabilities::SSL;
61
62        if options.database.is_some() {
63            capabilities |= Capabilities::CONNECT_WITH_DB;
64        }
65
66        Ok(Self {
67            waiting: VecDeque::new(),
68            capabilities,
69            server_version: (0, 0, 0),
70            sequence_id: 0,
71            collation,
72            charset,
73            stream: BufStream::new(MaybeTlsStream::Raw(socket)),
74        })
75    }
76
77    pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
78        if !self.stream.wbuf.is_empty() {
79            self.stream.flush().await?;
80        }
81
82        while !self.waiting.is_empty() {
83            while self.waiting.front() == Some(&Waiting::Row) {
84                let packet = self.recv_packet().await?;
85
86                if !packet.is_empty() && packet[0] == 0xfe && packet.len() < 9 {
87                    let eof = packet.eof(self.capabilities)?;
88
89                    if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
90                        *self.waiting.front_mut().unwrap() = Waiting::Result;
91                    } else {
92                        self.waiting.pop_front();
93                    };
94                }
95            }
96
97            while self.waiting.front() == Some(&Waiting::Result) {
98                let packet = self.recv_packet().await?;
99
100                if !packet.is_empty() && (packet[0] == 0x00 || packet[0] == 0xff) {
101                    let ok = packet.ok()?;
102
103                    if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
104                        self.waiting.pop_front();
105                    }
106                } else {
107                    *self.waiting.front_mut().unwrap() = Waiting::Row;
108                    self.skip_result_metadata(packet).await?;
109                }
110            }
111        }
112
113        Ok(())
114    }
115
116    pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
117    where
118        T: Encode<'en, Capabilities>,
119    {
120        self.sequence_id = 0;
121        self.write_packet(payload);
122        self.flush().await
123    }
124
125    pub(crate) fn write_packet<'en, T>(&mut self, payload: T)
126    where
127        T: Encode<'en, Capabilities>,
128    {
129        self.stream
130            .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
131    }
132
133    // receive the next packet from the database server
134    // may block (async) on more data from the server
135    pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
136        // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
137        // https://mariadb.com/kb/en/library/0-packet/#standard-packet
138
139        let mut header: Bytes = self.stream.read(4).await?;
140
141        let packet_size = header.get_uint_le(3) as usize;
142        let sequence_id = header.get_u8();
143
144        self.sequence_id = sequence_id.wrapping_add(1);
145
146        let payload: Bytes = self.stream.read(packet_size).await?;
147
148        // TODO: packet compression
149        // TODO: packet joining
150
151        if payload
152            .get(0)
153            .ok_or(err_protocol!("Packet empty"))?
154            .eq(&0xff)
155        {
156            self.waiting.pop_front();
157
158            // instead of letting this packet be looked at everywhere, we check here
159            // and emit a proper Error
160            return Err(
161                MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(),
162            );
163        }
164
165        Ok(Packet(payload))
166    }
167
168    pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
169    where
170        T: Decode<'de, Capabilities>,
171    {
172        self.recv_packet().await?.decode_with(self.capabilities)
173    }
174
175    pub(crate) async fn recv_ok(&mut self) -> Result<OkPacket, Error> {
176        self.recv_packet().await?.ok()
177    }
178
179    pub(crate) async fn maybe_recv_eof(&mut self) -> Result<Option<EofPacket>, Error> {
180        if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
181            Ok(None)
182        } else {
183            self.recv().await.map(Some)
184        }
185    }
186
187    async fn skip_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
188        let num_columns: u64 = packet.get_uint_lenenc(); // column count
189
190        for _ in 0..num_columns {
191            let _ = self.recv_packet().await?;
192        }
193
194        self.maybe_recv_eof().await?;
195
196        Ok(())
197    }
198}
199
200impl Deref for MySqlStream {
201    type Target = BufStream<MaybeTlsStream<Socket>>;
202
203    fn deref(&self) -> &Self::Target {
204        &self.stream
205    }
206}
207
208impl DerefMut for MySqlStream {
209    fn deref_mut(&mut self) -> &mut Self::Target {
210        &mut self.stream
211    }
212}