sqlx_core/mysql/connection/
auth.rs1use bytes::buf::Chain;
2use bytes::Bytes;
3use digest::{Digest, OutputSizeUser};
4use generic_array::GenericArray;
5use rand::thread_rng;
6use rsa::{pkcs8::DecodePublicKey, PaddingScheme, PublicKey, RsaPublicKey};
7use sha1::Sha1;
8use sha2::Sha256;
9
10use crate::error::Error;
11use crate::mysql::connection::stream::MySqlStream;
12use crate::mysql::protocol::auth::AuthPlugin;
13use crate::mysql::protocol::Packet;
14
15impl AuthPlugin {
16 pub(super) async fn scramble(
17 self,
18 stream: &mut MySqlStream,
19 password: &str,
20 nonce: &Chain<Bytes, Bytes>,
21 ) -> Result<Vec<u8>, Error> {
22 match self {
23 AuthPlugin::CachingSha2Password => Ok(scramble_sha256(password, nonce).to_vec()),
25
26 AuthPlugin::MySqlNativePassword => Ok(scramble_sha1(password, nonce).to_vec()),
27
28 AuthPlugin::Sha256Password => encrypt_rsa(stream, 0x01, password, nonce).await,
30 }
31 }
32
33 pub(super) async fn handle(
34 self,
35 stream: &mut MySqlStream,
36 packet: Packet<Bytes>,
37 password: &str,
38 nonce: &Chain<Bytes, Bytes>,
39 ) -> Result<bool, Error> {
40 match self {
41 AuthPlugin::CachingSha2Password if packet[0] == 0x01 => {
42 match packet[1] {
43 0x03 => Ok(true),
45
46 0x04 => {
48 let payload = encrypt_rsa(stream, 0x02, password, nonce).await?;
49
50 stream.write_packet(&*payload);
51 stream.flush().await?;
52
53 Ok(false)
54 }
55
56 v => {
57 Err(err_protocol!("unexpected result from fast authentication 0x{:x} when expecting 0x03 (AUTH_OK) or 0x04 (AUTH_CONTINUE)", v))
58 }
59 }
60 }
61
62 _ => Err(err_protocol!(
63 "unexpected packet 0x{:02x} for auth plugin '{}' during authentication",
64 packet[0],
65 self.name()
66 )),
67 }
68 }
69}
70
71fn scramble_sha1(
72 password: &str,
73 nonce: &Chain<Bytes, Bytes>,
74) -> GenericArray<u8, <Sha1 as OutputSizeUser>::OutputSize> {
75 let mut ctx = Sha1::new();
79
80 ctx.update(password);
81
82 let mut pw_hash = ctx.finalize_reset();
83
84 ctx.update(&pw_hash);
85
86 let pw_hash_hash = ctx.finalize_reset();
87
88 ctx.update(nonce.first_ref());
89 ctx.update(nonce.last_ref());
90 ctx.update(pw_hash_hash);
91
92 let pw_seed_hash_hash = ctx.finalize();
93
94 xor_eq(&mut pw_hash, &pw_seed_hash_hash);
95
96 pw_hash
97}
98
99fn scramble_sha256(
100 password: &str,
101 nonce: &Chain<Bytes, Bytes>,
102) -> GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize> {
103 let mut ctx = Sha256::new();
106
107 ctx.update(password);
108
109 let mut pw_hash = ctx.finalize_reset();
110
111 ctx.update(&pw_hash);
112
113 let pw_hash_hash = ctx.finalize_reset();
114
115 ctx.update(nonce.first_ref());
116 ctx.update(nonce.last_ref());
117 ctx.update(pw_hash_hash);
118
119 let pw_seed_hash_hash = ctx.finalize();
120
121 xor_eq(&mut pw_hash, &pw_seed_hash_hash);
122
123 pw_hash
124}
125
126async fn encrypt_rsa<'s>(
127 stream: &'s mut MySqlStream,
128 public_key_request_id: u8,
129 password: &'s str,
130 nonce: &'s Chain<Bytes, Bytes>,
131) -> Result<Vec<u8>, Error> {
132 if stream.is_tls() {
135 return Ok(to_asciz(password));
137 }
138
139 stream.write_packet(&[public_key_request_id][..]);
141 stream.flush().await?;
142
143 let packet = stream.recv_packet().await?;
145 let rsa_pub_key = &packet[1..];
146
147 let mut pass = to_asciz(password);
149
150 let (a, b) = (nonce.first_ref(), nonce.last_ref());
151 let mut nonce = Vec::with_capacity(a.len() + b.len());
152 nonce.extend_from_slice(&*a);
153 nonce.extend_from_slice(&*b);
154
155 xor_eq(&mut pass, &*nonce);
156
157 let pkey = parse_rsa_pub_key(rsa_pub_key)?;
159 let padding = PaddingScheme::new_oaep::<sha1::Sha1>();
160 pkey.encrypt(&mut thread_rng(), padding, &pass[..])
161 .map_err(Error::protocol)
162}
163
164fn xor_eq(x: &mut [u8], y: &[u8]) {
167 let y_len = y.len();
168
169 for i in 0..x.len() {
170 x[i] ^= y[i % y_len];
171 }
172}
173
174fn to_asciz(s: &str) -> Vec<u8> {
175 let mut z = String::with_capacity(s.len() + 1);
176 z.push_str(s);
177 z.push('\0');
178
179 z.into_bytes()
180}
181
182fn parse_rsa_pub_key(key: &[u8]) -> Result<RsaPublicKey, Error> {
184 let pem = std::str::from_utf8(key).map_err(Error::protocol)?;
185
186 RsaPublicKey::from_public_key_pem(&pem).map_err(Error::protocol)
191}