use std::convert::TryInto;
use std::io;
use async_std::net::{Shutdown, TcpStream};
use byteorder::{ByteOrder, LittleEndian};
use futures_core::future::BoxFuture;
use sha1::Sha1;
use crate::cache::StatementCache;
use crate::connection::Connection;
use crate::io::{Buf, BufMut, BufStream};
use crate::mysql::error::MySqlError;
use crate::mysql::protocol::{
AuthPlugin, AuthSwitch, Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake,
HandshakeResponse, OkPacket,
};
use crate::mysql::rsa;
use crate::mysql::util::xor_eq;
use crate::url::Url;
const MAX_PACKET_SIZE: u32 = 1024;
const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
pub struct MySqlConnection {
pub(super) stream: BufStream<TcpStream>,
pub(super) capabilities: Capabilities,
pub(super) statement_cache: StatementCache<u32>,
pub(super) packet: Vec<u8>,
packet_len: usize,
pub(super) next_seq_no: u8,
}
impl MySqlConnection {
pub(crate) fn write(&mut self, packet: impl Encode) {
let buf = self.stream.buffer_mut();
let header_offset = buf.len();
buf.advance(4);
packet.encode(buf, self.capabilities);
let len = buf.len() - header_offset - 4;
let mut header = &mut buf[header_offset..];
LittleEndian::write_u32(&mut header, len as u32);
header[3] = self.next_seq_no;
self.next_seq_no = self.next_seq_no.wrapping_add(1);
}
pub(crate) async fn send(&mut self, packet: impl Encode) -> crate::Result<()> {
self.write(packet);
self.stream.flush().await?;
Ok(())
}
pub(crate) async fn send_handshake_response(
&mut self,
url: &Url,
auth_plugin: &AuthPlugin,
auth_response: &[u8],
) -> crate::Result<()> {
self.send(HandshakeResponse {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
username: url.username().unwrap_or("root"),
database: url.database(),
auth_plugin,
auth_response,
})
.await
}
pub(crate) async fn try_receive(&mut self) -> crate::Result<Option<()>> {
self.packet.clear();
let mut header = ret_if_none!(self.stream.peek(4).await?);
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
self.next_seq_no = header.get_u8()?.wrapping_add(1);
self.stream.consume(4);
let payload = ret_if_none!(self.stream.peek(self.packet_len).await?);
self.packet.extend_from_slice(payload);
self.stream.consume(self.packet_len);
Ok(Some(()))
}
pub(crate) async fn receive(&mut self) -> crate::Result<&mut Self> {
self.try_receive()
.await?
.ok_or(io::ErrorKind::ConnectionAborted)?;
Ok(self)
}
#[inline]
pub(crate) fn packet(&self) -> &[u8] {
&self.packet[..self.packet_len]
}
pub(crate) async fn receive_eof(&mut self) -> crate::Result<()> {
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(self.receive().await?.packet())?;
}
Ok(())
}
pub(crate) async fn receive_handshake(&mut self, url: &Url) -> crate::Result<Handshake> {
let handshake = Handshake::decode(self.receive().await?.packet())?;
let mut client_capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::FOUND_ROWS
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::PLUGIN_AUTH;
if url.database().is_some() {
client_capabilities |= Capabilities::CONNECT_WITH_DB;
}
self.capabilities =
(client_capabilities & handshake.server_capabilities) | Capabilities::PROTOCOL_41;
Ok(handshake)
}
pub(crate) fn receive_auth_ok<'a>(
&'a mut self,
plugin: &'a AuthPlugin,
password: &'a str,
nonce: &'a [u8],
) -> BoxFuture<'a, crate::Result<()>> {
Box::pin(async move {
self.receive().await?;
match self.packet[0] {
0x00 => self.handle_ok().map(drop),
0xfe => self.handle_auth_switch(password).await,
0xff => self.handle_err(),
_ => self.handle_auth_continue(plugin, password, nonce).await,
}
})
}
}
impl MySqlConnection {
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
let ok = OkPacket::decode(self.packet())?;
self.next_seq_no = 0;
Ok(ok)
}
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
let err = ErrPacket::decode(self.packet())?;
self.next_seq_no = 0;
Err(MySqlError(err).into())
}
pub(crate) fn handle_unexpected_packet<T>(&self, id: u8) -> crate::Result<T> {
Err(protocol_err!("unexpected packet identifier 0x{:X?}", id).into())
}
pub(crate) async fn handle_auth_continue(
&mut self,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<()> {
match plugin {
AuthPlugin::CachingSha2Password => {
if self.packet[0] == 1 {
match self.packet[1] {
0x03 => {}
0x04 => {
let ct = self.rsa_encrypt(0x02, password, nonce).await?;
self.send(&*ct).await?;
}
auth => {
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", auth).into());
}
}
self.receive_auth_ok(plugin, password, nonce)
.await
.map(drop)
} else {
return self.handle_unexpected_packet(self.packet[0]);
}
}
_ => unreachable!(),
}
}
pub(crate) async fn handle_auth_switch(&mut self, password: &str) -> crate::Result<()> {
let auth = AuthSwitch::decode(self.packet())?;
let auth_response = self
.make_auth_initial_response(&auth.auth_plugin, password, &auth.auth_plugin_data)
.await?;
self.send(&*auth_response).await?;
self.receive_auth_ok(&auth.auth_plugin, password, &auth.auth_plugin_data)
.await
}
pub(crate) async fn make_auth_initial_response(
&mut self,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
match plugin {
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
Ok(plugin.scramble(password, nonce))
}
AuthPlugin::Sha256Password => {
Ok(self.rsa_encrypt(0x01, password, nonce).await?.into_vec())
}
}
}
pub(crate) async fn rsa_encrypt(
&mut self,
public_key_request_id: u8,
password: &str,
nonce: &[u8],
) -> crate::Result<Box<[u8]>> {
self.send(&[public_key_request_id][..]).await?;
let packet = self.receive().await?.packet();
let rsa_pub_key = &packet[1..];
let mut pass = password.as_bytes().to_vec();
pass.push(0);
xor_eq(&mut pass, nonce);
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
}
}
impl MySqlConnection {
async fn new(url: &Url) -> crate::Result<Self> {
let stream = TcpStream::connect((url.host(), url.port(3306))).await?;
Ok(Self {
stream: BufStream::new(stream),
capabilities: Capabilities::empty(),
packet: Vec::with_capacity(8192),
packet_len: 0,
next_seq_no: 0,
statement_cache: StatementCache::new(),
})
}
async fn initialize(&mut self) -> crate::Result<()> {
self.execute_raw("SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'))")
.await?;
self.execute_raw("SET time_zone = '+00:00'").await?;
self.execute_raw("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci")
.await?;
Ok(())
}
}
impl MySqlConnection {
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let mut self_ = Self::new(&url).await?;
let handshake = self_.receive_handshake(&url).await?;
let password = url.password().unwrap_or_default();
let auth_response = self_
.make_auth_initial_response(
&handshake.auth_plugin,
password,
&handshake.auth_plugin_data,
)
.await?;
self_
.send_handshake_response(&url, &handshake.auth_plugin, &auth_response)
.await?;
self_
.receive_auth_ok(
&handshake.auth_plugin,
password,
&handshake.auth_plugin_data,
)
.await?;
self_.initialize().await?;
Ok(self_)
}
async fn close(mut self) -> crate::Result<()> {
self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?;
Ok(())
}
}
impl Connection for MySqlConnection {
fn open<T>(url: T) -> BoxFuture<'static, crate::Result<Self>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(MySqlConnection::open(url.try_into()))
}
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(self.close())
}
}