sqlx_core_guts/mysql/connection/
establish.rs

1use bytes::buf::Buf;
2use bytes::Bytes;
3
4use crate::common::StatementCache;
5use crate::error::Error;
6use crate::mysql::connection::{tls, MySqlStream, MAX_PACKET_SIZE};
7use crate::mysql::protocol::connect::{
8    AuthSwitchRequest, AuthSwitchResponse, Handshake, HandshakeResponse,
9};
10use crate::mysql::protocol::Capabilities;
11use crate::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode};
12
13impl MySqlConnection {
14    pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result<Self, Error> {
15        let mut stream: MySqlStream = MySqlStream::connect(options).await?;
16
17        // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
18        // https://mariadb.com/kb/en/connection/
19
20        let handshake: Handshake = stream.recv_packet().await?.decode()?;
21
22        let mut plugin = handshake.auth_plugin;
23        let mut nonce = handshake.auth_plugin_data;
24
25        // FIXME: server version parse is a bit ugly
26        // expecting MAJOR.MINOR.PATCH
27
28        let mut server_version = handshake.server_version.split('.');
29
30        let server_version_major: u16 = server_version
31            .next()
32            .unwrap_or_default()
33            .parse()
34            .unwrap_or(0);
35
36        let server_version_minor: u16 = server_version
37            .next()
38            .unwrap_or_default()
39            .parse()
40            .unwrap_or(0);
41
42        let server_version_patch: u16 = server_version
43            .next()
44            .unwrap_or_default()
45            .parse()
46            .unwrap_or(0);
47
48        stream.server_version = (
49            server_version_major,
50            server_version_minor,
51            server_version_patch,
52        );
53
54        stream.capabilities &= handshake.server_capabilities;
55        stream.capabilities |= Capabilities::PROTOCOL_41;
56
57        if matches!(options.ssl_mode, MySqlSslMode::Disabled) {
58            // remove the SSL capability if SSL has been explicitly disabled
59            stream.capabilities.remove(Capabilities::SSL);
60        }
61
62        // Upgrade to TLS if we were asked to and the server supports it
63        tls::maybe_upgrade(&mut stream, options).await?;
64
65        let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) {
66            Some(plugin.scramble(&mut stream, password, &nonce).await?)
67        } else {
68            None
69        };
70
71        stream.write_packet(HandshakeResponse {
72            collation: stream.collation as u8,
73            max_packet_size: MAX_PACKET_SIZE,
74            username: &options.username,
75            database: options.database.as_deref(),
76            auth_plugin: plugin,
77            auth_response: auth_response.as_deref(),
78        });
79
80        stream.flush().await?;
81
82        loop {
83            let packet = stream.recv_packet().await?;
84            match packet[0] {
85                0x00 => {
86                    let _ok = packet.ok()?;
87
88                    break;
89                }
90
91                0xfe => {
92                    let switch: AuthSwitchRequest = packet.decode()?;
93
94                    plugin = Some(switch.plugin);
95                    nonce = switch.data.chain(Bytes::new());
96
97                    let response = switch
98                        .plugin
99                        .scramble(
100                            &mut stream,
101                            options.password.as_deref().unwrap_or_default(),
102                            &nonce,
103                        )
104                        .await?;
105
106                    stream.write_packet(AuthSwitchResponse(response));
107                    stream.flush().await?;
108                }
109
110                id => {
111                    if let (Some(plugin), Some(password)) = (plugin, &options.password) {
112                        if plugin.handle(&mut stream, packet, password, &nonce).await? {
113                            // plugin signaled authentication is ok
114                            break;
115                        }
116
117                    // plugin signaled to continue authentication
118                    } else {
119                        return Err(err_protocol!(
120                            "unexpected packet 0x{:02x} during authentication",
121                            id
122                        ));
123                    }
124                }
125            }
126        }
127
128        Ok(Self {
129            stream,
130            transaction_depth: 0,
131            cache_statement: StatementCache::new(options.statement_cache_capacity),
132            log_settings: options.log_settings.clone(),
133        })
134    }
135}