zero_mysql/protocol/connection/
handshake.rs

1use std::hint::cold_path;
2use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
3use zerocopy::{FromBytes, Immutable, KnownLayout};
4
5use crate::buffer::BufferSet;
6use crate::constant::{
7    CAPABILITIES_ALWAYS_ENABLED, CAPABILITIES_CONFIGURABLE, CapabilityFlags,
8    MARIADB_CAPABILITIES_ENABLED, MAX_ALLOWED_PACKET, MariadbCapabilityFlags, UTF8MB4_GENERAL_CI,
9};
10use crate::error::{Error, Result, eyre};
11use crate::opts::Opts;
12use crate::protocol::primitive::*;
13use crate::protocol::response::ErrPayloadBytes;
14
15#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
16#[repr(C, packed)]
17struct HandshakeFixedFields {
18    connection_id: U32LE,
19    auth_data_part1: [u8; 8],
20    _filler1: u8,
21    capability_flags_lower: U16LE,
22    charset: u8,
23    status_flags: U16LE,
24    capability_flags_upper: U16LE,
25    auth_data_len: u8,
26    _fillter2: [u8; 6],
27    mariadb_capabilities: U32LE,
28}
29
30#[derive(Debug, Clone)]
31pub struct InitialHandshake {
32    pub protocol_version: u8,
33    pub server_version: std::ops::Range<usize>,
34    pub connection_id: u32,
35    pub auth_plugin_data: Vec<u8>,
36    pub capability_flags: CapabilityFlags,
37    pub mariadb_capabilities: MariadbCapabilityFlags,
38    pub charset: u8,
39    pub status_flags: crate::constant::ServerStatusFlags,
40    pub auth_plugin_name: std::ops::Range<usize>,
41}
42
43/// Read initial handshake packet from server
44pub fn read_initial_handshake(payload: &[u8]) -> Result<InitialHandshake> {
45    let (protocol_version, data) = read_int_1(payload)?;
46
47    if protocol_version == 0xFF {
48        cold_path();
49        Err(ErrPayloadBytes(payload))?
50    }
51
52    let server_version_start = payload.len() - data.len();
53    let (server_version_bytes, data) = read_string_null(data)?;
54    let server_version = server_version_start..server_version_start + server_version_bytes.len();
55
56    let (fixed, data) = HandshakeFixedFields::ref_from_prefix(data)?;
57
58    let connection_id = fixed.connection_id.get();
59    let charset = fixed.charset;
60    let status_flags = fixed.status_flags.get();
61    let capability_flags = CapabilityFlags::from_bits(
62        ((fixed.capability_flags_upper.get() as u32) << 16)
63            | (fixed.capability_flags_lower.get() as u32),
64    )
65    .ok_or_else(|| Error::LibraryBug(eyre!("invalid capability flags from server")))?;
66    let mariadb_capabilities = MariadbCapabilityFlags::from_bits(fixed.mariadb_capabilities.get())
67        .ok_or_else(|| Error::LibraryBug(eyre!("invalid mariadb capability flags from server")))?;
68    let auth_data_len = fixed.auth_data_len;
69
70    let auth_data_2_len = (auth_data_len as usize).saturating_sub(9).max(12);
71    let (auth_data_2, data) = read_string_fix(data, auth_data_2_len)?;
72    let (_reserved, data) = read_int_1(data)?;
73
74    let mut auth_plugin_data = Vec::new();
75    auth_plugin_data.extend_from_slice(&fixed.auth_data_part1);
76    auth_plugin_data.extend_from_slice(auth_data_2);
77
78    let auth_plugin_name_start = payload.len() - data.len();
79    let (auth_plugin_name_bytes, rest) = read_string_null(data)?;
80    let auth_plugin_name =
81        auth_plugin_name_start..auth_plugin_name_start + auth_plugin_name_bytes.len();
82
83    if !rest.is_empty() {
84        return Err(Error::LibraryBug(eyre!(
85            "unexpected trailing data in handshake packet: {} bytes",
86            rest.len()
87        )));
88    }
89
90    Ok(InitialHandshake {
91        protocol_version,
92        server_version,
93        connection_id,
94        auth_plugin_data,
95        capability_flags,
96        mariadb_capabilities,
97        charset,
98        status_flags: crate::constant::ServerStatusFlags::from_bits_truncate(status_flags),
99        auth_plugin_name,
100    })
101}
102
103/// Auth switch request from server
104#[derive(Debug, Clone)]
105pub struct AuthSwitchRequest<'buf> {
106    pub plugin_name: &'buf [u8],
107    pub plugin_data: &'buf [u8],
108}
109
110/// Read auth switch request (0xFE with length >= 9)
111pub fn read_auth_switch_request(payload: &[u8]) -> Result<AuthSwitchRequest<'_>> {
112    let (header, mut data) = read_int_1(payload)?;
113    if header != 0xFE {
114        return Err(Error::LibraryBug(eyre!(
115            "expected auth switch header 0xFE, got 0x{:02X}",
116            header
117        )));
118    }
119
120    let (plugin_name, rest) = read_string_null(data)?;
121    data = rest;
122
123    if let Some(0) = data.last() {
124        Ok(AuthSwitchRequest {
125            plugin_name,
126            plugin_data: &data[..data.len() - 1],
127        })
128    } else {
129        Err(Error::LibraryBug(eyre!(
130            "auth switch request plugin data not null-terminated"
131        )))
132    }
133}
134
135/// Write auth switch response
136///
137/// Client sends the authentication data computed using the requested plugin.
138pub fn write_auth_switch_response(out: &mut Vec<u8>, auth_data: &[u8]) {
139    out.extend_from_slice(auth_data);
140}
141
142// ============================================================================
143// Authentication Plugins
144// ============================================================================
145
146/// mysql_native_password authentication
147///
148/// This is the traditional MySQL authentication method using SHA1.
149/// Formula: SHA1(password) XOR SHA1(challenge + SHA1(SHA1(password)))
150///
151/// # Arguments
152/// * `password` - Plain text password
153/// * `challenge` - 20-byte challenge from server (auth_plugin_data)
154///
155/// # Returns
156/// 20-byte authentication response
157pub fn auth_mysql_native_password(password: &str, challenge: &[u8]) -> [u8; 20] {
158    use sha1::{Digest, Sha1};
159
160    if password.is_empty() {
161        return [0_u8; 20];
162    }
163
164    // stage1_hash = SHA1(password)
165    let stage1_hash = Sha1::digest(password.as_bytes());
166
167    // stage2_hash = SHA1(stage1_hash)
168    let stage2_hash = Sha1::digest(stage1_hash);
169
170    // token_hash = SHA1(challenge + stage2_hash)
171    let mut hasher = Sha1::new();
172    hasher.update(challenge);
173    hasher.update(stage2_hash);
174    let token_hash = hasher.finalize();
175
176    // result = stage1_hash XOR token_hash
177    let mut result = [0_u8; 20];
178    for i in 0..20 {
179        result[i] = stage1_hash[i] ^ token_hash[i];
180    }
181
182    result
183}
184
185/// caching_sha2_password authentication - initial response
186///
187/// This is the default authentication method in MySQL 8.0+.
188/// Uses SHA256 hashing instead of SHA1.
189/// Formula: XOR(SHA256(password), SHA256(SHA256(SHA256(password)), challenge))
190///
191/// # Arguments
192/// * `password` - Plain text password
193/// * `challenge` - 20-byte challenge from server (auth_plugin_data)
194///
195/// # Returns
196/// 32-byte authentication response
197pub fn auth_caching_sha2_password(password: &str, challenge: &[u8]) -> [u8; 32] {
198    use sha2::{Digest, Sha256};
199
200    if password.is_empty() {
201        return [0_u8; 32];
202    }
203
204    // stage1 = SHA256(password)
205    let stage1 = Sha256::digest(password.as_bytes());
206
207    // stage2 = SHA256(stage1)
208    let stage2 = Sha256::digest(stage1);
209
210    // scramble = SHA256(stage2 + challenge)
211    let mut hasher = Sha256::new();
212    hasher.update(stage2);
213    hasher.update(challenge);
214    let scramble = hasher.finalize();
215
216    // result = stage1 XOR scramble
217    let mut result = [0_u8; 32];
218    for i in 0..32 {
219        result[i] = stage1[i] ^ scramble[i];
220    }
221
222    result
223}
224
225/// caching_sha2_password fast auth result
226///
227/// After sending the initial auth response, server may respond with:
228/// - 0x03 (fast auth success) - cached authentication succeeded
229/// - 0x04 (full auth required) - need to send password via RSA or cleartext
230#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum CachingSha2PasswordFastAuthResult {
232    Success,
233    FullAuthRequired,
234}
235
236/// Read caching_sha2_password fast auth result
237pub fn read_caching_sha2_password_fast_auth_result(
238    payload: &[u8],
239) -> Result<CachingSha2PasswordFastAuthResult> {
240    if payload.is_empty() {
241        return Err(Error::LibraryBug(eyre!(
242            "empty payload for caching_sha2_password fast auth result"
243        )));
244    }
245
246    match payload[0] {
247        0x03 => Ok(CachingSha2PasswordFastAuthResult::Success),
248        0x04 => Ok(CachingSha2PasswordFastAuthResult::FullAuthRequired),
249        _ => Err(Error::LibraryBug(eyre!(
250            "unexpected caching_sha2_password fast auth result: 0x{:02X}",
251            payload[0]
252        ))),
253    }
254}
255
256// ============================================================================
257// State Machine API for Handshake
258// ============================================================================
259
260/// Write SSL request packet (sent before HandshakeResponse when TLS is enabled)
261fn write_ssl_request(
262    out: &mut Vec<u8>,
263    capability_flags: CapabilityFlags,
264    mariadb_capabilities: MariadbCapabilityFlags,
265) {
266    // capability flags (4 bytes)
267    write_int_4(out, capability_flags.bits());
268
269    // max packet size (4 bytes)
270    write_int_4(out, MAX_ALLOWED_PACKET);
271
272    // charset (1 byte)
273    write_int_1(out, UTF8MB4_GENERAL_CI);
274
275    // reserved (23 bytes of 0x00)
276    out.extend_from_slice(&[0_u8; 19]);
277
278    if capability_flags.is_mariadb() {
279        write_int_4(out, mariadb_capabilities.bits());
280    } else {
281        write_int_4(out, 0);
282    }
283}
284
285/// Action returned by the Handshake state machine indicating what I/O operation is needed next
286pub enum HandshakeAction<'buf> {
287    /// Read a packet into the provided buffer
288    ReadPacket(&'buf mut Vec<u8>),
289
290    /// Write the prepared packet with given sequence_id, then read next response
291    WritePacket { sequence_id: u8 },
292
293    /// Write SSL request, then upgrade stream to TLS
294    UpgradeTls { sequence_id: u8 },
295
296    /// Handshake complete - call finish() to get results
297    Finished,
298}
299
300/// Internal state of the handshake state machine
301enum HandshakeState {
302    /// Initial state - need to read initial handshake from server
303    Start,
304    /// Waiting for initial handshake packet to be read
305    WaitingInitialHandshake,
306    /// SSL request written, waiting for TLS upgrade to complete
307    WaitingTlsUpgrade,
308    /// Handshake response written, waiting for auth result
309    WaitingAuthResult,
310    /// Auth switch response written, waiting for final result
311    WaitingFinalAuthResult,
312    /// Connected (terminal state)
313    Connected,
314}
315
316/// State machine for MySQL handshake
317///
318/// Pure parsing and packet generation state machine without I/O dependencies.
319pub struct Handshake<'a> {
320    state: HandshakeState,
321    opts: &'a Opts,
322    initial_handshake: Option<InitialHandshake>,
323    next_sequence_id: u8,
324    capability_flags: Option<CapabilityFlags>,
325    mariadb_capabilities: Option<MariadbCapabilityFlags>,
326}
327
328impl<'a> Handshake<'a> {
329    /// Create a new handshake state machine
330    pub fn new(opts: &'a Opts) -> Self {
331        Self {
332            state: HandshakeState::Start,
333            opts,
334            initial_handshake: None,
335            next_sequence_id: 1,
336            capability_flags: None,
337            mariadb_capabilities: None,
338        }
339    }
340
341    /// Drive the state machine forward
342    ///
343    /// Returns an action indicating what I/O operation the caller should perform.
344    pub fn step<'buf>(&mut self, buffer_set: &'buf mut BufferSet) -> Result<HandshakeAction<'buf>> {
345        match &mut self.state {
346            HandshakeState::Start => {
347                self.state = HandshakeState::WaitingInitialHandshake;
348                Ok(HandshakeAction::ReadPacket(
349                    &mut buffer_set.initial_handshake,
350                ))
351            }
352
353            HandshakeState::WaitingInitialHandshake => {
354                let handshake = read_initial_handshake(&buffer_set.initial_handshake)?;
355
356                let mut client_caps = CAPABILITIES_ALWAYS_ENABLED
357                    | (self.opts.capabilities & CAPABILITIES_CONFIGURABLE);
358                if self.opts.db.is_some() {
359                    client_caps |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
360                }
361                if self.opts.tls {
362                    client_caps |= CapabilityFlags::CLIENT_SSL;
363                }
364
365                let negotiated_caps = client_caps & handshake.capability_flags;
366                let mariadb_caps = if negotiated_caps.is_mariadb() {
367                    if !handshake
368                        .mariadb_capabilities
369                        .contains(MARIADB_CAPABILITIES_ENABLED)
370                    {
371                        return Err(Error::Unsupported(format!(
372                            "MariaDB server does not support the required capabilities. Server: {:?} Required: {:?}",
373                            handshake.mariadb_capabilities, MARIADB_CAPABILITIES_ENABLED
374                        )));
375                    }
376                    MARIADB_CAPABILITIES_ENABLED
377                } else {
378                    MariadbCapabilityFlags::empty()
379                };
380
381                // Store capabilities and initial handshake
382                self.capability_flags = Some(negotiated_caps);
383                self.mariadb_capabilities = Some(mariadb_caps);
384                self.initial_handshake = Some(handshake);
385
386                // TLS: SSLRequest + HandshakeResponse
387                if self.opts.tls && negotiated_caps.contains(CapabilityFlags::CLIENT_SSL) {
388                    write_ssl_request(buffer_set.new_write_buffer(), negotiated_caps, mariadb_caps);
389
390                    let seq = self.next_sequence_id;
391                    self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
392                    self.state = HandshakeState::WaitingTlsUpgrade;
393
394                    Ok(HandshakeAction::UpgradeTls { sequence_id: seq })
395                } else {
396                    // No TLS: HandshakeResponse
397                    self.write_handshake_response(buffer_set)?;
398                    let seq = self.next_sequence_id;
399                    self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
400                    self.state = HandshakeState::WaitingAuthResult;
401
402                    Ok(HandshakeAction::WritePacket { sequence_id: seq })
403                }
404            }
405
406            HandshakeState::WaitingTlsUpgrade => {
407                // TLS upgrade completed, now send handshake response
408                self.write_handshake_response(buffer_set)?;
409
410                let seq = self.next_sequence_id;
411                self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
412                self.state = HandshakeState::WaitingAuthResult;
413
414                Ok(HandshakeAction::WritePacket { sequence_id: seq })
415            }
416
417            HandshakeState::WaitingAuthResult => {
418                let payload = &buffer_set.read_buffer[..];
419                if payload.is_empty() {
420                    return Err(Error::LibraryBug(eyre!(
421                        "empty payload while waiting for auth result"
422                    )));
423                }
424
425                // Get initial plugin name from stored handshake
426                let initial_handshake = self.initial_handshake.as_ref().ok_or_else(|| {
427                    Error::LibraryBug(eyre!("initial_handshake not set in WaitingAuthResult"))
428                })?;
429                let initial_plugin =
430                    &buffer_set.initial_handshake[initial_handshake.auth_plugin_name.clone()];
431
432                match payload[0] {
433                    0x00 => {
434                        // OK packet - authentication succeeded
435                        self.state = HandshakeState::Connected;
436                        Ok(HandshakeAction::Finished)
437                    }
438                    0xFF => {
439                        // ERR packet - authentication failed
440                        Err(ErrPayloadBytes(payload).into())
441                    }
442                    0xFE => {
443                        // Could be auth switch or fast auth result
444                        if initial_plugin == b"caching_sha2_password" && payload.len() == 2 {
445                            // Fast auth result
446                            let result = read_caching_sha2_password_fast_auth_result(payload)?;
447                            match result {
448                                CachingSha2PasswordFastAuthResult::Success => {
449                                    // Need to read final OK packet
450                                    Ok(HandshakeAction::ReadPacket(&mut buffer_set.read_buffer))
451                                }
452                                CachingSha2PasswordFastAuthResult::FullAuthRequired => {
453                                    Err(Error::Unsupported(
454                                        "caching_sha2_password full auth (requires SSL/RSA)"
455                                            .to_string(),
456                                    ))
457                                }
458                            }
459                        } else {
460                            // Auth switch request
461                            let auth_switch = read_auth_switch_request(payload)?;
462
463                            // Compute auth response for new plugin
464                            let auth_response = match auth_switch.plugin_name {
465                                b"mysql_native_password" => auth_mysql_native_password(
466                                    &self.opts.password,
467                                    auth_switch.plugin_data,
468                                )
469                                .to_vec(),
470                                b"caching_sha2_password" => auth_caching_sha2_password(
471                                    &self.opts.password,
472                                    auth_switch.plugin_data,
473                                )
474                                .to_vec(),
475                                plugin => {
476                                    return Err(Error::Unsupported(
477                                        String::from_utf8_lossy(plugin).to_string(),
478                                    ));
479                                }
480                            };
481
482                            write_auth_switch_response(
483                                buffer_set.new_write_buffer(),
484                                &auth_response,
485                            );
486
487                            let seq = self.next_sequence_id;
488                            self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
489                            self.state = HandshakeState::WaitingFinalAuthResult;
490
491                            Ok(HandshakeAction::WritePacket { sequence_id: seq })
492                        }
493                    }
494                    header => Err(Error::LibraryBug(eyre!(
495                        "unexpected packet header 0x{:02X} while waiting for auth result",
496                        header
497                    ))),
498                }
499            }
500
501            HandshakeState::WaitingFinalAuthResult => {
502                let payload = &buffer_set.read_buffer[..];
503                if payload.is_empty() {
504                    return Err(Error::LibraryBug(eyre!(
505                        "empty payload while waiting for final auth result"
506                    )));
507                }
508
509                match payload[0] {
510                    0x00 => {
511                        // OK packet - authentication succeeded
512                        self.state = HandshakeState::Connected;
513                        Ok(HandshakeAction::Finished)
514                    }
515                    0xFF => {
516                        // ERR packet - authentication failed
517                        Err(ErrPayloadBytes(payload).into())
518                    }
519                    header => Err(Error::LibraryBug(eyre!(
520                        "unexpected packet header 0x{:02X} while waiting for final auth result",
521                        header
522                    ))),
523                }
524            }
525
526            HandshakeState::Connected => Err(Error::LibraryBug(eyre!(
527                "step() called after handshake completed"
528            ))),
529        }
530    }
531
532    /// Consume the state machine and return the connection info
533    ///
534    /// Returns an error if called before handshake is complete (before Finished action)
535    pub fn finish(self) -> Result<(InitialHandshake, CapabilityFlags, MariadbCapabilityFlags)> {
536        if !matches!(self.state, HandshakeState::Connected) {
537            return Err(Error::LibraryBug(eyre!(
538                "finish() called before handshake completed"
539            )));
540        }
541
542        let initial_handshake = self.initial_handshake.ok_or_else(|| {
543            Error::LibraryBug(eyre!("initial_handshake not set in Connected state"))
544        })?;
545        let capability_flags = self.capability_flags.ok_or_else(|| {
546            Error::LibraryBug(eyre!("capability_flags not set in Connected state"))
547        })?;
548        let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
549            Error::LibraryBug(eyre!("mariadb_capabilities not set in Connected state"))
550        })?;
551
552        Ok((initial_handshake, capability_flags, mariadb_capabilities))
553    }
554
555    /// Write handshake response packet (HandshakeResponse41)
556    fn write_handshake_response(&self, buffer_set: &mut BufferSet) -> Result<()> {
557        buffer_set.new_write_buffer();
558
559        let handshake = self.initial_handshake.as_ref().ok_or_else(|| {
560            Error::LibraryBug(eyre!(
561                "initial_handshake not set in write_handshake_response"
562            ))
563        })?;
564        let capability_flags = self.capability_flags.ok_or_else(|| {
565            Error::LibraryBug(eyre!(
566                "capability_flags not set in write_handshake_response"
567            ))
568        })?;
569        let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
570            Error::LibraryBug(eyre!(
571                "mariadb_capabilities not set in write_handshake_response"
572            ))
573        })?;
574
575        // Compute auth response based on plugin name
576        let auth_plugin_name = &buffer_set.initial_handshake[handshake.auth_plugin_name.clone()];
577        let auth_response = {
578            match auth_plugin_name {
579                b"mysql_native_password" => {
580                    auth_mysql_native_password(&self.opts.password, &handshake.auth_plugin_data)
581                        .to_vec()
582                }
583                b"caching_sha2_password" => {
584                    auth_caching_sha2_password(&self.opts.password, &handshake.auth_plugin_data)
585                        .to_vec()
586                }
587                plugin => {
588                    return Err(Error::Unsupported(
589                        String::from_utf8_lossy(plugin).to_string(),
590                    ));
591                }
592            }
593        };
594
595        let out = &mut buffer_set.write_buffer;
596        // capability flags (4 bytes)
597        write_int_4(out, capability_flags.bits());
598        // max packet size (4 bytes)
599        write_int_4(out, MAX_ALLOWED_PACKET);
600        // charset (1 byte)
601        write_int_1(out, UTF8MB4_GENERAL_CI);
602        // reserved (19 bytes) + MariaDB capabilities (4 bytes) = 23 bytes
603        out.extend_from_slice(&[0_u8; 19]);
604        write_int_4(out, mariadb_capabilities.bits());
605        // username (null-terminated)
606        write_string_null(out, self.opts.user.as_bytes());
607        // auth response (length-encoded)
608        if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
609            write_bytes_lenenc(out, &auth_response);
610        } else {
611            write_int_1(out, auth_response.len() as u8);
612            out.extend_from_slice(&auth_response);
613        }
614        // database name (null-terminated, if CLIENT_CONNECT_WITH_DB)
615        if let Some(db) = &self.opts.db {
616            write_string_null(out, db.as_bytes());
617        }
618
619        // auth plugin name (null-terminated, if CLIENT_PLUGIN_AUTH)
620        if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
621            write_string_null(out, auth_plugin_name);
622        }
623
624        Ok(())
625    }
626}
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631
632    #[test]
633    fn handshake_fixed_fields_has_alignment_of_1() {
634        assert_eq!(std::mem::align_of::<HandshakeFixedFields>(), 1);
635    }
636}