1use crate::nightly::cold_path;
2use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
3use zerocopy::{FromBytes, Immutable, KnownLayout};
4
5use crate::buffer::BufferSet;
6use crate::constant::{
7 CAPABILITIES_ALWAYS_ENABLED, CAPABILITIES_CONFIGURABLE, CapabilityFlags,
8 MARIADB_CAPABILITIES_ENABLED, MAX_ALLOWED_PACKET, MariadbCapabilityFlags, UTF8MB4_GENERAL_CI,
9};
10use crate::error::{Error, Result, eyre};
11use crate::opts::Opts;
12use crate::protocol::primitive::*;
13use crate::protocol::response::ErrPayloadBytes;
14
15#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
16#[repr(C, packed)]
17struct HandshakeFixedFields {
18 connection_id: U32LE,
19 auth_data_part1: [u8; 8],
20 _filler1: u8,
21 capability_flags_lower: U16LE,
22 charset: u8,
23 status_flags: U16LE,
24 capability_flags_upper: U16LE,
25 auth_data_len: u8,
26 _filler2: [u8; 6],
27 mariadb_capabilities: U32LE,
28}
29
30#[derive(Debug, Clone)]
31pub struct InitialHandshake {
32 pub protocol_version: u8,
33 pub server_version: std::ops::Range<usize>,
34 pub connection_id: u32,
35 pub auth_plugin_data: Vec<u8>,
36 pub capability_flags: CapabilityFlags,
37 pub mariadb_capabilities: MariadbCapabilityFlags,
38 pub charset: u8,
39 pub status_flags: crate::constant::ServerStatusFlags,
40 pub auth_plugin_name: std::ops::Range<usize>,
41}
42
43pub fn read_initial_handshake(payload: &[u8]) -> Result<InitialHandshake> {
45 let (protocol_version, data) = read_int_1(payload)?;
46
47 if protocol_version == 0xFF {
48 cold_path();
49 Err(ErrPayloadBytes(payload))?
50 }
51
52 let server_version_start = payload.len() - data.len();
53 let (server_version_bytes, data) = read_string_null(data)?;
54 let server_version = server_version_start..server_version_start + server_version_bytes.len();
55
56 let (fixed, data) = HandshakeFixedFields::ref_from_prefix(data)?;
57
58 let connection_id = fixed.connection_id.get();
59 let charset = fixed.charset;
60 let status_flags = fixed.status_flags.get();
61 let capability_flags = CapabilityFlags::from_bits(
62 ((fixed.capability_flags_upper.get() as u32) << 16)
63 | (fixed.capability_flags_lower.get() as u32),
64 )
65 .ok_or_else(|| Error::LibraryBug(eyre!("invalid capability flags from server")))?;
66 let mariadb_capabilities = MariadbCapabilityFlags::from_bits(fixed.mariadb_capabilities.get())
67 .ok_or_else(|| Error::LibraryBug(eyre!("invalid mariadb capability flags from server")))?;
68 let auth_data_len = fixed.auth_data_len;
69
70 let auth_data_2_len = (auth_data_len as usize).saturating_sub(9).max(12);
71 let (auth_data_2, data) = read_string_fix(data, auth_data_2_len)?;
72 let (_reserved, data) = read_int_1(data)?;
73
74 let mut auth_plugin_data = Vec::new();
75 auth_plugin_data.extend_from_slice(&fixed.auth_data_part1);
76 auth_plugin_data.extend_from_slice(auth_data_2);
77
78 let auth_plugin_name_start = payload.len() - data.len();
79 let (auth_plugin_name_bytes, rest) = read_string_null(data)?;
80 let auth_plugin_name =
81 auth_plugin_name_start..auth_plugin_name_start + auth_plugin_name_bytes.len();
82
83 if !rest.is_empty() {
84 return Err(Error::LibraryBug(eyre!(
85 "unexpected trailing data in handshake packet: {} bytes",
86 rest.len()
87 )));
88 }
89
90 Ok(InitialHandshake {
91 protocol_version,
92 server_version,
93 connection_id,
94 auth_plugin_data,
95 capability_flags,
96 mariadb_capabilities,
97 charset,
98 status_flags: crate::constant::ServerStatusFlags::from_bits_truncate(status_flags),
99 auth_plugin_name,
100 })
101}
102
103#[derive(Debug, Clone)]
105pub struct AuthSwitchRequest<'buf> {
106 pub plugin_name: &'buf [u8],
107 pub plugin_data: &'buf [u8],
108}
109
110pub fn read_auth_switch_request(payload: &[u8]) -> Result<AuthSwitchRequest<'_>> {
112 let (header, mut data) = read_int_1(payload)?;
113 if header != 0xFE {
114 return Err(Error::LibraryBug(eyre!(
115 "expected auth switch header 0xFE, got 0x{:02X}",
116 header
117 )));
118 }
119
120 let (plugin_name, rest) = read_string_null(data)?;
121 data = rest;
122
123 if let Some(0) = data.last() {
124 Ok(AuthSwitchRequest {
125 plugin_name,
126 plugin_data: &data[..data.len() - 1],
127 })
128 } else {
129 Err(Error::LibraryBug(eyre!(
130 "auth switch request plugin data not null-terminated"
131 )))
132 }
133}
134
135pub fn write_auth_switch_response(out: &mut Vec<u8>, auth_data: &[u8]) {
139 out.extend_from_slice(auth_data);
140}
141
142pub fn auth_mysql_native_password(password: &str, challenge: &[u8]) -> [u8; 20] {
158 use sha1::{Digest, Sha1};
159
160 if password.is_empty() {
161 return [0_u8; 20];
162 }
163
164 let stage1_hash = Sha1::digest(password.as_bytes());
166
167 let stage2_hash = Sha1::digest(stage1_hash);
169
170 let mut hasher = Sha1::new();
172 hasher.update(challenge);
173 hasher.update(stage2_hash);
174 let token_hash = hasher.finalize();
175
176 let mut result = [0_u8; 20];
178 for i in 0..20 {
179 result[i] = stage1_hash[i] ^ token_hash[i];
180 }
181
182 result
183}
184
185pub fn auth_caching_sha2_password(password: &str, challenge: &[u8]) -> [u8; 32] {
198 use sha2::{Digest, Sha256};
199
200 if password.is_empty() {
201 return [0_u8; 32];
202 }
203
204 let stage1 = Sha256::digest(password.as_bytes());
206
207 let stage2 = Sha256::digest(stage1);
209
210 let mut hasher = Sha256::new();
212 hasher.update(stage2);
213 hasher.update(challenge);
214 let scramble = hasher.finalize();
215
216 let mut result = [0_u8; 32];
218 for i in 0..32 {
219 result[i] = stage1[i] ^ scramble[i];
220 }
221
222 result
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum CachingSha2PasswordFastAuthResult {
232 Success,
233 FullAuthRequired,
234}
235
236pub fn read_caching_sha2_password_fast_auth_result(
238 payload: &[u8],
239) -> Result<CachingSha2PasswordFastAuthResult> {
240 if payload.is_empty() {
241 return Err(Error::LibraryBug(eyre!(
242 "empty payload for caching_sha2_password fast auth result"
243 )));
244 }
245
246 match payload[0] {
247 0x03 => Ok(CachingSha2PasswordFastAuthResult::Success),
248 0x04 => Ok(CachingSha2PasswordFastAuthResult::FullAuthRequired),
249 _ => Err(Error::LibraryBug(eyre!(
250 "unexpected caching_sha2_password fast auth result: 0x{:02X}",
251 payload[0]
252 ))),
253 }
254}
255
256fn rsa_encrypt_password(password: &str, scramble: &[u8], pem_str: &str) -> Result<Vec<u8>> {
261 use aws_lc_rs::rsa::{OAEP_SHA1_MGF1SHA1, OaepPublicEncryptingKey, PublicEncryptingKey};
262
263 let pem_data = pem::parse(pem_str)
264 .map_err(|e| Error::LibraryBug(eyre!("failed to parse RSA public key PEM: {}", e)))?;
265
266 let public_key = PublicEncryptingKey::from_der(pem_data.contents())
267 .map_err(|e| Error::LibraryBug(eyre!("failed to parse RSA public key DER: {}", e)))?;
268
269 let oaep_key = OaepPublicEncryptingKey::new(public_key)
270 .map_err(|e| Error::LibraryBug(eyre!("failed to create OAEP key: {}", e)))?;
271
272 if scramble.is_empty() {
273 return Err(Error::LibraryBug(eyre!(
274 "empty scramble in rsa_encrypt_password"
275 )));
276 }
277
278 let mut buf = Vec::with_capacity(password.len() + 1);
280 buf.extend_from_slice(password.as_bytes());
281 buf.push(0);
282
283 for (byte, key) in buf.iter_mut().zip(scramble.iter().cycle()) {
284 *byte ^= key;
285 }
286
287 let mut ciphertext = vec![0u8; oaep_key.ciphertext_size()];
288 let encrypted = oaep_key
289 .encrypt(&OAEP_SHA1_MGF1SHA1, &buf, &mut ciphertext, None)
290 .map_err(|e| Error::LibraryBug(eyre!("RSA encryption failed: {}", e)))?;
291
292 Ok(encrypted.to_vec())
293}
294
295fn write_ssl_request(
301 out: &mut Vec<u8>,
302 capability_flags: CapabilityFlags,
303 mariadb_capabilities: MariadbCapabilityFlags,
304) {
305 write_int_4(out, capability_flags.bits());
307
308 write_int_4(out, MAX_ALLOWED_PACKET);
310
311 write_int_1(out, UTF8MB4_GENERAL_CI);
313
314 out.extend_from_slice(&[0_u8; 19]);
316
317 if capability_flags.is_mariadb() {
318 write_int_4(out, mariadb_capabilities.bits());
319 } else {
320 write_int_4(out, 0);
321 }
322}
323
324pub enum HandshakeAction<'buf> {
326 ReadPacket(&'buf mut Vec<u8>),
328
329 WritePacket { sequence_id: u8 },
331
332 UpgradeTls { sequence_id: u8 },
334
335 Finished,
337}
338
339enum HandshakeState {
341 Start,
343 WaitingInitialHandshake,
345 WaitingTlsUpgrade,
347 WaitingAuthResult,
349 WaitingFinalAuthResult { caching_sha2: bool },
351 WaitingCachingSha2FastAuthOk,
353 WaitingRsaPublicKey,
355 Connected,
357}
358
359pub struct Handshake<'a> {
363 state: HandshakeState,
364 opts: &'a Opts,
365 initial_handshake: Option<InitialHandshake>,
366 next_sequence_id: u8,
367 capability_flags: Option<CapabilityFlags>,
368 mariadb_capabilities: Option<MariadbCapabilityFlags>,
369}
370
371impl<'a> Handshake<'a> {
372 pub fn new(opts: &'a Opts) -> Self {
374 Self {
375 state: HandshakeState::Start,
376 opts,
377 initial_handshake: None,
378 next_sequence_id: 1,
379 capability_flags: None,
380 mariadb_capabilities: None,
381 }
382 }
383
384 pub fn step<'buf>(&mut self, buffer_set: &'buf mut BufferSet) -> Result<HandshakeAction<'buf>> {
388 match &mut self.state {
389 HandshakeState::Start => {
390 self.state = HandshakeState::WaitingInitialHandshake;
391 Ok(HandshakeAction::ReadPacket(
392 &mut buffer_set.initial_handshake,
393 ))
394 }
395
396 HandshakeState::WaitingInitialHandshake => {
397 let handshake = read_initial_handshake(&buffer_set.initial_handshake)?;
398
399 let mut client_caps = CAPABILITIES_ALWAYS_ENABLED
400 | (self.opts.capabilities & CAPABILITIES_CONFIGURABLE);
401 if self.opts.db.is_some() {
402 client_caps |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
403 }
404 if self.opts.tls {
405 client_caps |= CapabilityFlags::CLIENT_SSL;
406 }
407
408 let negotiated_caps = client_caps & handshake.capability_flags;
409 let mariadb_caps = if negotiated_caps.is_mariadb() {
410 if !handshake
411 .mariadb_capabilities
412 .contains(MARIADB_CAPABILITIES_ENABLED)
413 {
414 return Err(Error::Unsupported(format!(
415 "MariaDB server does not support the required capabilities. Server: {:?} Required: {:?}",
416 handshake.mariadb_capabilities, MARIADB_CAPABILITIES_ENABLED
417 )));
418 }
419 MARIADB_CAPABILITIES_ENABLED
420 } else {
421 MariadbCapabilityFlags::empty()
422 };
423
424 self.capability_flags = Some(negotiated_caps);
426 self.mariadb_capabilities = Some(mariadb_caps);
427 self.initial_handshake = Some(handshake);
428
429 if self.opts.tls && negotiated_caps.contains(CapabilityFlags::CLIENT_SSL) {
431 write_ssl_request(buffer_set.new_write_buffer(), negotiated_caps, mariadb_caps);
432
433 let seq = self.next_sequence_id;
434 self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
435 self.state = HandshakeState::WaitingTlsUpgrade;
436
437 Ok(HandshakeAction::UpgradeTls { sequence_id: seq })
438 } else {
439 self.write_handshake_response(buffer_set)?;
441 let seq = self.next_sequence_id;
442 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
443 self.state = HandshakeState::WaitingAuthResult;
444
445 Ok(HandshakeAction::WritePacket { sequence_id: seq })
446 }
447 }
448
449 HandshakeState::WaitingTlsUpgrade => {
450 self.write_handshake_response(buffer_set)?;
452
453 let seq = self.next_sequence_id;
454 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
455 self.state = HandshakeState::WaitingAuthResult;
456
457 Ok(HandshakeAction::WritePacket { sequence_id: seq })
458 }
459
460 HandshakeState::WaitingAuthResult => {
461 let payload = &buffer_set.read_buffer[..];
462 if payload.is_empty() {
463 return Err(Error::LibraryBug(eyre!(
464 "empty payload while waiting for auth result"
465 )));
466 }
467
468 let initial_handshake = self.initial_handshake.as_ref().ok_or_else(|| {
470 Error::LibraryBug(eyre!("initial_handshake not set in WaitingAuthResult"))
471 })?;
472 let initial_plugin =
473 &buffer_set.initial_handshake[initial_handshake.auth_plugin_name.clone()];
474
475 match payload[0] {
476 0x00 => {
477 self.state = HandshakeState::Connected;
479 Ok(HandshakeAction::Finished)
480 }
481 0xFF => {
482 Err(ErrPayloadBytes(payload).into())
484 }
485 0x01 => {
486 if initial_plugin == b"caching_sha2_password" {
488 self.handle_auth_more_data(buffer_set)
489 } else {
490 Err(Error::LibraryBug(eyre!(
491 "unexpected AuthMoreData (0x01) for plugin {:?}",
492 String::from_utf8_lossy(initial_plugin)
493 )))
494 }
495 }
496 0xFE => {
497 let auth_switch = read_auth_switch_request(payload)?;
499
500 let (auth_response, is_caching_sha2) = match auth_switch.plugin_name {
502 b"mysql_native_password" => (
503 auth_mysql_native_password(
504 &self.opts.password,
505 auth_switch.plugin_data,
506 )
507 .to_vec(),
508 false,
509 ),
510 b"caching_sha2_password" => (
511 auth_caching_sha2_password(
512 &self.opts.password,
513 auth_switch.plugin_data,
514 )
515 .to_vec(),
516 true,
517 ),
518 plugin => {
519 return Err(Error::Unsupported(
520 String::from_utf8_lossy(plugin).to_string(),
521 ));
522 }
523 };
524
525 write_auth_switch_response(buffer_set.new_write_buffer(), &auth_response);
526
527 let seq = self.next_sequence_id;
528 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
529 self.state = HandshakeState::WaitingFinalAuthResult {
530 caching_sha2: is_caching_sha2,
531 };
532
533 Ok(HandshakeAction::WritePacket { sequence_id: seq })
534 }
535 header => Err(Error::LibraryBug(eyre!(
536 "unexpected packet header 0x{:02X} while waiting for auth result",
537 header
538 ))),
539 }
540 }
541
542 HandshakeState::WaitingFinalAuthResult { caching_sha2 } => {
543 let payload = &buffer_set.read_buffer[..];
544 if payload.is_empty() {
545 return Err(Error::LibraryBug(eyre!(
546 "empty payload while waiting for final auth result"
547 )));
548 }
549
550 match payload[0] {
551 0x00 => {
552 self.state = HandshakeState::Connected;
554 Ok(HandshakeAction::Finished)
555 }
556 0xFF => {
557 Err(ErrPayloadBytes(payload).into())
559 }
560 0x01 if *caching_sha2 => self.handle_auth_more_data(buffer_set),
561 header => Err(Error::LibraryBug(eyre!(
562 "unexpected packet header 0x{:02X} while waiting for final auth result",
563 header
564 ))),
565 }
566 }
567
568 HandshakeState::WaitingCachingSha2FastAuthOk => {
569 let payload = &buffer_set.read_buffer[..];
570 if payload.is_empty() {
571 return Err(Error::LibraryBug(eyre!(
572 "empty payload while waiting for caching_sha2 OK"
573 )));
574 }
575
576 match payload[0] {
577 0x00 => {
578 self.state = HandshakeState::Connected;
579 Ok(HandshakeAction::Finished)
580 }
581 0xFF => Err(ErrPayloadBytes(payload).into()),
582 header => Err(Error::LibraryBug(eyre!(
583 "unexpected packet header 0x{:02X} while waiting for caching_sha2 OK",
584 header
585 ))),
586 }
587 }
588
589 HandshakeState::WaitingRsaPublicKey => {
590 let payload = &buffer_set.read_buffer[..];
591 if payload.is_empty() {
592 return Err(Error::LibraryBug(eyre!(
593 "empty payload while waiting for RSA public key"
594 )));
595 }
596
597 match payload[0] {
598 0xFF => return Err(ErrPayloadBytes(payload).into()),
599 0x01 if payload.len() >= 2 => {}
600 header => {
601 return Err(Error::LibraryBug(eyre!(
602 "expected AuthMoreData (0x01) with RSA public key, got 0x{:02X}",
603 header
604 )));
605 }
606 }
607
608 let pem = std::str::from_utf8(&payload[1..]).map_err(|e| {
609 Error::LibraryBug(eyre!("RSA public key is not valid UTF-8: {}", e))
610 })?;
611
612 let handshake = self
613 .initial_handshake
614 .as_ref()
615 .ok_or_else(|| Error::LibraryBug(eyre!("initial_handshake not set")))?;
616
617 let encrypted =
618 rsa_encrypt_password(&self.opts.password, &handshake.auth_plugin_data, pem)?;
619
620 let out = buffer_set.new_write_buffer();
621 out.extend_from_slice(&encrypted);
622
623 let seq = self.next_sequence_id;
624 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
625 self.state = HandshakeState::WaitingFinalAuthResult {
626 caching_sha2: false,
627 };
628
629 Ok(HandshakeAction::WritePacket { sequence_id: seq })
630 }
631
632 HandshakeState::Connected => Err(Error::LibraryBug(eyre!(
633 "step() called after handshake completed"
634 ))),
635 }
636 }
637
638 pub fn finish(self) -> Result<(InitialHandshake, CapabilityFlags, MariadbCapabilityFlags)> {
642 if !matches!(self.state, HandshakeState::Connected) {
643 return Err(Error::LibraryBug(eyre!(
644 "finish() called before handshake completed"
645 )));
646 }
647
648 let initial_handshake = self.initial_handshake.ok_or_else(|| {
649 Error::LibraryBug(eyre!("initial_handshake not set in Connected state"))
650 })?;
651 let capability_flags = self.capability_flags.ok_or_else(|| {
652 Error::LibraryBug(eyre!("capability_flags not set in Connected state"))
653 })?;
654 let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
655 Error::LibraryBug(eyre!("mariadb_capabilities not set in Connected state"))
656 })?;
657
658 Ok((initial_handshake, capability_flags, mariadb_capabilities))
659 }
660
661 fn write_handshake_response(&self, buffer_set: &mut BufferSet) -> Result<()> {
663 buffer_set.new_write_buffer();
664
665 let handshake = self.initial_handshake.as_ref().ok_or_else(|| {
666 Error::LibraryBug(eyre!(
667 "initial_handshake not set in write_handshake_response"
668 ))
669 })?;
670 let capability_flags = self.capability_flags.ok_or_else(|| {
671 Error::LibraryBug(eyre!(
672 "capability_flags not set in write_handshake_response"
673 ))
674 })?;
675 let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
676 Error::LibraryBug(eyre!(
677 "mariadb_capabilities not set in write_handshake_response"
678 ))
679 })?;
680
681 let auth_plugin_name = &buffer_set.initial_handshake[handshake.auth_plugin_name.clone()];
683 let auth_response = {
684 match auth_plugin_name {
685 b"mysql_native_password" => {
686 auth_mysql_native_password(&self.opts.password, &handshake.auth_plugin_data)
687 .to_vec()
688 }
689 b"caching_sha2_password" => {
690 auth_caching_sha2_password(&self.opts.password, &handshake.auth_plugin_data)
691 .to_vec()
692 }
693 plugin => {
694 return Err(Error::Unsupported(
695 String::from_utf8_lossy(plugin).to_string(),
696 ));
697 }
698 }
699 };
700
701 let out = &mut buffer_set.write_buffer;
702 write_int_4(out, capability_flags.bits());
704 write_int_4(out, MAX_ALLOWED_PACKET);
706 write_int_1(out, UTF8MB4_GENERAL_CI);
708 out.extend_from_slice(&[0_u8; 19]);
710 write_int_4(out, mariadb_capabilities.bits());
711 write_string_null(out, self.opts.user.as_bytes());
713 if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
715 write_bytes_lenenc(out, &auth_response);
716 } else {
717 write_int_1(out, auth_response.len() as u8);
718 out.extend_from_slice(&auth_response);
719 }
720 if let Some(db) = &self.opts.db {
722 write_string_null(out, db.as_bytes());
723 }
724
725 if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
727 write_string_null(out, auth_plugin_name);
728 }
729
730 Ok(())
731 }
732
733 fn handle_auth_more_data<'buf>(
737 &mut self,
738 buffer_set: &'buf mut BufferSet,
739 ) -> Result<HandshakeAction<'buf>> {
740 let payload = &buffer_set.read_buffer[..];
741 if payload.len() < 2 {
742 return Err(Error::LibraryBug(eyre!(
743 "AuthMoreData packet too short: {} bytes",
744 payload.len()
745 )));
746 }
747
748 let result = read_caching_sha2_password_fast_auth_result(&payload[1..])?;
749
750 match result {
751 CachingSha2PasswordFastAuthResult::Success => {
752 self.state = HandshakeState::WaitingCachingSha2FastAuthOk;
754 Ok(HandshakeAction::ReadPacket(&mut buffer_set.read_buffer))
755 }
756 CachingSha2PasswordFastAuthResult::FullAuthRequired => {
757 let capability_flags = self
758 .capability_flags
759 .ok_or_else(|| Error::LibraryBug(eyre!("capability_flags not set")))?;
760
761 if capability_flags.contains(CapabilityFlags::CLIENT_SSL) {
762 let out = buffer_set.new_write_buffer();
764 out.extend_from_slice(self.opts.password.as_bytes());
765 out.push(0);
766
767 let seq = self.next_sequence_id;
768 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
769 self.state = HandshakeState::WaitingFinalAuthResult {
770 caching_sha2: false,
771 };
772
773 Ok(HandshakeAction::WritePacket { sequence_id: seq })
774 } else {
775 let out = buffer_set.new_write_buffer();
777 out.push(0x02);
778
779 let seq = self.next_sequence_id;
780 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
781 self.state = HandshakeState::WaitingRsaPublicKey;
782
783 Ok(HandshakeAction::WritePacket { sequence_id: seq })
784 }
785 }
786 }
787 }
788}
789
790#[cfg(test)]
791mod tests {
792 use super::*;
793 use crate::test_macros::{check_eq, check_err};
794
795 #[test]
796 fn handshake_fixed_fields_has_alignment_of_1() {
797 assert_eq!(std::mem::align_of::<HandshakeFixedFields>(), 1);
798 }
799
800 #[test]
801 #[expect(clippy::unwrap_used)]
802 fn rsa_encrypt_password_xors_and_encrypts() {
803 use aws_lc_rs::encoding::AsDer;
804 use aws_lc_rs::rsa::{
805 KeySize, OAEP_SHA1_MGF1SHA1, OaepPrivateDecryptingKey, PrivateDecryptingKey,
806 };
807 use aws_lc_rs::signature::KeyPair;
808
809 let key_pair = aws_lc_rs::rsa::KeyPair::generate(KeySize::Rsa2048).unwrap();
810 let private_key_pkcs8 = key_pair.as_der().unwrap();
811 let public_key_der = key_pair.public_key().as_der().unwrap();
812
813 let pem_data = pem::Pem::new("PUBLIC KEY", public_key_der.as_ref().to_vec());
814 let pem_string = pem::encode(&pem_data);
815
816 let password = "test_password";
817 let scramble = b"01234567890123456789";
818
819 let encrypted = super::rsa_encrypt_password(password, scramble, &pem_string).unwrap();
820
821 let private_key = PrivateDecryptingKey::from_pkcs8(private_key_pkcs8.as_ref()).unwrap();
823 let oaep_key = OaepPrivateDecryptingKey::new(private_key).unwrap();
824 let mut plaintext = vec![0u8; encrypted.len()];
825 let decrypted = oaep_key
826 .decrypt(&OAEP_SHA1_MGF1SHA1, &encrypted, &mut plaintext, None)
827 .unwrap();
828
829 let mut expected = password.as_bytes().to_vec();
831 expected.push(0);
832 for (byte, key) in expected.iter_mut().zip(scramble.iter().cycle()) {
833 *byte ^= key;
834 }
835 assert_eq!(decrypted, expected);
836 }
837
838 #[test]
839 fn fast_auth_result_parsing() -> crate::error::Result<()> {
840 check_eq!(
841 read_caching_sha2_password_fast_auth_result(&[0x03])?,
842 CachingSha2PasswordFastAuthResult::Success,
843 );
844 check_eq!(
845 read_caching_sha2_password_fast_auth_result(&[0x04])?,
846 CachingSha2PasswordFastAuthResult::FullAuthRequired,
847 );
848 check_err!(read_caching_sha2_password_fast_auth_result(&[0x05]));
849 check_err!(read_caching_sha2_password_fast_auth_result(&[]));
850 Ok(())
851 }
852}