sqlx_mysql/connection/
establish.rs

1use bytes::buf::Buf;
2use bytes::Bytes;
3
4use crate::collation::{CharSet, Collation};
5use crate::common::StatementCache;
6use crate::connection::{tls, MySqlConnectionInner, MySqlStream, MAX_PACKET_SIZE};
7use crate::error::Error;
8use crate::net::{Socket, WithSocket};
9use crate::protocol::connect::{
10    AuthSwitchRequest, AuthSwitchResponse, Handshake, HandshakeResponse,
11};
12use crate::protocol::Capabilities;
13use crate::{MySqlConnectOptions, MySqlConnection, MySqlSslMode};
14
15impl MySqlConnection {
16    pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result<Self, Error> {
17        let do_handshake = DoHandshake::new(options)?;
18
19        let handshake = match &options.socket {
20            Some(path) => crate::net::connect_uds(path, do_handshake).await?,
21            None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
22        };
23
24        let stream = handshake?;
25
26        Ok(Self {
27            inner: Box::new(MySqlConnectionInner {
28                stream,
29                transaction_depth: 0,
30                status_flags: Default::default(),
31                cache_statement: StatementCache::new(options.statement_cache_capacity),
32                log_settings: options.log_settings.clone(),
33            }),
34        })
35    }
36}
37
38struct DoHandshake<'a> {
39    options: &'a MySqlConnectOptions,
40    charset: CharSet,
41    collation: Collation,
42}
43
44impl<'a> DoHandshake<'a> {
45    fn new(options: &'a MySqlConnectOptions) -> Result<Self, Error> {
46        let charset: CharSet = options.charset.parse()?;
47        let collation: Collation = options
48            .collation
49            .as_deref()
50            .map(|collation| collation.parse())
51            .transpose()?
52            .unwrap_or_else(|| charset.default_collation());
53
54        if options.enable_cleartext_plugin
55            && matches!(
56                options.ssl_mode,
57                MySqlSslMode::Disabled | MySqlSslMode::Preferred
58            )
59        {
60            log::warn!("Security warning: sending cleartext passwords without requiring SSL");
61        }
62
63        Ok(Self {
64            options,
65            charset,
66            collation,
67        })
68    }
69
70    async fn do_handshake<S: Socket>(self, socket: S) -> Result<MySqlStream, Error> {
71        let DoHandshake {
72            options,
73            charset,
74            collation,
75        } = self;
76
77        let mut stream = MySqlStream::with_socket(charset, collation, options, socket);
78
79        // https://dev.mysql.com/doc/internals/en/connection-phase.html
80        // https://mariadb.com/kb/en/connection/
81
82        let handshake: Handshake = stream.recv_packet().await?.decode()?;
83
84        let mut plugin = handshake.auth_plugin;
85        let nonce = handshake.auth_plugin_data;
86
87        // FIXME: server version parse is a bit ugly
88        // expecting MAJOR.MINOR.PATCH
89
90        let mut server_version = handshake.server_version.split('.');
91
92        let server_version_major: u16 = server_version
93            .next()
94            .unwrap_or_default()
95            .parse()
96            .unwrap_or(0);
97
98        let server_version_minor: u16 = server_version
99            .next()
100            .unwrap_or_default()
101            .parse()
102            .unwrap_or(0);
103
104        let server_version_patch: u16 = server_version
105            .next()
106            .unwrap_or_default()
107            .parse()
108            .unwrap_or(0);
109
110        stream.server_version = (
111            server_version_major,
112            server_version_minor,
113            server_version_patch,
114        );
115
116        stream.capabilities &= handshake.server_capabilities;
117        stream.capabilities |= Capabilities::PROTOCOL_41;
118
119        let mut stream = tls::maybe_upgrade(stream, self.options).await?;
120
121        let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) {
122            Some(plugin.scramble(&mut stream, password, &nonce).await?)
123        } else {
124            None
125        };
126
127        stream.write_packet(HandshakeResponse {
128            collation: stream.collation as u8,
129            max_packet_size: MAX_PACKET_SIZE,
130            username: &options.username,
131            database: options.database.as_deref(),
132            auth_plugin: plugin,
133            auth_response: auth_response.as_deref(),
134        })?;
135
136        stream.flush().await?;
137
138        loop {
139            let packet = stream.recv_packet().await?;
140            match packet[0] {
141                0x00 => {
142                    let _ok = packet.ok()?;
143
144                    break;
145                }
146
147                0xfe => {
148                    let switch: AuthSwitchRequest =
149                        packet.decode_with(self.options.enable_cleartext_plugin)?;
150
151                    plugin = Some(switch.plugin);
152                    let nonce = switch.data.chain(Bytes::new());
153
154                    let response = switch
155                        .plugin
156                        .scramble(
157                            &mut stream,
158                            options.password.as_deref().unwrap_or_default(),
159                            &nonce,
160                        )
161                        .await?;
162
163                    stream.write_packet(AuthSwitchResponse(response))?;
164                    stream.flush().await?;
165                }
166
167                id => {
168                    if let (Some(plugin), Some(password)) = (plugin, &options.password) {
169                        if plugin.handle(&mut stream, packet, password, &nonce).await? {
170                            // plugin signaled authentication is ok
171                            break;
172                        }
173
174                        // plugin signaled to continue authentication
175                    } else {
176                        return Err(err_protocol!(
177                            "unexpected packet 0x{:02x} during authentication",
178                            id
179                        ));
180                    }
181                }
182            }
183        }
184
185        Ok(stream)
186    }
187}
188
189impl<'a> WithSocket for DoHandshake<'a> {
190    type Output = Result<MySqlStream, Error>;
191
192    async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
193        self.do_handshake(socket).await
194    }
195}