zero_mysql/protocol/connection/
handshake.rs1use std::hint::cold_path;
2use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
3use zerocopy::{FromBytes, Immutable, KnownLayout};
4
5use crate::buffer::BufferSet;
6use crate::constant::{
7 CAPABILITIES_ALWAYS_ENABLED, CAPABILITIES_CONFIGURABLE, CapabilityFlags,
8 MARIADB_CAPABILITIES_ENABLED, MAX_ALLOWED_PACKET, MariadbCapabilityFlags, UTF8MB4_GENERAL_CI,
9};
10use crate::error::{Error, Result, eyre};
11use crate::opts::Opts;
12use crate::protocol::primitive::*;
13use crate::protocol::response::ErrPayloadBytes;
14
15#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
16#[repr(C, packed)]
17struct HandshakeFixedFields {
18 connection_id: U32LE,
19 auth_data_part1: [u8; 8],
20 _filler1: u8,
21 capability_flags_lower: U16LE,
22 charset: u8,
23 status_flags: U16LE,
24 capability_flags_upper: U16LE,
25 auth_data_len: u8,
26 _fillter2: [u8; 6],
27 mariadb_capabilities: U32LE,
28}
29
30#[derive(Debug, Clone)]
31pub struct InitialHandshake {
32 pub protocol_version: u8,
33 pub server_version: std::ops::Range<usize>,
34 pub connection_id: u32,
35 pub auth_plugin_data: Vec<u8>,
36 pub capability_flags: CapabilityFlags,
37 pub mariadb_capabilities: MariadbCapabilityFlags,
38 pub charset: u8,
39 pub status_flags: crate::constant::ServerStatusFlags,
40 pub auth_plugin_name: std::ops::Range<usize>,
41}
42
43pub fn read_initial_handshake(payload: &[u8]) -> Result<InitialHandshake> {
45 let (protocol_version, data) = read_int_1(payload)?;
46
47 if protocol_version == 0xFF {
48 cold_path();
49 Err(ErrPayloadBytes(payload))?
50 }
51
52 let server_version_start = payload.len() - data.len();
53 let (server_version_bytes, data) = read_string_null(data)?;
54 let server_version = server_version_start..server_version_start + server_version_bytes.len();
55
56 let (fixed, data) = HandshakeFixedFields::ref_from_prefix(data)?;
57
58 let connection_id = fixed.connection_id.get();
59 let charset = fixed.charset;
60 let status_flags = fixed.status_flags.get();
61 let capability_flags = CapabilityFlags::from_bits(
62 ((fixed.capability_flags_upper.get() as u32) << 16)
63 | (fixed.capability_flags_lower.get() as u32),
64 )
65 .ok_or_else(|| Error::LibraryBug(eyre!("invalid capability flags from server")))?;
66 let mariadb_capabilities = MariadbCapabilityFlags::from_bits(fixed.mariadb_capabilities.get())
67 .ok_or_else(|| Error::LibraryBug(eyre!("invalid mariadb capability flags from server")))?;
68 let auth_data_len = fixed.auth_data_len;
69
70 let auth_data_2_len = (auth_data_len as usize).saturating_sub(9).max(12);
71 let (auth_data_2, data) = read_string_fix(data, auth_data_2_len)?;
72 let (_reserved, data) = read_int_1(data)?;
73
74 let mut auth_plugin_data = Vec::new();
75 auth_plugin_data.extend_from_slice(&fixed.auth_data_part1);
76 auth_plugin_data.extend_from_slice(auth_data_2);
77
78 let auth_plugin_name_start = payload.len() - data.len();
79 let (auth_plugin_name_bytes, rest) = read_string_null(data)?;
80 let auth_plugin_name =
81 auth_plugin_name_start..auth_plugin_name_start + auth_plugin_name_bytes.len();
82
83 if !rest.is_empty() {
84 return Err(Error::LibraryBug(eyre!(
85 "unexpected trailing data in handshake packet: {} bytes",
86 rest.len()
87 )));
88 }
89
90 Ok(InitialHandshake {
91 protocol_version,
92 server_version,
93 connection_id,
94 auth_plugin_data,
95 capability_flags,
96 mariadb_capabilities,
97 charset,
98 status_flags: crate::constant::ServerStatusFlags::from_bits_truncate(status_flags),
99 auth_plugin_name,
100 })
101}
102
103#[derive(Debug, Clone)]
105pub struct AuthSwitchRequest<'buf> {
106 pub plugin_name: &'buf [u8],
107 pub plugin_data: &'buf [u8],
108}
109
110pub fn read_auth_switch_request(payload: &[u8]) -> Result<AuthSwitchRequest<'_>> {
112 let (header, mut data) = read_int_1(payload)?;
113 if header != 0xFE {
114 return Err(Error::LibraryBug(eyre!(
115 "expected auth switch header 0xFE, got 0x{:02X}",
116 header
117 )));
118 }
119
120 let (plugin_name, rest) = read_string_null(data)?;
121 data = rest;
122
123 if let Some(0) = data.last() {
124 Ok(AuthSwitchRequest {
125 plugin_name,
126 plugin_data: &data[..data.len() - 1],
127 })
128 } else {
129 Err(Error::LibraryBug(eyre!(
130 "auth switch request plugin data not null-terminated"
131 )))
132 }
133}
134
135pub fn write_auth_switch_response(out: &mut Vec<u8>, auth_data: &[u8]) {
139 out.extend_from_slice(auth_data);
140}
141
142pub fn auth_mysql_native_password(password: &str, challenge: &[u8]) -> [u8; 20] {
158 use sha1::{Digest, Sha1};
159
160 if password.is_empty() {
161 return [0_u8; 20];
162 }
163
164 let stage1_hash = Sha1::digest(password.as_bytes());
166
167 let stage2_hash = Sha1::digest(stage1_hash);
169
170 let mut hasher = Sha1::new();
172 hasher.update(challenge);
173 hasher.update(stage2_hash);
174 let token_hash = hasher.finalize();
175
176 let mut result = [0_u8; 20];
178 for i in 0..20 {
179 result[i] = stage1_hash[i] ^ token_hash[i];
180 }
181
182 result
183}
184
185pub fn auth_caching_sha2_password(password: &str, challenge: &[u8]) -> [u8; 32] {
198 use sha2::{Digest, Sha256};
199
200 if password.is_empty() {
201 return [0_u8; 32];
202 }
203
204 let stage1 = Sha256::digest(password.as_bytes());
206
207 let stage2 = Sha256::digest(stage1);
209
210 let mut hasher = Sha256::new();
212 hasher.update(stage2);
213 hasher.update(challenge);
214 let scramble = hasher.finalize();
215
216 let mut result = [0_u8; 32];
218 for i in 0..32 {
219 result[i] = stage1[i] ^ scramble[i];
220 }
221
222 result
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum CachingSha2PasswordFastAuthResult {
232 Success,
233 FullAuthRequired,
234}
235
236pub fn read_caching_sha2_password_fast_auth_result(
238 payload: &[u8],
239) -> Result<CachingSha2PasswordFastAuthResult> {
240 if payload.is_empty() {
241 return Err(Error::LibraryBug(eyre!(
242 "empty payload for caching_sha2_password fast auth result"
243 )));
244 }
245
246 match payload[0] {
247 0x03 => Ok(CachingSha2PasswordFastAuthResult::Success),
248 0x04 => Ok(CachingSha2PasswordFastAuthResult::FullAuthRequired),
249 _ => Err(Error::LibraryBug(eyre!(
250 "unexpected caching_sha2_password fast auth result: 0x{:02X}",
251 payload[0]
252 ))),
253 }
254}
255
256fn write_ssl_request(
262 out: &mut Vec<u8>,
263 capability_flags: CapabilityFlags,
264 mariadb_capabilities: MariadbCapabilityFlags,
265) {
266 write_int_4(out, capability_flags.bits());
268
269 write_int_4(out, MAX_ALLOWED_PACKET);
271
272 write_int_1(out, UTF8MB4_GENERAL_CI);
274
275 out.extend_from_slice(&[0_u8; 19]);
277
278 if capability_flags.is_mariadb() {
279 write_int_4(out, mariadb_capabilities.bits());
280 } else {
281 write_int_4(out, 0);
282 }
283}
284
285pub enum HandshakeAction<'buf> {
287 ReadPacket(&'buf mut Vec<u8>),
289
290 WritePacket { sequence_id: u8 },
292
293 UpgradeTls { sequence_id: u8 },
295
296 Finished,
298}
299
300enum HandshakeState {
302 Start,
304 WaitingInitialHandshake,
306 WaitingTlsUpgrade,
308 WaitingAuthResult,
310 WaitingFinalAuthResult,
312 Connected,
314}
315
316pub struct Handshake<'a> {
320 state: HandshakeState,
321 opts: &'a Opts,
322 initial_handshake: Option<InitialHandshake>,
323 next_sequence_id: u8,
324 capability_flags: Option<CapabilityFlags>,
325 mariadb_capabilities: Option<MariadbCapabilityFlags>,
326}
327
328impl<'a> Handshake<'a> {
329 pub fn new(opts: &'a Opts) -> Self {
331 Self {
332 state: HandshakeState::Start,
333 opts,
334 initial_handshake: None,
335 next_sequence_id: 1,
336 capability_flags: None,
337 mariadb_capabilities: None,
338 }
339 }
340
341 pub fn step<'buf>(&mut self, buffer_set: &'buf mut BufferSet) -> Result<HandshakeAction<'buf>> {
345 match &mut self.state {
346 HandshakeState::Start => {
347 self.state = HandshakeState::WaitingInitialHandshake;
348 Ok(HandshakeAction::ReadPacket(
349 &mut buffer_set.initial_handshake,
350 ))
351 }
352
353 HandshakeState::WaitingInitialHandshake => {
354 let handshake = read_initial_handshake(&buffer_set.initial_handshake)?;
355
356 let mut client_caps = CAPABILITIES_ALWAYS_ENABLED
357 | (self.opts.capabilities & CAPABILITIES_CONFIGURABLE);
358 if self.opts.db.is_some() {
359 client_caps |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
360 }
361 if self.opts.tls {
362 client_caps |= CapabilityFlags::CLIENT_SSL;
363 }
364
365 let negotiated_caps = client_caps & handshake.capability_flags;
366 let mariadb_caps = if negotiated_caps.is_mariadb() {
367 if !handshake
368 .mariadb_capabilities
369 .contains(MARIADB_CAPABILITIES_ENABLED)
370 {
371 return Err(Error::Unsupported(format!(
372 "MariaDB server does not support the required capabilities. Server: {:?} Required: {:?}",
373 handshake.mariadb_capabilities, MARIADB_CAPABILITIES_ENABLED
374 )));
375 }
376 MARIADB_CAPABILITIES_ENABLED
377 } else {
378 MariadbCapabilityFlags::empty()
379 };
380
381 self.capability_flags = Some(negotiated_caps);
383 self.mariadb_capabilities = Some(mariadb_caps);
384 self.initial_handshake = Some(handshake);
385
386 if self.opts.tls && negotiated_caps.contains(CapabilityFlags::CLIENT_SSL) {
388 write_ssl_request(buffer_set.new_write_buffer(), negotiated_caps, mariadb_caps);
389
390 let seq = self.next_sequence_id;
391 self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
392 self.state = HandshakeState::WaitingTlsUpgrade;
393
394 Ok(HandshakeAction::UpgradeTls { sequence_id: seq })
395 } else {
396 self.write_handshake_response(buffer_set)?;
398 let seq = self.next_sequence_id;
399 self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
400 self.state = HandshakeState::WaitingAuthResult;
401
402 Ok(HandshakeAction::WritePacket { sequence_id: seq })
403 }
404 }
405
406 HandshakeState::WaitingTlsUpgrade => {
407 self.write_handshake_response(buffer_set)?;
409
410 let seq = self.next_sequence_id;
411 self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
412 self.state = HandshakeState::WaitingAuthResult;
413
414 Ok(HandshakeAction::WritePacket { sequence_id: seq })
415 }
416
417 HandshakeState::WaitingAuthResult => {
418 let payload = &buffer_set.read_buffer[..];
419 if payload.is_empty() {
420 return Err(Error::LibraryBug(eyre!(
421 "empty payload while waiting for auth result"
422 )));
423 }
424
425 let initial_handshake = self.initial_handshake.as_ref().ok_or_else(|| {
427 Error::LibraryBug(eyre!("initial_handshake not set in WaitingAuthResult"))
428 })?;
429 let initial_plugin =
430 &buffer_set.initial_handshake[initial_handshake.auth_plugin_name.clone()];
431
432 match payload[0] {
433 0x00 => {
434 self.state = HandshakeState::Connected;
436 Ok(HandshakeAction::Finished)
437 }
438 0xFF => {
439 Err(ErrPayloadBytes(payload).into())
441 }
442 0xFE => {
443 if initial_plugin == b"caching_sha2_password" && payload.len() == 2 {
445 let result = read_caching_sha2_password_fast_auth_result(payload)?;
447 match result {
448 CachingSha2PasswordFastAuthResult::Success => {
449 Ok(HandshakeAction::ReadPacket(&mut buffer_set.read_buffer))
451 }
452 CachingSha2PasswordFastAuthResult::FullAuthRequired => {
453 Err(Error::Unsupported(
454 "caching_sha2_password full auth (requires SSL/RSA)"
455 .to_string(),
456 ))
457 }
458 }
459 } else {
460 let auth_switch = read_auth_switch_request(payload)?;
462
463 let auth_response = match auth_switch.plugin_name {
465 b"mysql_native_password" => auth_mysql_native_password(
466 &self.opts.password,
467 auth_switch.plugin_data,
468 )
469 .to_vec(),
470 b"caching_sha2_password" => auth_caching_sha2_password(
471 &self.opts.password,
472 auth_switch.plugin_data,
473 )
474 .to_vec(),
475 plugin => {
476 return Err(Error::Unsupported(
477 String::from_utf8_lossy(plugin).to_string(),
478 ));
479 }
480 };
481
482 write_auth_switch_response(
483 buffer_set.new_write_buffer(),
484 &auth_response,
485 );
486
487 let seq = self.next_sequence_id;
488 self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
489 self.state = HandshakeState::WaitingFinalAuthResult;
490
491 Ok(HandshakeAction::WritePacket { sequence_id: seq })
492 }
493 }
494 header => Err(Error::LibraryBug(eyre!(
495 "unexpected packet header 0x{:02X} while waiting for auth result",
496 header
497 ))),
498 }
499 }
500
501 HandshakeState::WaitingFinalAuthResult => {
502 let payload = &buffer_set.read_buffer[..];
503 if payload.is_empty() {
504 return Err(Error::LibraryBug(eyre!(
505 "empty payload while waiting for final auth result"
506 )));
507 }
508
509 match payload[0] {
510 0x00 => {
511 self.state = HandshakeState::Connected;
513 Ok(HandshakeAction::Finished)
514 }
515 0xFF => {
516 Err(ErrPayloadBytes(payload).into())
518 }
519 header => Err(Error::LibraryBug(eyre!(
520 "unexpected packet header 0x{:02X} while waiting for final auth result",
521 header
522 ))),
523 }
524 }
525
526 HandshakeState::Connected => Err(Error::LibraryBug(eyre!(
527 "step() called after handshake completed"
528 ))),
529 }
530 }
531
532 pub fn finish(self) -> Result<(InitialHandshake, CapabilityFlags, MariadbCapabilityFlags)> {
536 if !matches!(self.state, HandshakeState::Connected) {
537 return Err(Error::LibraryBug(eyre!(
538 "finish() called before handshake completed"
539 )));
540 }
541
542 let initial_handshake = self.initial_handshake.ok_or_else(|| {
543 Error::LibraryBug(eyre!("initial_handshake not set in Connected state"))
544 })?;
545 let capability_flags = self.capability_flags.ok_or_else(|| {
546 Error::LibraryBug(eyre!("capability_flags not set in Connected state"))
547 })?;
548 let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
549 Error::LibraryBug(eyre!("mariadb_capabilities not set in Connected state"))
550 })?;
551
552 Ok((initial_handshake, capability_flags, mariadb_capabilities))
553 }
554
555 fn write_handshake_response(&self, buffer_set: &mut BufferSet) -> Result<()> {
557 buffer_set.new_write_buffer();
558
559 let handshake = self.initial_handshake.as_ref().ok_or_else(|| {
560 Error::LibraryBug(eyre!(
561 "initial_handshake not set in write_handshake_response"
562 ))
563 })?;
564 let capability_flags = self.capability_flags.ok_or_else(|| {
565 Error::LibraryBug(eyre!(
566 "capability_flags not set in write_handshake_response"
567 ))
568 })?;
569 let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
570 Error::LibraryBug(eyre!(
571 "mariadb_capabilities not set in write_handshake_response"
572 ))
573 })?;
574
575 let auth_plugin_name = &buffer_set.initial_handshake[handshake.auth_plugin_name.clone()];
577 let auth_response = {
578 match auth_plugin_name {
579 b"mysql_native_password" => {
580 auth_mysql_native_password(&self.opts.password, &handshake.auth_plugin_data)
581 .to_vec()
582 }
583 b"caching_sha2_password" => {
584 auth_caching_sha2_password(&self.opts.password, &handshake.auth_plugin_data)
585 .to_vec()
586 }
587 plugin => {
588 return Err(Error::Unsupported(
589 String::from_utf8_lossy(plugin).to_string(),
590 ));
591 }
592 }
593 };
594
595 let out = &mut buffer_set.write_buffer;
596 write_int_4(out, capability_flags.bits());
598 write_int_4(out, MAX_ALLOWED_PACKET);
600 write_int_1(out, UTF8MB4_GENERAL_CI);
602 out.extend_from_slice(&[0_u8; 19]);
604 write_int_4(out, mariadb_capabilities.bits());
605 write_string_null(out, self.opts.user.as_bytes());
607 if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
609 write_bytes_lenenc(out, &auth_response);
610 } else {
611 write_int_1(out, auth_response.len() as u8);
612 out.extend_from_slice(&auth_response);
613 }
614 if let Some(db) = &self.opts.db {
616 write_string_null(out, db.as_bytes());
617 }
618
619 if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
621 write_string_null(out, auth_plugin_name);
622 }
623
624 Ok(())
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn handshake_fixed_fields_has_alignment_of_1() {
634 assert_eq!(std::mem::align_of::<HandshakeFixedFields>(), 1);
635 }
636}