Skip to main content

zero_mysql/protocol/connection/
handshake.rs

1use crate::nightly::cold_path;
2use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
3use zerocopy::{FromBytes, Immutable, KnownLayout};
4
5use crate::buffer::BufferSet;
6use crate::constant::{
7    CAPABILITIES_ALWAYS_ENABLED, CAPABILITIES_CONFIGURABLE, CapabilityFlags,
8    MARIADB_CAPABILITIES_ENABLED, MAX_ALLOWED_PACKET, MariadbCapabilityFlags, UTF8MB4_GENERAL_CI,
9};
10use crate::error::{Error, Result, eyre};
11use crate::opts::Opts;
12use crate::protocol::primitive::*;
13use crate::protocol::response::ErrPayloadBytes;
14
15#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
16#[repr(C, packed)]
17struct HandshakeFixedFields {
18    connection_id: U32LE,
19    auth_data_part1: [u8; 8],
20    _filler1: u8,
21    capability_flags_lower: U16LE,
22    charset: u8,
23    status_flags: U16LE,
24    capability_flags_upper: U16LE,
25    auth_data_len: u8,
26    _filler2: [u8; 6],
27    mariadb_capabilities: U32LE,
28}
29
30#[derive(Debug, Clone)]
31pub struct InitialHandshake {
32    pub protocol_version: u8,
33    pub server_version: std::ops::Range<usize>,
34    pub connection_id: u32,
35    pub auth_plugin_data: Vec<u8>,
36    pub capability_flags: CapabilityFlags,
37    pub mariadb_capabilities: MariadbCapabilityFlags,
38    pub charset: u8,
39    pub status_flags: crate::constant::ServerStatusFlags,
40    pub auth_plugin_name: std::ops::Range<usize>,
41}
42
43/// Read initial handshake packet from server
44pub fn read_initial_handshake(payload: &[u8]) -> Result<InitialHandshake> {
45    let (protocol_version, data) = read_int_1(payload)?;
46
47    if protocol_version == 0xFF {
48        cold_path();
49        Err(ErrPayloadBytes(payload))?
50    }
51
52    let server_version_start = payload.len() - data.len();
53    let (server_version_bytes, data) = read_string_null(data)?;
54    let server_version = server_version_start..server_version_start + server_version_bytes.len();
55
56    let (fixed, data) = HandshakeFixedFields::ref_from_prefix(data)?;
57
58    let connection_id = fixed.connection_id.get();
59    let charset = fixed.charset;
60    let status_flags = fixed.status_flags.get();
61    let capability_flags = CapabilityFlags::from_bits(
62        ((fixed.capability_flags_upper.get() as u32) << 16)
63            | (fixed.capability_flags_lower.get() as u32),
64    )
65    .ok_or_else(|| Error::LibraryBug(eyre!("invalid capability flags from server")))?;
66    let mariadb_capabilities = MariadbCapabilityFlags::from_bits(fixed.mariadb_capabilities.get())
67        .ok_or_else(|| Error::LibraryBug(eyre!("invalid mariadb capability flags from server")))?;
68    let auth_data_len = fixed.auth_data_len;
69
70    let auth_data_2_len = (auth_data_len as usize).saturating_sub(9).max(12);
71    let (auth_data_2, data) = read_string_fix(data, auth_data_2_len)?;
72    let (_reserved, data) = read_int_1(data)?;
73
74    let mut auth_plugin_data = Vec::new();
75    auth_plugin_data.extend_from_slice(&fixed.auth_data_part1);
76    auth_plugin_data.extend_from_slice(auth_data_2);
77
78    let auth_plugin_name_start = payload.len() - data.len();
79    let (auth_plugin_name_bytes, rest) = read_string_null(data)?;
80    let auth_plugin_name =
81        auth_plugin_name_start..auth_plugin_name_start + auth_plugin_name_bytes.len();
82
83    if !rest.is_empty() {
84        return Err(Error::LibraryBug(eyre!(
85            "unexpected trailing data in handshake packet: {} bytes",
86            rest.len()
87        )));
88    }
89
90    Ok(InitialHandshake {
91        protocol_version,
92        server_version,
93        connection_id,
94        auth_plugin_data,
95        capability_flags,
96        mariadb_capabilities,
97        charset,
98        status_flags: crate::constant::ServerStatusFlags::from_bits_truncate(status_flags),
99        auth_plugin_name,
100    })
101}
102
103/// Auth switch request from server
104#[derive(Debug, Clone)]
105pub struct AuthSwitchRequest<'buf> {
106    pub plugin_name: &'buf [u8],
107    pub plugin_data: &'buf [u8],
108}
109
110/// Read auth switch request (0xFE with length >= 9)
111pub fn read_auth_switch_request(payload: &[u8]) -> Result<AuthSwitchRequest<'_>> {
112    let (header, mut data) = read_int_1(payload)?;
113    if header != 0xFE {
114        return Err(Error::LibraryBug(eyre!(
115            "expected auth switch header 0xFE, got 0x{:02X}",
116            header
117        )));
118    }
119
120    let (plugin_name, rest) = read_string_null(data)?;
121    data = rest;
122
123    if let Some(0) = data.last() {
124        Ok(AuthSwitchRequest {
125            plugin_name,
126            plugin_data: &data[..data.len() - 1],
127        })
128    } else {
129        Err(Error::LibraryBug(eyre!(
130            "auth switch request plugin data not null-terminated"
131        )))
132    }
133}
134
135/// Write auth switch response
136///
137/// Client sends the authentication data computed using the requested plugin.
138pub fn write_auth_switch_response(out: &mut Vec<u8>, auth_data: &[u8]) {
139    out.extend_from_slice(auth_data);
140}
141
142// ============================================================================
143// Authentication Plugins
144// ============================================================================
145
146/// mysql_native_password authentication
147///
148/// This is the traditional MySQL authentication method using SHA1.
149/// Formula: SHA1(password) XOR SHA1(challenge + SHA1(SHA1(password)))
150///
151/// # Arguments
152/// * `password` - Plain text password
153/// * `challenge` - 20-byte challenge from server (auth_plugin_data)
154///
155/// # Returns
156/// 20-byte authentication response
157pub fn auth_mysql_native_password(password: &str, challenge: &[u8]) -> [u8; 20] {
158    use sha1::{Digest, Sha1};
159
160    if password.is_empty() {
161        return [0_u8; 20];
162    }
163
164    // stage1_hash = SHA1(password)
165    let stage1_hash = Sha1::digest(password.as_bytes());
166
167    // stage2_hash = SHA1(stage1_hash)
168    let stage2_hash = Sha1::digest(stage1_hash);
169
170    // token_hash = SHA1(challenge + stage2_hash)
171    let mut hasher = Sha1::new();
172    hasher.update(challenge);
173    hasher.update(stage2_hash);
174    let token_hash = hasher.finalize();
175
176    // result = stage1_hash XOR token_hash
177    let mut result = [0_u8; 20];
178    for i in 0..20 {
179        result[i] = stage1_hash[i] ^ token_hash[i];
180    }
181
182    result
183}
184
185/// caching_sha2_password authentication - initial response
186///
187/// This is the default authentication method in MySQL 8.0+.
188/// Uses SHA256 hashing instead of SHA1.
189/// Formula: XOR(SHA256(password), SHA256(SHA256(SHA256(password)), challenge))
190///
191/// # Arguments
192/// * `password` - Plain text password
193/// * `challenge` - 20-byte challenge from server (auth_plugin_data)
194///
195/// # Returns
196/// 32-byte authentication response
197pub fn auth_caching_sha2_password(password: &str, challenge: &[u8]) -> [u8; 32] {
198    use sha2::{Digest, Sha256};
199
200    if password.is_empty() {
201        return [0_u8; 32];
202    }
203
204    // stage1 = SHA256(password)
205    let stage1 = Sha256::digest(password.as_bytes());
206
207    // stage2 = SHA256(stage1)
208    let stage2 = Sha256::digest(stage1);
209
210    // scramble = SHA256(stage2 + challenge)
211    let mut hasher = Sha256::new();
212    hasher.update(stage2);
213    hasher.update(challenge);
214    let scramble = hasher.finalize();
215
216    // result = stage1 XOR scramble
217    let mut result = [0_u8; 32];
218    for i in 0..32 {
219        result[i] = stage1[i] ^ scramble[i];
220    }
221
222    result
223}
224
225/// caching_sha2_password fast auth result
226///
227/// After sending the initial auth response, server may respond with:
228/// - 0x03 (fast auth success) - cached authentication succeeded
229/// - 0x04 (full auth required) - need to send password via RSA or cleartext
230#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum CachingSha2PasswordFastAuthResult {
232    Success,
233    FullAuthRequired,
234}
235
236/// Read caching_sha2_password fast auth result
237pub fn read_caching_sha2_password_fast_auth_result(
238    payload: &[u8],
239) -> Result<CachingSha2PasswordFastAuthResult> {
240    if payload.is_empty() {
241        return Err(Error::LibraryBug(eyre!(
242            "empty payload for caching_sha2_password fast auth result"
243        )));
244    }
245
246    match payload[0] {
247        0x03 => Ok(CachingSha2PasswordFastAuthResult::Success),
248        0x04 => Ok(CachingSha2PasswordFastAuthResult::FullAuthRequired),
249        _ => Err(Error::LibraryBug(eyre!(
250            "unexpected caching_sha2_password fast auth result: 0x{:02X}",
251            payload[0]
252        ))),
253    }
254}
255
256/// Encrypt password for caching_sha2_password full auth via RSA.
257///
258/// XORs the null-terminated password with the scramble (cyclically),
259/// then RSA-OAEP-SHA1 encrypts with the server's public key.
260fn rsa_encrypt_password(password: &str, scramble: &[u8], pem_str: &str) -> Result<Vec<u8>> {
261    use aws_lc_rs::rsa::{OAEP_SHA1_MGF1SHA1, OaepPublicEncryptingKey, PublicEncryptingKey};
262
263    let pem_data = pem::parse(pem_str)
264        .map_err(|e| Error::LibraryBug(eyre!("failed to parse RSA public key PEM: {}", e)))?;
265
266    let public_key = PublicEncryptingKey::from_der(pem_data.contents())
267        .map_err(|e| Error::LibraryBug(eyre!("failed to parse RSA public key DER: {}", e)))?;
268
269    let oaep_key = OaepPublicEncryptingKey::new(public_key)
270        .map_err(|e| Error::LibraryBug(eyre!("failed to create OAEP key: {}", e)))?;
271
272    if scramble.is_empty() {
273        return Err(Error::LibraryBug(eyre!(
274            "empty scramble in rsa_encrypt_password"
275        )));
276    }
277
278    // XOR (password + '\0') with scramble repeated cyclically
279    let mut buf = Vec::with_capacity(password.len() + 1);
280    buf.extend_from_slice(password.as_bytes());
281    buf.push(0);
282
283    for (byte, key) in buf.iter_mut().zip(scramble.iter().cycle()) {
284        *byte ^= key;
285    }
286
287    let mut ciphertext = vec![0u8; oaep_key.ciphertext_size()];
288    let encrypted = oaep_key
289        .encrypt(&OAEP_SHA1_MGF1SHA1, &buf, &mut ciphertext, None)
290        .map_err(|e| Error::LibraryBug(eyre!("RSA encryption failed: {}", e)))?;
291
292    Ok(encrypted.to_vec())
293}
294
295// ============================================================================
296// State Machine API for Handshake
297// ============================================================================
298
299/// Write SSL request packet (sent before HandshakeResponse when TLS is enabled)
300fn write_ssl_request(
301    out: &mut Vec<u8>,
302    capability_flags: CapabilityFlags,
303    mariadb_capabilities: MariadbCapabilityFlags,
304) {
305    // capability flags (4 bytes)
306    write_int_4(out, capability_flags.bits());
307
308    // max packet size (4 bytes)
309    write_int_4(out, MAX_ALLOWED_PACKET);
310
311    // charset (1 byte)
312    write_int_1(out, UTF8MB4_GENERAL_CI);
313
314    // reserved (23 bytes of 0x00)
315    out.extend_from_slice(&[0_u8; 19]);
316
317    if capability_flags.is_mariadb() {
318        write_int_4(out, mariadb_capabilities.bits());
319    } else {
320        write_int_4(out, 0);
321    }
322}
323
324/// Action returned by the Handshake state machine indicating what I/O operation is needed next
325pub enum HandshakeAction<'buf> {
326    /// Read a packet into the provided buffer
327    ReadPacket(&'buf mut Vec<u8>),
328
329    /// Write the prepared packet with given sequence_id, then read next response
330    WritePacket { sequence_id: u8 },
331
332    /// Write SSL request, then upgrade stream to TLS
333    UpgradeTls { sequence_id: u8 },
334
335    /// Handshake complete - call finish() to get results
336    Finished,
337}
338
339/// Internal state of the handshake state machine
340enum HandshakeState {
341    /// Initial state - need to read initial handshake from server
342    Start,
343    /// Waiting for initial handshake packet to be read
344    WaitingInitialHandshake,
345    /// SSL request written, waiting for TLS upgrade to complete
346    WaitingTlsUpgrade,
347    /// Handshake response written, waiting for auth result
348    WaitingAuthResult,
349    /// After auth switch response. `caching_sha2` = whether we might receive AuthMoreData (0x01).
350    WaitingFinalAuthResult { caching_sha2: bool },
351    /// After caching_sha2 fast auth success (0x03) — waiting for the OK packet.
352    WaitingCachingSha2FastAuthOk,
353    /// After requesting RSA public key (0x02) — waiting for AuthMoreData with PEM key.
354    WaitingRsaPublicKey,
355    /// Connected (terminal state)
356    Connected,
357}
358
359/// State machine for MySQL handshake
360///
361/// Pure parsing and packet generation state machine without I/O dependencies.
362pub struct Handshake<'a> {
363    state: HandshakeState,
364    opts: &'a Opts,
365    initial_handshake: Option<InitialHandshake>,
366    next_sequence_id: u8,
367    capability_flags: Option<CapabilityFlags>,
368    mariadb_capabilities: Option<MariadbCapabilityFlags>,
369}
370
371impl<'a> Handshake<'a> {
372    /// Create a new handshake state machine
373    pub fn new(opts: &'a Opts) -> Self {
374        Self {
375            state: HandshakeState::Start,
376            opts,
377            initial_handshake: None,
378            next_sequence_id: 1,
379            capability_flags: None,
380            mariadb_capabilities: None,
381        }
382    }
383
384    /// Drive the state machine forward
385    ///
386    /// Returns an action indicating what I/O operation the caller should perform.
387    pub fn step<'buf>(&mut self, buffer_set: &'buf mut BufferSet) -> Result<HandshakeAction<'buf>> {
388        match &mut self.state {
389            HandshakeState::Start => {
390                self.state = HandshakeState::WaitingInitialHandshake;
391                Ok(HandshakeAction::ReadPacket(
392                    &mut buffer_set.initial_handshake,
393                ))
394            }
395
396            HandshakeState::WaitingInitialHandshake => {
397                let handshake = read_initial_handshake(&buffer_set.initial_handshake)?;
398
399                let mut client_caps = CAPABILITIES_ALWAYS_ENABLED
400                    | (self.opts.capabilities & CAPABILITIES_CONFIGURABLE);
401                if self.opts.db.is_some() {
402                    client_caps |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
403                }
404                if self.opts.tls {
405                    client_caps |= CapabilityFlags::CLIENT_SSL;
406                }
407
408                let negotiated_caps = client_caps & handshake.capability_flags;
409                let mariadb_caps = if negotiated_caps.is_mariadb() {
410                    if !handshake
411                        .mariadb_capabilities
412                        .contains(MARIADB_CAPABILITIES_ENABLED)
413                    {
414                        return Err(Error::Unsupported(format!(
415                            "MariaDB server does not support the required capabilities. Server: {:?} Required: {:?}",
416                            handshake.mariadb_capabilities, MARIADB_CAPABILITIES_ENABLED
417                        )));
418                    }
419                    MARIADB_CAPABILITIES_ENABLED
420                } else {
421                    MariadbCapabilityFlags::empty()
422                };
423
424                // Store capabilities and initial handshake
425                self.capability_flags = Some(negotiated_caps);
426                self.mariadb_capabilities = Some(mariadb_caps);
427                self.initial_handshake = Some(handshake);
428
429                // TLS: SSLRequest + HandshakeResponse
430                if self.opts.tls && negotiated_caps.contains(CapabilityFlags::CLIENT_SSL) {
431                    write_ssl_request(buffer_set.new_write_buffer(), negotiated_caps, mariadb_caps);
432
433                    let seq = self.next_sequence_id;
434                    self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
435                    self.state = HandshakeState::WaitingTlsUpgrade;
436
437                    Ok(HandshakeAction::UpgradeTls { sequence_id: seq })
438                } else {
439                    // No TLS: HandshakeResponse
440                    self.write_handshake_response(buffer_set)?;
441                    let seq = self.next_sequence_id;
442                    self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
443                    self.state = HandshakeState::WaitingAuthResult;
444
445                    Ok(HandshakeAction::WritePacket { sequence_id: seq })
446                }
447            }
448
449            HandshakeState::WaitingTlsUpgrade => {
450                // TLS upgrade completed, now send handshake response
451                self.write_handshake_response(buffer_set)?;
452
453                let seq = self.next_sequence_id;
454                self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
455                self.state = HandshakeState::WaitingAuthResult;
456
457                Ok(HandshakeAction::WritePacket { sequence_id: seq })
458            }
459
460            HandshakeState::WaitingAuthResult => {
461                let payload = &buffer_set.read_buffer[..];
462                if payload.is_empty() {
463                    return Err(Error::LibraryBug(eyre!(
464                        "empty payload while waiting for auth result"
465                    )));
466                }
467
468                // Get initial plugin name from stored handshake
469                let initial_handshake = self.initial_handshake.as_ref().ok_or_else(|| {
470                    Error::LibraryBug(eyre!("initial_handshake not set in WaitingAuthResult"))
471                })?;
472                let initial_plugin =
473                    &buffer_set.initial_handshake[initial_handshake.auth_plugin_name.clone()];
474
475                match payload[0] {
476                    0x00 => {
477                        // OK packet - authentication succeeded
478                        self.state = HandshakeState::Connected;
479                        Ok(HandshakeAction::Finished)
480                    }
481                    0xFF => {
482                        // ERR packet - authentication failed
483                        Err(ErrPayloadBytes(payload).into())
484                    }
485                    0x01 => {
486                        // AuthMoreData — caching_sha2_password fast auth result
487                        if initial_plugin == b"caching_sha2_password" {
488                            self.handle_auth_more_data(buffer_set)
489                        } else {
490                            Err(Error::LibraryBug(eyre!(
491                                "unexpected AuthMoreData (0x01) for plugin {:?}",
492                                String::from_utf8_lossy(initial_plugin)
493                            )))
494                        }
495                    }
496                    0xFE => {
497                        // Auth switch request
498                        let auth_switch = read_auth_switch_request(payload)?;
499
500                        // Compute auth response for new plugin
501                        let (auth_response, is_caching_sha2) = match auth_switch.plugin_name {
502                            b"mysql_native_password" => (
503                                auth_mysql_native_password(
504                                    &self.opts.password,
505                                    auth_switch.plugin_data,
506                                )
507                                .to_vec(),
508                                false,
509                            ),
510                            b"caching_sha2_password" => (
511                                auth_caching_sha2_password(
512                                    &self.opts.password,
513                                    auth_switch.plugin_data,
514                                )
515                                .to_vec(),
516                                true,
517                            ),
518                            plugin => {
519                                return Err(Error::Unsupported(
520                                    String::from_utf8_lossy(plugin).to_string(),
521                                ));
522                            }
523                        };
524
525                        write_auth_switch_response(buffer_set.new_write_buffer(), &auth_response);
526
527                        let seq = self.next_sequence_id;
528                        self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
529                        self.state = HandshakeState::WaitingFinalAuthResult {
530                            caching_sha2: is_caching_sha2,
531                        };
532
533                        Ok(HandshakeAction::WritePacket { sequence_id: seq })
534                    }
535                    header => Err(Error::LibraryBug(eyre!(
536                        "unexpected packet header 0x{:02X} while waiting for auth result",
537                        header
538                    ))),
539                }
540            }
541
542            HandshakeState::WaitingFinalAuthResult { caching_sha2 } => {
543                let payload = &buffer_set.read_buffer[..];
544                if payload.is_empty() {
545                    return Err(Error::LibraryBug(eyre!(
546                        "empty payload while waiting for final auth result"
547                    )));
548                }
549
550                match payload[0] {
551                    0x00 => {
552                        // OK packet - authentication succeeded
553                        self.state = HandshakeState::Connected;
554                        Ok(HandshakeAction::Finished)
555                    }
556                    0xFF => {
557                        // ERR packet - authentication failed
558                        Err(ErrPayloadBytes(payload).into())
559                    }
560                    0x01 if *caching_sha2 => self.handle_auth_more_data(buffer_set),
561                    header => Err(Error::LibraryBug(eyre!(
562                        "unexpected packet header 0x{:02X} while waiting for final auth result",
563                        header
564                    ))),
565                }
566            }
567
568            HandshakeState::WaitingCachingSha2FastAuthOk => {
569                let payload = &buffer_set.read_buffer[..];
570                if payload.is_empty() {
571                    return Err(Error::LibraryBug(eyre!(
572                        "empty payload while waiting for caching_sha2 OK"
573                    )));
574                }
575
576                match payload[0] {
577                    0x00 => {
578                        self.state = HandshakeState::Connected;
579                        Ok(HandshakeAction::Finished)
580                    }
581                    0xFF => Err(ErrPayloadBytes(payload).into()),
582                    header => Err(Error::LibraryBug(eyre!(
583                        "unexpected packet header 0x{:02X} while waiting for caching_sha2 OK",
584                        header
585                    ))),
586                }
587            }
588
589            HandshakeState::WaitingRsaPublicKey => {
590                let payload = &buffer_set.read_buffer[..];
591                if payload.is_empty() {
592                    return Err(Error::LibraryBug(eyre!(
593                        "empty payload while waiting for RSA public key"
594                    )));
595                }
596
597                match payload[0] {
598                    0xFF => return Err(ErrPayloadBytes(payload).into()),
599                    0x01 if payload.len() >= 2 => {}
600                    header => {
601                        return Err(Error::LibraryBug(eyre!(
602                            "expected AuthMoreData (0x01) with RSA public key, got 0x{:02X}",
603                            header
604                        )));
605                    }
606                }
607
608                let pem = std::str::from_utf8(&payload[1..]).map_err(|e| {
609                    Error::LibraryBug(eyre!("RSA public key is not valid UTF-8: {}", e))
610                })?;
611
612                let handshake = self
613                    .initial_handshake
614                    .as_ref()
615                    .ok_or_else(|| Error::LibraryBug(eyre!("initial_handshake not set")))?;
616
617                let encrypted =
618                    rsa_encrypt_password(&self.opts.password, &handshake.auth_plugin_data, pem)?;
619
620                let out = buffer_set.new_write_buffer();
621                out.extend_from_slice(&encrypted);
622
623                let seq = self.next_sequence_id;
624                self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
625                self.state = HandshakeState::WaitingFinalAuthResult {
626                    caching_sha2: false,
627                };
628
629                Ok(HandshakeAction::WritePacket { sequence_id: seq })
630            }
631
632            HandshakeState::Connected => Err(Error::LibraryBug(eyre!(
633                "step() called after handshake completed"
634            ))),
635        }
636    }
637
638    /// Consume the state machine and return the connection info
639    ///
640    /// Returns an error if called before handshake is complete (before Finished action)
641    pub fn finish(self) -> Result<(InitialHandshake, CapabilityFlags, MariadbCapabilityFlags)> {
642        if !matches!(self.state, HandshakeState::Connected) {
643            return Err(Error::LibraryBug(eyre!(
644                "finish() called before handshake completed"
645            )));
646        }
647
648        let initial_handshake = self.initial_handshake.ok_or_else(|| {
649            Error::LibraryBug(eyre!("initial_handshake not set in Connected state"))
650        })?;
651        let capability_flags = self.capability_flags.ok_or_else(|| {
652            Error::LibraryBug(eyre!("capability_flags not set in Connected state"))
653        })?;
654        let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
655            Error::LibraryBug(eyre!("mariadb_capabilities not set in Connected state"))
656        })?;
657
658        Ok((initial_handshake, capability_flags, mariadb_capabilities))
659    }
660
661    /// Write handshake response packet (HandshakeResponse41)
662    fn write_handshake_response(&self, buffer_set: &mut BufferSet) -> Result<()> {
663        buffer_set.new_write_buffer();
664
665        let handshake = self.initial_handshake.as_ref().ok_or_else(|| {
666            Error::LibraryBug(eyre!(
667                "initial_handshake not set in write_handshake_response"
668            ))
669        })?;
670        let capability_flags = self.capability_flags.ok_or_else(|| {
671            Error::LibraryBug(eyre!(
672                "capability_flags not set in write_handshake_response"
673            ))
674        })?;
675        let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
676            Error::LibraryBug(eyre!(
677                "mariadb_capabilities not set in write_handshake_response"
678            ))
679        })?;
680
681        // Compute auth response based on plugin name
682        let auth_plugin_name = &buffer_set.initial_handshake[handshake.auth_plugin_name.clone()];
683        let auth_response = {
684            match auth_plugin_name {
685                b"mysql_native_password" => {
686                    auth_mysql_native_password(&self.opts.password, &handshake.auth_plugin_data)
687                        .to_vec()
688                }
689                b"caching_sha2_password" => {
690                    auth_caching_sha2_password(&self.opts.password, &handshake.auth_plugin_data)
691                        .to_vec()
692                }
693                plugin => {
694                    return Err(Error::Unsupported(
695                        String::from_utf8_lossy(plugin).to_string(),
696                    ));
697                }
698            }
699        };
700
701        let out = &mut buffer_set.write_buffer;
702        // capability flags (4 bytes)
703        write_int_4(out, capability_flags.bits());
704        // max packet size (4 bytes)
705        write_int_4(out, MAX_ALLOWED_PACKET);
706        // charset (1 byte)
707        write_int_1(out, UTF8MB4_GENERAL_CI);
708        // reserved (19 bytes) + MariaDB capabilities (4 bytes) = 23 bytes
709        out.extend_from_slice(&[0_u8; 19]);
710        write_int_4(out, mariadb_capabilities.bits());
711        // username (null-terminated)
712        write_string_null(out, self.opts.user.as_bytes());
713        // auth response (length-encoded)
714        if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
715            write_bytes_lenenc(out, &auth_response);
716        } else {
717            write_int_1(out, auth_response.len() as u8);
718            out.extend_from_slice(&auth_response);
719        }
720        // database name (null-terminated, if CLIENT_CONNECT_WITH_DB)
721        if let Some(db) = &self.opts.db {
722            write_string_null(out, db.as_bytes());
723        }
724
725        // auth plugin name (null-terminated, if CLIENT_PLUGIN_AUTH)
726        if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
727            write_string_null(out, auth_plugin_name);
728        }
729
730        Ok(())
731    }
732
733    /// Handle AuthMoreData (0x01) packet for caching_sha2_password.
734    ///
735    /// Called from both `WaitingAuthResult` and `WaitingFinalAuthResult { caching_sha2: true }`.
736    fn handle_auth_more_data<'buf>(
737        &mut self,
738        buffer_set: &'buf mut BufferSet,
739    ) -> Result<HandshakeAction<'buf>> {
740        let payload = &buffer_set.read_buffer[..];
741        if payload.len() < 2 {
742            return Err(Error::LibraryBug(eyre!(
743                "AuthMoreData packet too short: {} bytes",
744                payload.len()
745            )));
746        }
747
748        let result = read_caching_sha2_password_fast_auth_result(&payload[1..])?;
749
750        match result {
751            CachingSha2PasswordFastAuthResult::Success => {
752                // Fast auth succeeded — server will send OK next
753                self.state = HandshakeState::WaitingCachingSha2FastAuthOk;
754                Ok(HandshakeAction::ReadPacket(&mut buffer_set.read_buffer))
755            }
756            CachingSha2PasswordFastAuthResult::FullAuthRequired => {
757                let capability_flags = self
758                    .capability_flags
759                    .ok_or_else(|| Error::LibraryBug(eyre!("capability_flags not set")))?;
760
761                if capability_flags.contains(CapabilityFlags::CLIENT_SSL) {
762                    // TLS is active — send cleartext password (null-terminated)
763                    let out = buffer_set.new_write_buffer();
764                    out.extend_from_slice(self.opts.password.as_bytes());
765                    out.push(0);
766
767                    let seq = self.next_sequence_id;
768                    self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
769                    self.state = HandshakeState::WaitingFinalAuthResult {
770                        caching_sha2: false,
771                    };
772
773                    Ok(HandshakeAction::WritePacket { sequence_id: seq })
774                } else {
775                    // No TLS — request server's RSA public key
776                    let out = buffer_set.new_write_buffer();
777                    out.push(0x02);
778
779                    let seq = self.next_sequence_id;
780                    self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
781                    self.state = HandshakeState::WaitingRsaPublicKey;
782
783                    Ok(HandshakeAction::WritePacket { sequence_id: seq })
784                }
785            }
786        }
787    }
788}
789
790#[cfg(test)]
791mod tests {
792    use super::*;
793    use crate::test_macros::{check_eq, check_err};
794
795    #[test]
796    fn handshake_fixed_fields_has_alignment_of_1() {
797        assert_eq!(std::mem::align_of::<HandshakeFixedFields>(), 1);
798    }
799
800    #[test]
801    #[expect(clippy::unwrap_used)]
802    fn rsa_encrypt_password_xors_and_encrypts() {
803        use aws_lc_rs::encoding::AsDer;
804        use aws_lc_rs::rsa::{
805            KeySize, OAEP_SHA1_MGF1SHA1, OaepPrivateDecryptingKey, PrivateDecryptingKey,
806        };
807        use aws_lc_rs::signature::KeyPair;
808
809        let key_pair = aws_lc_rs::rsa::KeyPair::generate(KeySize::Rsa2048).unwrap();
810        let private_key_pkcs8 = key_pair.as_der().unwrap();
811        let public_key_der = key_pair.public_key().as_der().unwrap();
812
813        let pem_data = pem::Pem::new("PUBLIC KEY", public_key_der.as_ref().to_vec());
814        let pem_string = pem::encode(&pem_data);
815
816        let password = "test_password";
817        let scramble = b"01234567890123456789";
818
819        let encrypted = super::rsa_encrypt_password(password, scramble, &pem_string).unwrap();
820
821        // Decrypt and verify
822        let private_key = PrivateDecryptingKey::from_pkcs8(private_key_pkcs8.as_ref()).unwrap();
823        let oaep_key = OaepPrivateDecryptingKey::new(private_key).unwrap();
824        let mut plaintext = vec![0u8; encrypted.len()];
825        let decrypted = oaep_key
826            .decrypt(&OAEP_SHA1_MGF1SHA1, &encrypted, &mut plaintext, None)
827            .unwrap();
828
829        // Decrypted should be XOR(password + '\0', scramble)
830        let mut expected = password.as_bytes().to_vec();
831        expected.push(0);
832        for (byte, key) in expected.iter_mut().zip(scramble.iter().cycle()) {
833            *byte ^= key;
834        }
835        assert_eq!(decrypted, expected);
836    }
837
838    #[test]
839    fn fast_auth_result_parsing() -> crate::error::Result<()> {
840        check_eq!(
841            read_caching_sha2_password_fast_auth_result(&[0x03])?,
842            CachingSha2PasswordFastAuthResult::Success,
843        );
844        check_eq!(
845            read_caching_sha2_password_fast_auth_result(&[0x04])?,
846            CachingSha2PasswordFastAuthResult::FullAuthRequired,
847        );
848        check_err!(read_caching_sha2_password_fast_auth_result(&[0x05]));
849        check_err!(read_caching_sha2_password_fast_auth_result(&[]));
850        Ok(())
851    }
852}