Skip to main content

sqlmodel_mysql/
auth.rs

1//! MySQL authentication implementations.
2//!
3//! This module implements the MySQL authentication plugins:
4//! - `mysql_native_password`: SHA1-based (legacy, MySQL < 8.0 default)
5//! - `caching_sha2_password`: SHA256-based (MySQL 8.0+ default)
6//!
7//! # mysql_native_password
8//!
9//! Password scramble algorithm:
10//! ```text
11//! SHA1(password) XOR SHA1(seed + SHA1(SHA1(password)))
12//! ```
13//!
14//! # caching_sha2_password
15//!
16//! Fast auth (if cached on server):
17//! ```text
18//! XOR(SHA256(password), SHA256(SHA256(SHA256(password)) + seed))
19//! ```
20//!
21//! Full auth requires TLS or RSA public key encryption.
22
23use sha1::Sha1;
24use sha2::{Digest, Sha256};
25
26/// Well-known authentication plugin names.
27pub mod plugins {
28    /// SHA1-based authentication (legacy default)
29    pub const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
30    /// SHA256-based authentication (MySQL 8.0+ default)
31    pub const CACHING_SHA2_PASSWORD: &str = "caching_sha2_password";
32    /// RSA-based SHA256 authentication
33    pub const SHA256_PASSWORD: &str = "sha256_password";
34    /// MySQL clear password (for debugging/testing only)
35    pub const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
36}
37
38/// Response codes for caching_sha2_password protocol.
39pub mod caching_sha2 {
40    /// Request for public key (client should send 0x02)
41    pub const REQUEST_PUBLIC_KEY: u8 = 0x02;
42    /// Fast auth success
43    pub const FAST_AUTH_SUCCESS: u8 = 0x03;
44    /// Full auth needed (switch to secure channel or RSA)
45    pub const PERFORM_FULL_AUTH: u8 = 0x04;
46}
47
48/// Compute mysql_native_password authentication response.
49///
50/// Algorithm: `SHA1(password) XOR SHA1(seed + SHA1(SHA1(password)))`
51///
52/// # Arguments
53/// * `password` - The user's password (UTF-8)
54/// * `auth_data` - The 20-byte scramble from the server
55///
56/// # Returns
57/// The 20-byte authentication response, or empty vec if password is empty.
58pub fn mysql_native_password(password: &str, auth_data: &[u8]) -> Vec<u8> {
59    if password.is_empty() {
60        return vec![];
61    }
62
63    // Ensure we only use first 20 bytes of auth_data
64    let seed = if auth_data.len() > 20 {
65        &auth_data[..20]
66    } else {
67        auth_data
68    };
69
70    // Stage 1: SHA1(password)
71    let mut hasher = Sha1::new();
72    hasher.update(password.as_bytes());
73    let stage1: [u8; 20] = hasher.finalize().into();
74
75    // Stage 2: SHA1(SHA1(password))
76    let mut hasher = Sha1::new();
77    hasher.update(stage1);
78    let stage2: [u8; 20] = hasher.finalize().into();
79
80    // Stage 3: SHA1(seed + stage2)
81    let mut hasher = Sha1::new();
82    hasher.update(seed);
83    hasher.update(stage2);
84    let stage3: [u8; 20] = hasher.finalize().into();
85
86    // Final: stage1 XOR stage3
87    stage1
88        .iter()
89        .zip(stage3.iter())
90        .map(|(a, b)| a ^ b)
91        .collect()
92}
93
94/// Compute caching_sha2_password fast authentication response.
95///
96/// Algorithm: `XOR(SHA256(password), SHA256(SHA256(SHA256(password)) + seed))`
97///
98/// # Arguments
99/// * `password` - The user's password (UTF-8)
100/// * `auth_data` - The scramble from the server (typically 20 bytes + NUL)
101///
102/// # Returns
103/// The 32-byte authentication response, or empty vec if password is empty.
104pub fn caching_sha2_password(password: &str, auth_data: &[u8]) -> Vec<u8> {
105    if password.is_empty() {
106        return vec![];
107    }
108
109    // Remove trailing NUL if present (MySQL sends 20-byte scramble + NUL = 21 bytes)
110    // Only strip if length is 21 and ends with NUL, to avoid modifying valid 20-byte seeds
111    let seed = if auth_data.len() == 21 && auth_data.last() == Some(&0) {
112        &auth_data[..20]
113    } else {
114        auth_data
115    };
116
117    // SHA256(password)
118    let mut hasher = Sha256::new();
119    hasher.update(password.as_bytes());
120    let password_hash: [u8; 32] = hasher.finalize().into();
121
122    // SHA256(SHA256(password))
123    let mut hasher = Sha256::new();
124    hasher.update(password_hash);
125    let password_hash_hash: [u8; 32] = hasher.finalize().into();
126
127    // SHA256(SHA256(SHA256(password)) + seed)
128    let mut hasher = Sha256::new();
129    hasher.update(password_hash_hash);
130    hasher.update(seed);
131    let scramble: [u8; 32] = hasher.finalize().into();
132
133    // XOR(SHA256(password), scramble)
134    password_hash
135        .iter()
136        .zip(scramble.iter())
137        .map(|(a, b)| a ^ b)
138        .collect()
139}
140
141/// Generate a random nonce for client-side use.
142///
143/// Uses `OsRng` for cryptographically secure random generation.
144pub fn generate_nonce(length: usize) -> Vec<u8> {
145    use rand::RngCore;
146    use rand::rngs::OsRng;
147    let mut bytes = vec![0u8; length];
148    OsRng.fill_bytes(&mut bytes);
149    bytes
150}
151
152/// Scramble password for sha256_password plugin using RSA encryption.
153///
154/// This is used when full authentication is required for caching_sha2_password
155/// or sha256_password plugins without TLS.
156///
157/// # Arguments
158/// * `password` - The user's password
159/// * `seed` - The authentication seed from server
160/// * `public_key` - RSA public key from server (PEM format)
161///
162/// # Returns
163/// The encrypted password, or error if encryption fails.
164///
165/// Note: This is a placeholder - RSA encryption requires additional dependencies.
166/// In practice, prefer using TLS connections which don't require RSA encryption.
167pub fn sha256_password_rsa(
168    _password: &str,
169    _seed: &[u8],
170    _public_key: &[u8],
171) -> Result<Vec<u8>, String> {
172    // RSA encryption requires the `rsa` crate which adds significant dependencies.
173    // For now, we recommend using TLS connections instead.
174    // When TLS is established, the password can be sent in cleartext (still secure).
175    Err("RSA encryption not implemented - use TLS connection instead".to_string())
176}
177
178/// XOR password with seed for cleartext transmission over TLS.
179///
180/// When the connection is secured with TLS, some auth methods allow sending
181/// the password XOR'd with the seed (or even cleartext).
182pub fn xor_password_with_seed(password: &str, seed: &[u8]) -> Vec<u8> {
183    let password_bytes = password.as_bytes();
184    let mut result = Vec::with_capacity(password_bytes.len() + 1);
185
186    for (i, &byte) in password_bytes.iter().enumerate() {
187        let seed_byte = seed.get(i % seed.len()).copied().unwrap_or(0);
188        result.push(byte ^ seed_byte);
189    }
190
191    // NUL terminator
192    result.push(0);
193
194    result
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_mysql_native_password_empty() {
203        let result = mysql_native_password("", &[0; 20]);
204        assert!(result.is_empty());
205    }
206
207    #[test]
208    fn test_mysql_native_password() {
209        // Known test vector from MySQL protocol documentation
210        // Seed: 20 bytes of zeros
211        // Password: "secret"
212        let seed = [0u8; 20];
213        let result = mysql_native_password("secret", &seed);
214
215        // Should produce 20 bytes
216        assert_eq!(result.len(), 20);
217
218        // The result should be deterministic
219        let result2 = mysql_native_password("secret", &seed);
220        assert_eq!(result, result2);
221    }
222
223    #[test]
224    fn test_mysql_native_password_real_seed() {
225        // Test with a realistic scramble
226        let seed = [
227            0x3d, 0x4c, 0x5e, 0x2f, 0x1a, 0x0b, 0x7c, 0x8d, 0x9e, 0xaf, 0x10, 0x21, 0x32, 0x43,
228            0x54, 0x65, 0x76, 0x87, 0x98, 0xa9,
229        ];
230
231        let result = mysql_native_password("mypassword", &seed);
232        assert_eq!(result.len(), 20);
233
234        // Different password should give different result
235        let result2 = mysql_native_password("otherpassword", &seed);
236        assert_ne!(result, result2);
237    }
238
239    #[test]
240    fn test_caching_sha2_password_empty() {
241        let result = caching_sha2_password("", &[0; 20]);
242        assert!(result.is_empty());
243    }
244
245    #[test]
246    fn test_caching_sha2_password() {
247        let seed = [0u8; 20];
248        let result = caching_sha2_password("secret", &seed);
249
250        // Should produce 32 bytes (SHA-256 output)
251        assert_eq!(result.len(), 32);
252
253        // Should be deterministic
254        let result2 = caching_sha2_password("secret", &seed);
255        assert_eq!(result, result2);
256    }
257
258    #[test]
259    fn test_caching_sha2_password_with_nul() {
260        // MySQL often sends seed with trailing NUL
261        let mut seed = vec![0u8; 20];
262        seed.push(0); // Trailing NUL
263
264        let result = caching_sha2_password("secret", &seed);
265        assert_eq!(result.len(), 32);
266
267        // Should be same as without NUL
268        let result2 = caching_sha2_password("secret", &seed[..20]);
269        assert_eq!(result, result2);
270    }
271
272    #[test]
273    fn test_generate_nonce() {
274        let nonce1 = generate_nonce(20);
275        let nonce2 = generate_nonce(20);
276
277        assert_eq!(nonce1.len(), 20);
278        assert_eq!(nonce2.len(), 20);
279
280        // Should be different (extremely high probability)
281        assert_ne!(nonce1, nonce2);
282    }
283
284    #[test]
285    fn test_xor_password_with_seed() {
286        let password = "test";
287        let seed = [1, 2, 3, 4, 5, 6, 7, 8];
288
289        let result = xor_password_with_seed(password, &seed);
290
291        // Should be password length + 1 (NUL terminator)
292        assert_eq!(result.len(), 5);
293
294        // Last byte should be NUL
295        assert_eq!(result[4], 0);
296
297        // XOR is reversible
298        let recovered: Vec<u8> = result[..4]
299            .iter()
300            .enumerate()
301            .map(|(i, &b)| b ^ seed[i % seed.len()])
302            .collect();
303        assert_eq!(recovered, password.as_bytes());
304    }
305
306    #[test]
307    fn test_plugin_names() {
308        assert_eq!(plugins::MYSQL_NATIVE_PASSWORD, "mysql_native_password");
309        assert_eq!(plugins::CACHING_SHA2_PASSWORD, "caching_sha2_password");
310        assert_eq!(plugins::SHA256_PASSWORD, "sha256_password");
311    }
312}