sqlx_mysql/protocol/connect/
handshake.rs

1use bytes::buf::Chain;
2use bytes::{Buf, Bytes};
3use std::cmp;
4
5use crate::error::Error;
6use crate::io::{BufExt, ProtocolDecode};
7use crate::protocol::auth::AuthPlugin;
8use crate::protocol::response::Status;
9use crate::protocol::Capabilities;
10
11// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
12// https://mariadb.com/kb/en/connection/#initial-handshake-packet
13
14#[derive(Debug)]
15pub(crate) struct Handshake {
16    #[allow(unused)]
17    pub(crate) protocol_version: u8,
18    pub(crate) server_version: String,
19    #[allow(unused)]
20    pub(crate) connection_id: u32,
21    pub(crate) server_capabilities: Capabilities,
22    #[allow(unused)]
23    pub(crate) server_default_collation: u8,
24    #[allow(unused)]
25    pub(crate) status: Status,
26    pub(crate) auth_plugin: Option<AuthPlugin>,
27    pub(crate) auth_plugin_data: Chain<Bytes, Bytes>,
28}
29
30impl ProtocolDecode<'_> for Handshake {
31    fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
32        let protocol_version = buf.get_u8(); // int<1>
33        let server_version = buf.get_str_nul()?; // string<NUL>
34        let connection_id = buf.get_u32_le(); // int<4>
35        let auth_plugin_data_1 = buf.get_bytes(8); // string<8>
36
37        buf.advance(1); // reserved: string<1>
38
39        let capabilities_1 = buf.get_u16_le(); // int<2>
40        let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
41
42        let collation = buf.get_u8(); // int<1>
43        let status = Status::from_bits_truncate(buf.get_u16_le());
44
45        let capabilities_2 = buf.get_u16_le(); // int<2>
46        capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
47
48        let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
49            buf.get_u8()
50        } else {
51            buf.advance(1); // int<1>
52            0
53        };
54
55        buf.advance(6); // reserved: string<6>
56
57        if capabilities.contains(Capabilities::MYSQL) {
58            buf.advance(4); // reserved: string<4>
59        } else {
60            let capabilities_3 = buf.get_u32_le(); // int<4>
61            capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32);
62        }
63
64        let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) {
65            let len = cmp::max(auth_plugin_data_len.saturating_sub(9), 12);
66            let v = buf.get_bytes(len as usize);
67            buf.advance(1); // NUL-terminator
68
69            v
70        } else {
71            Bytes::new()
72        };
73
74        let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
75            Some(buf.get_str_nul()?.parse()?)
76        } else {
77            None
78        };
79
80        Ok(Self {
81            protocol_version,
82            server_version,
83            connection_id,
84            server_default_collation: collation,
85            status,
86            server_capabilities: capabilities,
87            auth_plugin,
88            auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2),
89        })
90    }
91}
92
93#[test]
94fn test_decode_handshake_mysql_8_0_18() {
95    const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
96
97    let p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap();
98
99    assert_eq!(p.protocol_version, 10);
100
101    assert_eq!(
102        p.server_capabilities,
103        Capabilities::MYSQL
104            | Capabilities::FOUND_ROWS
105            | Capabilities::LONG_FLAG
106            | Capabilities::CONNECT_WITH_DB
107            | Capabilities::NO_SCHEMA
108            | Capabilities::COMPRESS
109            | Capabilities::ODBC
110            | Capabilities::LOCAL_FILES
111            | Capabilities::IGNORE_SPACE
112            | Capabilities::PROTOCOL_41
113            | Capabilities::INTERACTIVE
114            | Capabilities::SSL
115            | Capabilities::TRANSACTIONS
116            | Capabilities::SECURE_CONNECTION
117            | Capabilities::MULTI_STATEMENTS
118            | Capabilities::MULTI_RESULTS
119            | Capabilities::PS_MULTI_RESULTS
120            | Capabilities::PLUGIN_AUTH
121            | Capabilities::CONNECT_ATTRS
122            | Capabilities::PLUGIN_AUTH_LENENC_DATA
123            | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
124            | Capabilities::SESSION_TRACK
125            | Capabilities::DEPRECATE_EOF
126            | Capabilities::ZSTD_COMPRESSION_ALGORITHM
127            | Capabilities::SSL_VERIFY_SERVER_CERT
128            | Capabilities::OPTIONAL_RESULTSET_METADATA
129            | Capabilities::REMEMBER_OPTIONS,
130    );
131
132    assert_eq!(p.server_default_collation, 255);
133    assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
134
135    assert!(matches!(
136        p.auth_plugin,
137        Some(AuthPlugin::CachingSha2Password)
138    ));
139
140    assert_eq!(
141        &*p.auth_plugin_data.into_iter().collect::<Vec<_>>(),
142        &[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
143    );
144}
145
146#[test]
147fn test_decode_handshake_mariadb_10_4_7() {
148    const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
149
150    let p = Handshake::decode(HANDSHAKE_MARIA_DB_10_4_7.into()).unwrap();
151
152    assert_eq!(p.protocol_version, 10);
153
154    assert_eq!(
155        &*p.server_version,
156        "5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic"
157    );
158
159    assert_eq!(
160        p.server_capabilities,
161        Capabilities::FOUND_ROWS
162            | Capabilities::LONG_FLAG
163            | Capabilities::CONNECT_WITH_DB
164            | Capabilities::NO_SCHEMA
165            | Capabilities::COMPRESS
166            | Capabilities::ODBC
167            | Capabilities::LOCAL_FILES
168            | Capabilities::IGNORE_SPACE
169            | Capabilities::PROTOCOL_41
170            | Capabilities::INTERACTIVE
171            | Capabilities::TRANSACTIONS
172            | Capabilities::SECURE_CONNECTION
173            | Capabilities::MULTI_STATEMENTS
174            | Capabilities::MULTI_RESULTS
175            | Capabilities::PS_MULTI_RESULTS
176            | Capabilities::PLUGIN_AUTH
177            | Capabilities::CONNECT_ATTRS
178            | Capabilities::PLUGIN_AUTH_LENENC_DATA
179            | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
180            | Capabilities::SESSION_TRACK
181            | Capabilities::DEPRECATE_EOF
182            | Capabilities::REMEMBER_OPTIONS
183            | Capabilities::MARIADB_CLIENT_PROGRESS
184            | Capabilities::MARIADB_CLIENT_MULTI
185            | Capabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS
186    );
187
188    assert_eq!(p.server_default_collation, 8);
189    assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
190    assert!(matches!(
191        p.auth_plugin,
192        Some(AuthPlugin::MySqlNativePassword)
193    ));
194
195    assert_eq!(
196        &*p.auth_plugin_data.into_iter().collect::<Vec<_>>(),
197        &[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,]
198    );
199}