sqlx_mysql/connection/
auth.rs

1use bytes::buf::Chain;
2use bytes::Bytes;
3use digest::{Digest, OutputSizeUser};
4use generic_array::GenericArray;
5use rand::thread_rng;
6use rsa::{pkcs8::DecodePublicKey, Oaep, RsaPublicKey};
7use sha1::Sha1;
8use sha2::Sha256;
9
10use crate::connection::stream::MySqlStream;
11use crate::error::Error;
12use crate::protocol::auth::AuthPlugin;
13use crate::protocol::Packet;
14
15impl AuthPlugin {
16    pub(super) async fn scramble(
17        self,
18        stream: &mut MySqlStream,
19        password: &str,
20        nonce: &Chain<Bytes, Bytes>,
21    ) -> Result<Vec<u8>, Error> {
22        match self {
23            // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
24            AuthPlugin::CachingSha2Password => Ok(scramble_sha256(password, nonce).to_vec()),
25
26            AuthPlugin::MySqlNativePassword => Ok(scramble_sha1(password, nonce).to_vec()),
27
28            // https://mariadb.com/kb/en/sha256_password-plugin/
29            AuthPlugin::Sha256Password => encrypt_rsa(stream, 0x01, password, nonce).await,
30
31            AuthPlugin::MySqlClearPassword => {
32                let mut pw_bytes = password.as_bytes().to_owned();
33                pw_bytes.push(0); // null terminate
34                Ok(pw_bytes)
35            }
36        }
37    }
38
39    pub(super) async fn handle(
40        self,
41        stream: &mut MySqlStream,
42        packet: Packet<Bytes>,
43        password: &str,
44        nonce: &Chain<Bytes, Bytes>,
45    ) -> Result<bool, Error> {
46        match self {
47            AuthPlugin::CachingSha2Password if packet[0] == 0x01 => {
48                match packet[1] {
49                    // AUTH_OK
50                    0x03 => Ok(true),
51
52                    // AUTH_CONTINUE
53                    0x04 => {
54                        let payload = encrypt_rsa(stream, 0x02, password, nonce).await?;
55
56                        stream.write_packet(&*payload)?;
57                        stream.flush().await?;
58
59                        Ok(false)
60                    }
61
62                    v => {
63                        Err(err_protocol!("unexpected result from fast authentication 0x{:x} when expecting 0x03 (AUTH_OK) or 0x04 (AUTH_CONTINUE)", v))
64                    }
65                }
66            }
67
68            _ => Err(err_protocol!(
69                "unexpected packet 0x{:02x} for auth plugin '{}' during authentication",
70                packet[0],
71                self.name()
72            )),
73        }
74    }
75}
76
77fn scramble_sha1(
78    password: &str,
79    nonce: &Chain<Bytes, Bytes>,
80) -> GenericArray<u8, <Sha1 as OutputSizeUser>::OutputSize> {
81    // SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) )
82    // https://mariadb.com/kb/en/connection/#mysql_native_password-plugin
83
84    let mut ctx = Sha1::new();
85
86    ctx.update(password);
87
88    let mut pw_hash = ctx.finalize_reset();
89
90    ctx.update(pw_hash);
91
92    let pw_hash_hash = ctx.finalize_reset();
93
94    ctx.update(nonce.first_ref());
95    ctx.update(nonce.last_ref());
96    ctx.update(pw_hash_hash);
97
98    let pw_seed_hash_hash = ctx.finalize();
99
100    xor_eq(&mut pw_hash, &pw_seed_hash_hash);
101
102    pw_hash
103}
104
105fn scramble_sha256(
106    password: &str,
107    nonce: &Chain<Bytes, Bytes>,
108) -> GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize> {
109    // XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password))))
110    // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password
111    let mut ctx = Sha256::new();
112
113    ctx.update(password);
114
115    let mut pw_hash = ctx.finalize_reset();
116
117    ctx.update(pw_hash);
118
119    let pw_hash_hash = ctx.finalize_reset();
120
121    ctx.update(nonce.first_ref());
122    ctx.update(nonce.last_ref());
123    ctx.update(pw_hash_hash);
124
125    let pw_seed_hash_hash = ctx.finalize();
126
127    xor_eq(&mut pw_hash, &pw_seed_hash_hash);
128
129    pw_hash
130}
131
132async fn encrypt_rsa<'s>(
133    stream: &'s mut MySqlStream,
134    public_key_request_id: u8,
135    password: &'s str,
136    nonce: &'s Chain<Bytes, Bytes>,
137) -> Result<Vec<u8>, Error> {
138    // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
139
140    if stream.is_tls {
141        // If in a TLS stream, send the password directly in clear text
142        return Ok(to_asciz(password));
143    }
144
145    // client sends a public key request
146    stream.write_packet(&[public_key_request_id][..])?;
147    stream.flush().await?;
148
149    // server sends a public key response
150    let packet = stream.recv_packet().await?;
151    let rsa_pub_key = &packet[1..];
152
153    // xor the password with the given nonce
154    let mut pass = to_asciz(password);
155
156    let (a, b) = (nonce.first_ref(), nonce.last_ref());
157    let mut nonce = Vec::with_capacity(a.len() + b.len());
158    nonce.extend_from_slice(a);
159    nonce.extend_from_slice(b);
160
161    xor_eq(&mut pass, &nonce);
162
163    // client sends an RSA encrypted password
164    let pkey = parse_rsa_pub_key(rsa_pub_key)?;
165    let padding = Oaep::new::<sha1::Sha1>();
166    pkey.encrypt(&mut thread_rng(), padding, &pass[..])
167        .map_err(Error::protocol)
168}
169
170// XOR(x, y)
171// If len(y) < len(x), wrap around inside y
172fn xor_eq(x: &mut [u8], y: &[u8]) {
173    let y_len = y.len();
174
175    for i in 0..x.len() {
176        x[i] ^= y[i % y_len];
177    }
178}
179
180fn to_asciz(s: &str) -> Vec<u8> {
181    let mut z = String::with_capacity(s.len() + 1);
182    z.push_str(s);
183    z.push('\0');
184
185    z.into_bytes()
186}
187
188// https://docs.rs/rsa/0.3.0/rsa/struct.RSAPublicKey.html?search=#example-1
189fn parse_rsa_pub_key(key: &[u8]) -> Result<RsaPublicKey, Error> {
190    let pem = std::str::from_utf8(key).map_err(Error::protocol)?;
191
192    // This takes advantage of the knowledge that we know
193    // we are receiving a PKCS#8 RSA Public Key at all
194    // times from MySQL
195
196    RsaPublicKey::from_public_key_pem(pem).map_err(Error::protocol)
197}