Skip to main content

zero_mysql/protocol/connection/
handshake.rs

1use crate::nightly::cold_path;
2use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
3use zerocopy::{FromBytes, Immutable, KnownLayout};
4
5use crate::buffer::BufferSet;
6use crate::constant::{
7    CAPABILITIES_ALWAYS_ENABLED, CAPABILITIES_CONFIGURABLE, CapabilityFlags,
8    MARIADB_CAPABILITIES_ENABLED, MAX_ALLOWED_PACKET, MariadbCapabilityFlags, UTF8MB4_GENERAL_CI,
9};
10use crate::error::{Error, Result, eyre};
11use crate::opts::Opts;
12use crate::protocol::primitive::*;
13use crate::protocol::response::ErrPayloadBytes;
14
15#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
16#[repr(C, packed)]
17struct HandshakeFixedFields {
18    connection_id: U32LE,
19    auth_data_part1: [u8; 8],
20    _filler1: u8,
21    capability_flags_lower: U16LE,
22    charset: u8,
23    status_flags: U16LE,
24    capability_flags_upper: U16LE,
25    auth_data_len: u8,
26    _filler2: [u8; 6],
27    mariadb_capabilities: U32LE,
28}
29
30#[derive(Debug, Clone)]
31pub struct InitialHandshake {
32    pub protocol_version: u8,
33    pub server_version: std::ops::Range<usize>,
34    pub connection_id: u32,
35    pub auth_plugin_data: Vec<u8>,
36    pub capability_flags: CapabilityFlags,
37    pub mariadb_capabilities: MariadbCapabilityFlags,
38    pub charset: u8,
39    pub status_flags: crate::constant::ServerStatusFlags,
40    pub auth_plugin_name: std::ops::Range<usize>,
41}
42
43/// Read initial handshake packet from server
44pub fn read_initial_handshake(payload: &[u8]) -> Result<InitialHandshake> {
45    let (protocol_version, data) = read_int_1(payload)?;
46
47    if protocol_version == 0xFF {
48        cold_path();
49        Err(ErrPayloadBytes(payload))?
50    }
51
52    let server_version_start = payload.len() - data.len();
53    let (server_version_bytes, data) = read_string_null(data)?;
54    let server_version = server_version_start..server_version_start + server_version_bytes.len();
55
56    let (fixed, data) = HandshakeFixedFields::ref_from_prefix(data)?;
57
58    let connection_id = fixed.connection_id.get();
59    let charset = fixed.charset;
60    let status_flags = fixed.status_flags.get();
61    let capability_flags = CapabilityFlags::from_bits(
62        ((fixed.capability_flags_upper.get() as u32) << 16)
63            | (fixed.capability_flags_lower.get() as u32),
64    )
65    .ok_or_else(|| Error::LibraryBug(eyre!("invalid capability flags from server")))?;
66    let mariadb_capabilities = MariadbCapabilityFlags::from_bits(fixed.mariadb_capabilities.get())
67        .ok_or_else(|| Error::LibraryBug(eyre!("invalid mariadb capability flags from server")))?;
68    let auth_data_len = fixed.auth_data_len;
69
70    let auth_data_2_len = (auth_data_len as usize).saturating_sub(9).max(12);
71    let (auth_data_2, data) = read_string_fix(data, auth_data_2_len)?;
72    let (_reserved, data) = read_int_1(data)?;
73
74    let mut auth_plugin_data = Vec::new();
75    auth_plugin_data.extend_from_slice(&fixed.auth_data_part1);
76    auth_plugin_data.extend_from_slice(auth_data_2);
77
78    let auth_plugin_name_start = payload.len() - data.len();
79    let (auth_plugin_name_bytes, rest) = read_string_null(data)?;
80    let auth_plugin_name =
81        auth_plugin_name_start..auth_plugin_name_start + auth_plugin_name_bytes.len();
82
83    if !rest.is_empty() {
84        return Err(Error::LibraryBug(eyre!(
85            "unexpected trailing data in handshake packet: {} bytes",
86            rest.len()
87        )));
88    }
89
90    Ok(InitialHandshake {
91        protocol_version,
92        server_version,
93        connection_id,
94        auth_plugin_data,
95        capability_flags,
96        mariadb_capabilities,
97        charset,
98        status_flags: crate::constant::ServerStatusFlags::from_bits_truncate(status_flags),
99        auth_plugin_name,
100    })
101}
102
103/// Auth switch request from server
104#[derive(Debug, Clone)]
105pub struct AuthSwitchRequest<'buf> {
106    pub plugin_name: &'buf [u8],
107    pub plugin_data: &'buf [u8],
108}
109
110/// Read auth switch request (0xFE with length >= 9)
111pub fn read_auth_switch_request(payload: &[u8]) -> Result<AuthSwitchRequest<'_>> {
112    let (header, mut data) = read_int_1(payload)?;
113    if header != 0xFE {
114        return Err(Error::LibraryBug(eyre!(
115            "expected auth switch header 0xFE, got 0x{:02X}",
116            header
117        )));
118    }
119
120    let (plugin_name, rest) = read_string_null(data)?;
121    data = rest;
122
123    if let Some(0) = data.last() {
124        Ok(AuthSwitchRequest {
125            plugin_name,
126            plugin_data: &data[..data.len() - 1],
127        })
128    } else {
129        Err(Error::LibraryBug(eyre!(
130            "auth switch request plugin data not null-terminated"
131        )))
132    }
133}
134
135/// Write auth switch response
136///
137/// Client sends the authentication data computed using the requested plugin.
138pub fn write_auth_switch_response(out: &mut Vec<u8>, auth_data: &[u8]) {
139    out.extend_from_slice(auth_data);
140}
141
142// ============================================================================
143// Authentication Plugins
144// ============================================================================
145
146/// mysql_native_password authentication
147///
148/// This is the traditional MySQL authentication method using SHA1.
149/// Formula: SHA1(password) XOR SHA1(challenge + SHA1(SHA1(password)))
150///
151/// # Arguments
152/// * `password` - Plain text password
153/// * `challenge` - 20-byte challenge from server (auth_plugin_data)
154///
155/// # Returns
156/// 20-byte authentication response
157pub fn auth_mysql_native_password(password: &str, challenge: &[u8]) -> [u8; 20] {
158    use sha1::{Digest, Sha1};
159
160    if password.is_empty() {
161        return [0_u8; 20];
162    }
163
164    // stage1_hash = SHA1(password)
165    let stage1_hash = Sha1::digest(password.as_bytes());
166
167    // stage2_hash = SHA1(stage1_hash)
168    let stage2_hash = Sha1::digest(stage1_hash);
169
170    // token_hash = SHA1(challenge + stage2_hash)
171    let mut hasher = Sha1::new();
172    hasher.update(challenge);
173    hasher.update(stage2_hash);
174    let token_hash = hasher.finalize();
175
176    // result = stage1_hash XOR token_hash
177    let mut result = [0_u8; 20];
178    for i in 0..20 {
179        result[i] = stage1_hash[i] ^ token_hash[i];
180    }
181
182    result
183}
184
185/// caching_sha2_password authentication - initial response
186///
187/// This is the default authentication method in MySQL 8.0+.
188/// Uses SHA256 hashing instead of SHA1.
189/// Formula: XOR(SHA256(password), SHA256(SHA256(SHA256(password)), challenge))
190///
191/// # Arguments
192/// * `password` - Plain text password
193/// * `challenge` - 20-byte challenge from server (auth_plugin_data)
194///
195/// # Returns
196/// 32-byte authentication response
197pub fn auth_caching_sha2_password(password: &str, challenge: &[u8]) -> [u8; 32] {
198    use sha2::{Digest, Sha256};
199
200    if password.is_empty() {
201        return [0_u8; 32];
202    }
203
204    // stage1 = SHA256(password)
205    let stage1 = Sha256::digest(password.as_bytes());
206
207    // stage2 = SHA256(stage1)
208    let stage2 = Sha256::digest(stage1);
209
210    // scramble = SHA256(stage2 + challenge)
211    let mut hasher = Sha256::new();
212    hasher.update(stage2);
213    hasher.update(challenge);
214    let scramble = hasher.finalize();
215
216    // result = stage1 XOR scramble
217    let mut result = [0_u8; 32];
218    for i in 0..32 {
219        result[i] = stage1[i] ^ scramble[i];
220    }
221
222    result
223}
224
225/// caching_sha2_password fast auth result
226///
227/// After sending the initial auth response, server may respond with:
228/// - 0x03 (fast auth success) - cached authentication succeeded
229/// - 0x04 (full auth required) - need to send password via RSA or cleartext
230#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum CachingSha2PasswordFastAuthResult {
232    Success,
233    FullAuthRequired,
234}
235
236/// Read caching_sha2_password fast auth result
237pub fn read_caching_sha2_password_fast_auth_result(
238    payload: &[u8],
239) -> Result<CachingSha2PasswordFastAuthResult> {
240    if payload.is_empty() {
241        return Err(Error::LibraryBug(eyre!(
242            "empty payload for caching_sha2_password fast auth result"
243        )));
244    }
245
246    match payload[0] {
247        0x03 => Ok(CachingSha2PasswordFastAuthResult::Success),
248        0x04 => Ok(CachingSha2PasswordFastAuthResult::FullAuthRequired),
249        _ => Err(Error::LibraryBug(eyre!(
250            "unexpected caching_sha2_password fast auth result: 0x{:02X}",
251            payload[0]
252        ))),
253    }
254}
255
256/// Encrypt password for caching_sha2_password full auth via RSA.
257///
258/// XORs the null-terminated password with the scramble (cyclically),
259/// then RSA-OAEP-SHA1 encrypts with the server's public key.
260fn rsa_encrypt_password(password: &str, scramble: &[u8], pem: &str) -> Result<Vec<u8>> {
261    use rsa::pkcs8::DecodePublicKey;
262    use rsa::{Oaep, RsaPublicKey};
263
264    let public_key = RsaPublicKey::from_public_key_pem(pem)
265        .map_err(|e| Error::LibraryBug(eyre!("failed to parse RSA public key: {}", e)))?;
266
267    if scramble.is_empty() {
268        return Err(Error::LibraryBug(eyre!(
269            "empty scramble in rsa_encrypt_password"
270        )));
271    }
272
273    // XOR (password + '\0') with scramble repeated cyclically
274    let mut buf = Vec::with_capacity(password.len() + 1);
275    buf.extend_from_slice(password.as_bytes());
276    buf.push(0);
277
278    for (byte, key) in buf.iter_mut().zip(scramble.iter().cycle()) {
279        *byte ^= key;
280    }
281
282    let padding = Oaep::new::<sha1::Sha1>();
283    public_key
284        .encrypt(&mut rsa::rand_core::OsRng, padding, &buf)
285        .map_err(|e| Error::LibraryBug(eyre!("RSA encryption failed: {}", e)))
286}
287
288// ============================================================================
289// State Machine API for Handshake
290// ============================================================================
291
292/// Write SSL request packet (sent before HandshakeResponse when TLS is enabled)
293fn write_ssl_request(
294    out: &mut Vec<u8>,
295    capability_flags: CapabilityFlags,
296    mariadb_capabilities: MariadbCapabilityFlags,
297) {
298    // capability flags (4 bytes)
299    write_int_4(out, capability_flags.bits());
300
301    // max packet size (4 bytes)
302    write_int_4(out, MAX_ALLOWED_PACKET);
303
304    // charset (1 byte)
305    write_int_1(out, UTF8MB4_GENERAL_CI);
306
307    // reserved (23 bytes of 0x00)
308    out.extend_from_slice(&[0_u8; 19]);
309
310    if capability_flags.is_mariadb() {
311        write_int_4(out, mariadb_capabilities.bits());
312    } else {
313        write_int_4(out, 0);
314    }
315}
316
317/// Action returned by the Handshake state machine indicating what I/O operation is needed next
318pub enum HandshakeAction<'buf> {
319    /// Read a packet into the provided buffer
320    ReadPacket(&'buf mut Vec<u8>),
321
322    /// Write the prepared packet with given sequence_id, then read next response
323    WritePacket { sequence_id: u8 },
324
325    /// Write SSL request, then upgrade stream to TLS
326    UpgradeTls { sequence_id: u8 },
327
328    /// Handshake complete - call finish() to get results
329    Finished,
330}
331
332/// Internal state of the handshake state machine
333enum HandshakeState {
334    /// Initial state - need to read initial handshake from server
335    Start,
336    /// Waiting for initial handshake packet to be read
337    WaitingInitialHandshake,
338    /// SSL request written, waiting for TLS upgrade to complete
339    WaitingTlsUpgrade,
340    /// Handshake response written, waiting for auth result
341    WaitingAuthResult,
342    /// After auth switch response. `caching_sha2` = whether we might receive AuthMoreData (0x01).
343    WaitingFinalAuthResult { caching_sha2: bool },
344    /// After caching_sha2 fast auth success (0x03) — waiting for the OK packet.
345    WaitingCachingSha2FastAuthOk,
346    /// After requesting RSA public key (0x02) — waiting for AuthMoreData with PEM key.
347    WaitingRsaPublicKey,
348    /// Connected (terminal state)
349    Connected,
350}
351
352/// State machine for MySQL handshake
353///
354/// Pure parsing and packet generation state machine without I/O dependencies.
355pub struct Handshake<'a> {
356    state: HandshakeState,
357    opts: &'a Opts,
358    initial_handshake: Option<InitialHandshake>,
359    next_sequence_id: u8,
360    capability_flags: Option<CapabilityFlags>,
361    mariadb_capabilities: Option<MariadbCapabilityFlags>,
362}
363
364impl<'a> Handshake<'a> {
365    /// Create a new handshake state machine
366    pub fn new(opts: &'a Opts) -> Self {
367        Self {
368            state: HandshakeState::Start,
369            opts,
370            initial_handshake: None,
371            next_sequence_id: 1,
372            capability_flags: None,
373            mariadb_capabilities: None,
374        }
375    }
376
377    /// Drive the state machine forward
378    ///
379    /// Returns an action indicating what I/O operation the caller should perform.
380    pub fn step<'buf>(&mut self, buffer_set: &'buf mut BufferSet) -> Result<HandshakeAction<'buf>> {
381        match &mut self.state {
382            HandshakeState::Start => {
383                self.state = HandshakeState::WaitingInitialHandshake;
384                Ok(HandshakeAction::ReadPacket(
385                    &mut buffer_set.initial_handshake,
386                ))
387            }
388
389            HandshakeState::WaitingInitialHandshake => {
390                let handshake = read_initial_handshake(&buffer_set.initial_handshake)?;
391
392                let mut client_caps = CAPABILITIES_ALWAYS_ENABLED
393                    | (self.opts.capabilities & CAPABILITIES_CONFIGURABLE);
394                if self.opts.db.is_some() {
395                    client_caps |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
396                }
397                if self.opts.tls {
398                    client_caps |= CapabilityFlags::CLIENT_SSL;
399                }
400
401                let negotiated_caps = client_caps & handshake.capability_flags;
402                let mariadb_caps = if negotiated_caps.is_mariadb() {
403                    if !handshake
404                        .mariadb_capabilities
405                        .contains(MARIADB_CAPABILITIES_ENABLED)
406                    {
407                        return Err(Error::Unsupported(format!(
408                            "MariaDB server does not support the required capabilities. Server: {:?} Required: {:?}",
409                            handshake.mariadb_capabilities, MARIADB_CAPABILITIES_ENABLED
410                        )));
411                    }
412                    MARIADB_CAPABILITIES_ENABLED
413                } else {
414                    MariadbCapabilityFlags::empty()
415                };
416
417                // Store capabilities and initial handshake
418                self.capability_flags = Some(negotiated_caps);
419                self.mariadb_capabilities = Some(mariadb_caps);
420                self.initial_handshake = Some(handshake);
421
422                // TLS: SSLRequest + HandshakeResponse
423                if self.opts.tls && negotiated_caps.contains(CapabilityFlags::CLIENT_SSL) {
424                    write_ssl_request(buffer_set.new_write_buffer(), negotiated_caps, mariadb_caps);
425
426                    let seq = self.next_sequence_id;
427                    self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
428                    self.state = HandshakeState::WaitingTlsUpgrade;
429
430                    Ok(HandshakeAction::UpgradeTls { sequence_id: seq })
431                } else {
432                    // No TLS: HandshakeResponse
433                    self.write_handshake_response(buffer_set)?;
434                    let seq = self.next_sequence_id;
435                    self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
436                    self.state = HandshakeState::WaitingAuthResult;
437
438                    Ok(HandshakeAction::WritePacket { sequence_id: seq })
439                }
440            }
441
442            HandshakeState::WaitingTlsUpgrade => {
443                // TLS upgrade completed, now send handshake response
444                self.write_handshake_response(buffer_set)?;
445
446                let seq = self.next_sequence_id;
447                self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
448                self.state = HandshakeState::WaitingAuthResult;
449
450                Ok(HandshakeAction::WritePacket { sequence_id: seq })
451            }
452
453            HandshakeState::WaitingAuthResult => {
454                let payload = &buffer_set.read_buffer[..];
455                if payload.is_empty() {
456                    return Err(Error::LibraryBug(eyre!(
457                        "empty payload while waiting for auth result"
458                    )));
459                }
460
461                // Get initial plugin name from stored handshake
462                let initial_handshake = self.initial_handshake.as_ref().ok_or_else(|| {
463                    Error::LibraryBug(eyre!("initial_handshake not set in WaitingAuthResult"))
464                })?;
465                let initial_plugin =
466                    &buffer_set.initial_handshake[initial_handshake.auth_plugin_name.clone()];
467
468                match payload[0] {
469                    0x00 => {
470                        // OK packet - authentication succeeded
471                        self.state = HandshakeState::Connected;
472                        Ok(HandshakeAction::Finished)
473                    }
474                    0xFF => {
475                        // ERR packet - authentication failed
476                        Err(ErrPayloadBytes(payload).into())
477                    }
478                    0x01 => {
479                        // AuthMoreData — caching_sha2_password fast auth result
480                        if initial_plugin == b"caching_sha2_password" {
481                            self.handle_auth_more_data(buffer_set)
482                        } else {
483                            Err(Error::LibraryBug(eyre!(
484                                "unexpected AuthMoreData (0x01) for plugin {:?}",
485                                String::from_utf8_lossy(initial_plugin)
486                            )))
487                        }
488                    }
489                    0xFE => {
490                        // Auth switch request
491                        let auth_switch = read_auth_switch_request(payload)?;
492
493                        // Compute auth response for new plugin
494                        let (auth_response, is_caching_sha2) = match auth_switch.plugin_name {
495                            b"mysql_native_password" => (
496                                auth_mysql_native_password(
497                                    &self.opts.password,
498                                    auth_switch.plugin_data,
499                                )
500                                .to_vec(),
501                                false,
502                            ),
503                            b"caching_sha2_password" => (
504                                auth_caching_sha2_password(
505                                    &self.opts.password,
506                                    auth_switch.plugin_data,
507                                )
508                                .to_vec(),
509                                true,
510                            ),
511                            plugin => {
512                                return Err(Error::Unsupported(
513                                    String::from_utf8_lossy(plugin).to_string(),
514                                ));
515                            }
516                        };
517
518                        write_auth_switch_response(buffer_set.new_write_buffer(), &auth_response);
519
520                        let seq = self.next_sequence_id;
521                        self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
522                        self.state = HandshakeState::WaitingFinalAuthResult {
523                            caching_sha2: is_caching_sha2,
524                        };
525
526                        Ok(HandshakeAction::WritePacket { sequence_id: seq })
527                    }
528                    header => Err(Error::LibraryBug(eyre!(
529                        "unexpected packet header 0x{:02X} while waiting for auth result",
530                        header
531                    ))),
532                }
533            }
534
535            HandshakeState::WaitingFinalAuthResult { caching_sha2 } => {
536                let payload = &buffer_set.read_buffer[..];
537                if payload.is_empty() {
538                    return Err(Error::LibraryBug(eyre!(
539                        "empty payload while waiting for final auth result"
540                    )));
541                }
542
543                match payload[0] {
544                    0x00 => {
545                        // OK packet - authentication succeeded
546                        self.state = HandshakeState::Connected;
547                        Ok(HandshakeAction::Finished)
548                    }
549                    0xFF => {
550                        // ERR packet - authentication failed
551                        Err(ErrPayloadBytes(payload).into())
552                    }
553                    0x01 if *caching_sha2 => self.handle_auth_more_data(buffer_set),
554                    header => Err(Error::LibraryBug(eyre!(
555                        "unexpected packet header 0x{:02X} while waiting for final auth result",
556                        header
557                    ))),
558                }
559            }
560
561            HandshakeState::WaitingCachingSha2FastAuthOk => {
562                let payload = &buffer_set.read_buffer[..];
563                if payload.is_empty() {
564                    return Err(Error::LibraryBug(eyre!(
565                        "empty payload while waiting for caching_sha2 OK"
566                    )));
567                }
568
569                match payload[0] {
570                    0x00 => {
571                        self.state = HandshakeState::Connected;
572                        Ok(HandshakeAction::Finished)
573                    }
574                    0xFF => Err(ErrPayloadBytes(payload).into()),
575                    header => Err(Error::LibraryBug(eyre!(
576                        "unexpected packet header 0x{:02X} while waiting for caching_sha2 OK",
577                        header
578                    ))),
579                }
580            }
581
582            HandshakeState::WaitingRsaPublicKey => {
583                let payload = &buffer_set.read_buffer[..];
584                if payload.is_empty() {
585                    return Err(Error::LibraryBug(eyre!(
586                        "empty payload while waiting for RSA public key"
587                    )));
588                }
589
590                match payload[0] {
591                    0xFF => return Err(ErrPayloadBytes(payload).into()),
592                    0x01 if payload.len() >= 2 => {}
593                    header => {
594                        return Err(Error::LibraryBug(eyre!(
595                            "expected AuthMoreData (0x01) with RSA public key, got 0x{:02X}",
596                            header
597                        )));
598                    }
599                }
600
601                let pem = std::str::from_utf8(&payload[1..]).map_err(|e| {
602                    Error::LibraryBug(eyre!("RSA public key is not valid UTF-8: {}", e))
603                })?;
604
605                let handshake = self
606                    .initial_handshake
607                    .as_ref()
608                    .ok_or_else(|| Error::LibraryBug(eyre!("initial_handshake not set")))?;
609
610                let encrypted =
611                    rsa_encrypt_password(&self.opts.password, &handshake.auth_plugin_data, pem)?;
612
613                let out = buffer_set.new_write_buffer();
614                out.extend_from_slice(&encrypted);
615
616                let seq = self.next_sequence_id;
617                self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
618                self.state = HandshakeState::WaitingFinalAuthResult {
619                    caching_sha2: false,
620                };
621
622                Ok(HandshakeAction::WritePacket { sequence_id: seq })
623            }
624
625            HandshakeState::Connected => Err(Error::LibraryBug(eyre!(
626                "step() called after handshake completed"
627            ))),
628        }
629    }
630
631    /// Consume the state machine and return the connection info
632    ///
633    /// Returns an error if called before handshake is complete (before Finished action)
634    pub fn finish(self) -> Result<(InitialHandshake, CapabilityFlags, MariadbCapabilityFlags)> {
635        if !matches!(self.state, HandshakeState::Connected) {
636            return Err(Error::LibraryBug(eyre!(
637                "finish() called before handshake completed"
638            )));
639        }
640
641        let initial_handshake = self.initial_handshake.ok_or_else(|| {
642            Error::LibraryBug(eyre!("initial_handshake not set in Connected state"))
643        })?;
644        let capability_flags = self.capability_flags.ok_or_else(|| {
645            Error::LibraryBug(eyre!("capability_flags not set in Connected state"))
646        })?;
647        let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
648            Error::LibraryBug(eyre!("mariadb_capabilities not set in Connected state"))
649        })?;
650
651        Ok((initial_handshake, capability_flags, mariadb_capabilities))
652    }
653
654    /// Write handshake response packet (HandshakeResponse41)
655    fn write_handshake_response(&self, buffer_set: &mut BufferSet) -> Result<()> {
656        buffer_set.new_write_buffer();
657
658        let handshake = self.initial_handshake.as_ref().ok_or_else(|| {
659            Error::LibraryBug(eyre!(
660                "initial_handshake not set in write_handshake_response"
661            ))
662        })?;
663        let capability_flags = self.capability_flags.ok_or_else(|| {
664            Error::LibraryBug(eyre!(
665                "capability_flags not set in write_handshake_response"
666            ))
667        })?;
668        let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
669            Error::LibraryBug(eyre!(
670                "mariadb_capabilities not set in write_handshake_response"
671            ))
672        })?;
673
674        // Compute auth response based on plugin name
675        let auth_plugin_name = &buffer_set.initial_handshake[handshake.auth_plugin_name.clone()];
676        let auth_response = {
677            match auth_plugin_name {
678                b"mysql_native_password" => {
679                    auth_mysql_native_password(&self.opts.password, &handshake.auth_plugin_data)
680                        .to_vec()
681                }
682                b"caching_sha2_password" => {
683                    auth_caching_sha2_password(&self.opts.password, &handshake.auth_plugin_data)
684                        .to_vec()
685                }
686                plugin => {
687                    return Err(Error::Unsupported(
688                        String::from_utf8_lossy(plugin).to_string(),
689                    ));
690                }
691            }
692        };
693
694        let out = &mut buffer_set.write_buffer;
695        // capability flags (4 bytes)
696        write_int_4(out, capability_flags.bits());
697        // max packet size (4 bytes)
698        write_int_4(out, MAX_ALLOWED_PACKET);
699        // charset (1 byte)
700        write_int_1(out, UTF8MB4_GENERAL_CI);
701        // reserved (19 bytes) + MariaDB capabilities (4 bytes) = 23 bytes
702        out.extend_from_slice(&[0_u8; 19]);
703        write_int_4(out, mariadb_capabilities.bits());
704        // username (null-terminated)
705        write_string_null(out, self.opts.user.as_bytes());
706        // auth response (length-encoded)
707        if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
708            write_bytes_lenenc(out, &auth_response);
709        } else {
710            write_int_1(out, auth_response.len() as u8);
711            out.extend_from_slice(&auth_response);
712        }
713        // database name (null-terminated, if CLIENT_CONNECT_WITH_DB)
714        if let Some(db) = &self.opts.db {
715            write_string_null(out, db.as_bytes());
716        }
717
718        // auth plugin name (null-terminated, if CLIENT_PLUGIN_AUTH)
719        if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
720            write_string_null(out, auth_plugin_name);
721        }
722
723        Ok(())
724    }
725
726    /// Handle AuthMoreData (0x01) packet for caching_sha2_password.
727    ///
728    /// Called from both `WaitingAuthResult` and `WaitingFinalAuthResult { caching_sha2: true }`.
729    fn handle_auth_more_data<'buf>(
730        &mut self,
731        buffer_set: &'buf mut BufferSet,
732    ) -> Result<HandshakeAction<'buf>> {
733        let payload = &buffer_set.read_buffer[..];
734        if payload.len() < 2 {
735            return Err(Error::LibraryBug(eyre!(
736                "AuthMoreData packet too short: {} bytes",
737                payload.len()
738            )));
739        }
740
741        let result = read_caching_sha2_password_fast_auth_result(&payload[1..])?;
742
743        match result {
744            CachingSha2PasswordFastAuthResult::Success => {
745                // Fast auth succeeded — server will send OK next
746                self.state = HandshakeState::WaitingCachingSha2FastAuthOk;
747                Ok(HandshakeAction::ReadPacket(&mut buffer_set.read_buffer))
748            }
749            CachingSha2PasswordFastAuthResult::FullAuthRequired => {
750                let capability_flags = self
751                    .capability_flags
752                    .ok_or_else(|| Error::LibraryBug(eyre!("capability_flags not set")))?;
753
754                if capability_flags.contains(CapabilityFlags::CLIENT_SSL) {
755                    // TLS is active — send cleartext password (null-terminated)
756                    let out = buffer_set.new_write_buffer();
757                    out.extend_from_slice(self.opts.password.as_bytes());
758                    out.push(0);
759
760                    let seq = self.next_sequence_id;
761                    self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
762                    self.state = HandshakeState::WaitingFinalAuthResult {
763                        caching_sha2: false,
764                    };
765
766                    Ok(HandshakeAction::WritePacket { sequence_id: seq })
767                } else {
768                    // No TLS — request server's RSA public key
769                    let out = buffer_set.new_write_buffer();
770                    out.push(0x02);
771
772                    let seq = self.next_sequence_id;
773                    self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
774                    self.state = HandshakeState::WaitingRsaPublicKey;
775
776                    Ok(HandshakeAction::WritePacket { sequence_id: seq })
777                }
778            }
779        }
780    }
781}
782
783#[cfg(test)]
784mod tests {
785    use super::*;
786
787    #[test]
788    fn handshake_fixed_fields_has_alignment_of_1() {
789        assert_eq!(std::mem::align_of::<HandshakeFixedFields>(), 1);
790    }
791
792    #[test]
793    #[allow(clippy::unwrap_used)]
794    fn rsa_encrypt_password_xors_and_encrypts() {
795        use rsa::RsaPrivateKey;
796        use rsa::pkcs8::{EncodePublicKey, LineEnding};
797
798        let mut rng = rsa::rand_core::OsRng;
799        let private_key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
800        let public_key = rsa::RsaPublicKey::from(&private_key);
801        let pem = public_key.to_public_key_pem(LineEnding::LF).unwrap();
802
803        let password = "test_password";
804        let scramble = b"01234567890123456789";
805
806        let encrypted = super::rsa_encrypt_password(password, scramble, &pem).unwrap();
807
808        // Decrypt and verify
809        use rsa::Oaep;
810        let padding = Oaep::new::<sha1::Sha1>();
811        let decrypted = private_key.decrypt(padding, &encrypted).unwrap();
812
813        // Decrypted should be XOR(password + '\0', scramble)
814        let mut expected = password.as_bytes().to_vec();
815        expected.push(0);
816        for (byte, key) in expected.iter_mut().zip(scramble.iter().cycle()) {
817            *byte ^= key;
818        }
819        assert_eq!(decrypted, expected);
820    }
821
822    #[test]
823    fn fast_auth_result_parsing() {
824        assert_eq!(
825            read_caching_sha2_password_fast_auth_result(&[0x03]).unwrap(),
826            CachingSha2PasswordFastAuthResult::Success,
827        );
828        assert_eq!(
829            read_caching_sha2_password_fast_auth_result(&[0x04]).unwrap(),
830            CachingSha2PasswordFastAuthResult::FullAuthRequired,
831        );
832        assert!(read_caching_sha2_password_fast_auth_result(&[0x05]).is_err());
833        assert!(read_caching_sha2_password_fast_auth_result(&[]).is_err());
834    }
835}