1use crate::nightly::cold_path;
2use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
3use zerocopy::{FromBytes, Immutable, KnownLayout};
4
5use crate::buffer::BufferSet;
6use crate::constant::{
7 CAPABILITIES_ALWAYS_ENABLED, CAPABILITIES_CONFIGURABLE, CapabilityFlags,
8 MARIADB_CAPABILITIES_ENABLED, MAX_ALLOWED_PACKET, MariadbCapabilityFlags, UTF8MB4_GENERAL_CI,
9};
10use crate::error::{Error, Result, eyre};
11use crate::opts::Opts;
12use crate::protocol::primitive::*;
13use crate::protocol::response::ErrPayloadBytes;
14
15#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
16#[repr(C, packed)]
17struct HandshakeFixedFields {
18 connection_id: U32LE,
19 auth_data_part1: [u8; 8],
20 _filler1: u8,
21 capability_flags_lower: U16LE,
22 charset: u8,
23 status_flags: U16LE,
24 capability_flags_upper: U16LE,
25 auth_data_len: u8,
26 _filler2: [u8; 6],
27 mariadb_capabilities: U32LE,
28}
29
30#[derive(Debug, Clone)]
31pub struct InitialHandshake {
32 pub protocol_version: u8,
33 pub server_version: std::ops::Range<usize>,
34 pub connection_id: u32,
35 pub auth_plugin_data: Vec<u8>,
36 pub capability_flags: CapabilityFlags,
37 pub mariadb_capabilities: MariadbCapabilityFlags,
38 pub charset: u8,
39 pub status_flags: crate::constant::ServerStatusFlags,
40 pub auth_plugin_name: std::ops::Range<usize>,
41}
42
43pub fn read_initial_handshake(payload: &[u8]) -> Result<InitialHandshake> {
45 let (protocol_version, data) = read_int_1(payload)?;
46
47 if protocol_version == 0xFF {
48 cold_path();
49 Err(ErrPayloadBytes(payload))?
50 }
51
52 let server_version_start = payload.len() - data.len();
53 let (server_version_bytes, data) = read_string_null(data)?;
54 let server_version = server_version_start..server_version_start + server_version_bytes.len();
55
56 let (fixed, data) = HandshakeFixedFields::ref_from_prefix(data)?;
57
58 let connection_id = fixed.connection_id.get();
59 let charset = fixed.charset;
60 let status_flags = fixed.status_flags.get();
61 let capability_flags = CapabilityFlags::from_bits(
62 ((fixed.capability_flags_upper.get() as u32) << 16)
63 | (fixed.capability_flags_lower.get() as u32),
64 )
65 .ok_or_else(|| Error::LibraryBug(eyre!("invalid capability flags from server")))?;
66 let mariadb_capabilities = MariadbCapabilityFlags::from_bits(fixed.mariadb_capabilities.get())
67 .ok_or_else(|| Error::LibraryBug(eyre!("invalid mariadb capability flags from server")))?;
68 let auth_data_len = fixed.auth_data_len;
69
70 let auth_data_2_len = (auth_data_len as usize).saturating_sub(9).max(12);
71 let (auth_data_2, data) = read_string_fix(data, auth_data_2_len)?;
72 let (_reserved, data) = read_int_1(data)?;
73
74 let mut auth_plugin_data = Vec::new();
75 auth_plugin_data.extend_from_slice(&fixed.auth_data_part1);
76 auth_plugin_data.extend_from_slice(auth_data_2);
77
78 let auth_plugin_name_start = payload.len() - data.len();
79 let (auth_plugin_name_bytes, rest) = read_string_null(data)?;
80 let auth_plugin_name =
81 auth_plugin_name_start..auth_plugin_name_start + auth_plugin_name_bytes.len();
82
83 if !rest.is_empty() {
84 return Err(Error::LibraryBug(eyre!(
85 "unexpected trailing data in handshake packet: {} bytes",
86 rest.len()
87 )));
88 }
89
90 Ok(InitialHandshake {
91 protocol_version,
92 server_version,
93 connection_id,
94 auth_plugin_data,
95 capability_flags,
96 mariadb_capabilities,
97 charset,
98 status_flags: crate::constant::ServerStatusFlags::from_bits_truncate(status_flags),
99 auth_plugin_name,
100 })
101}
102
103#[derive(Debug, Clone)]
105pub struct AuthSwitchRequest<'buf> {
106 pub plugin_name: &'buf [u8],
107 pub plugin_data: &'buf [u8],
108}
109
110pub fn read_auth_switch_request(payload: &[u8]) -> Result<AuthSwitchRequest<'_>> {
112 let (header, mut data) = read_int_1(payload)?;
113 if header != 0xFE {
114 return Err(Error::LibraryBug(eyre!(
115 "expected auth switch header 0xFE, got 0x{:02X}",
116 header
117 )));
118 }
119
120 let (plugin_name, rest) = read_string_null(data)?;
121 data = rest;
122
123 if let Some(0) = data.last() {
124 Ok(AuthSwitchRequest {
125 plugin_name,
126 plugin_data: &data[..data.len() - 1],
127 })
128 } else {
129 Err(Error::LibraryBug(eyre!(
130 "auth switch request plugin data not null-terminated"
131 )))
132 }
133}
134
135pub fn write_auth_switch_response(out: &mut Vec<u8>, auth_data: &[u8]) {
139 out.extend_from_slice(auth_data);
140}
141
142pub fn auth_mysql_native_password(password: &str, challenge: &[u8]) -> [u8; 20] {
158 use sha1::{Digest, Sha1};
159
160 if password.is_empty() {
161 return [0_u8; 20];
162 }
163
164 let stage1_hash = Sha1::digest(password.as_bytes());
166
167 let stage2_hash = Sha1::digest(stage1_hash);
169
170 let mut hasher = Sha1::new();
172 hasher.update(challenge);
173 hasher.update(stage2_hash);
174 let token_hash = hasher.finalize();
175
176 let mut result = [0_u8; 20];
178 for i in 0..20 {
179 result[i] = stage1_hash[i] ^ token_hash[i];
180 }
181
182 result
183}
184
185pub fn auth_caching_sha2_password(password: &str, challenge: &[u8]) -> [u8; 32] {
198 use sha2::{Digest, Sha256};
199
200 if password.is_empty() {
201 return [0_u8; 32];
202 }
203
204 let stage1 = Sha256::digest(password.as_bytes());
206
207 let stage2 = Sha256::digest(stage1);
209
210 let mut hasher = Sha256::new();
212 hasher.update(stage2);
213 hasher.update(challenge);
214 let scramble = hasher.finalize();
215
216 let mut result = [0_u8; 32];
218 for i in 0..32 {
219 result[i] = stage1[i] ^ scramble[i];
220 }
221
222 result
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum CachingSha2PasswordFastAuthResult {
232 Success,
233 FullAuthRequired,
234}
235
236pub fn read_caching_sha2_password_fast_auth_result(
238 payload: &[u8],
239) -> Result<CachingSha2PasswordFastAuthResult> {
240 if payload.is_empty() {
241 return Err(Error::LibraryBug(eyre!(
242 "empty payload for caching_sha2_password fast auth result"
243 )));
244 }
245
246 match payload[0] {
247 0x03 => Ok(CachingSha2PasswordFastAuthResult::Success),
248 0x04 => Ok(CachingSha2PasswordFastAuthResult::FullAuthRequired),
249 _ => Err(Error::LibraryBug(eyre!(
250 "unexpected caching_sha2_password fast auth result: 0x{:02X}",
251 payload[0]
252 ))),
253 }
254}
255
256fn rsa_encrypt_password(password: &str, scramble: &[u8], pem: &str) -> Result<Vec<u8>> {
261 use rsa::pkcs8::DecodePublicKey;
262 use rsa::{Oaep, RsaPublicKey};
263
264 let public_key = RsaPublicKey::from_public_key_pem(pem)
265 .map_err(|e| Error::LibraryBug(eyre!("failed to parse RSA public key: {}", e)))?;
266
267 if scramble.is_empty() {
268 return Err(Error::LibraryBug(eyre!(
269 "empty scramble in rsa_encrypt_password"
270 )));
271 }
272
273 let mut buf = Vec::with_capacity(password.len() + 1);
275 buf.extend_from_slice(password.as_bytes());
276 buf.push(0);
277
278 for (byte, key) in buf.iter_mut().zip(scramble.iter().cycle()) {
279 *byte ^= key;
280 }
281
282 let padding = Oaep::new::<sha1::Sha1>();
283 public_key
284 .encrypt(&mut rsa::rand_core::OsRng, padding, &buf)
285 .map_err(|e| Error::LibraryBug(eyre!("RSA encryption failed: {}", e)))
286}
287
288fn write_ssl_request(
294 out: &mut Vec<u8>,
295 capability_flags: CapabilityFlags,
296 mariadb_capabilities: MariadbCapabilityFlags,
297) {
298 write_int_4(out, capability_flags.bits());
300
301 write_int_4(out, MAX_ALLOWED_PACKET);
303
304 write_int_1(out, UTF8MB4_GENERAL_CI);
306
307 out.extend_from_slice(&[0_u8; 19]);
309
310 if capability_flags.is_mariadb() {
311 write_int_4(out, mariadb_capabilities.bits());
312 } else {
313 write_int_4(out, 0);
314 }
315}
316
317pub enum HandshakeAction<'buf> {
319 ReadPacket(&'buf mut Vec<u8>),
321
322 WritePacket { sequence_id: u8 },
324
325 UpgradeTls { sequence_id: u8 },
327
328 Finished,
330}
331
332enum HandshakeState {
334 Start,
336 WaitingInitialHandshake,
338 WaitingTlsUpgrade,
340 WaitingAuthResult,
342 WaitingFinalAuthResult { caching_sha2: bool },
344 WaitingCachingSha2FastAuthOk,
346 WaitingRsaPublicKey,
348 Connected,
350}
351
352pub struct Handshake<'a> {
356 state: HandshakeState,
357 opts: &'a Opts,
358 initial_handshake: Option<InitialHandshake>,
359 next_sequence_id: u8,
360 capability_flags: Option<CapabilityFlags>,
361 mariadb_capabilities: Option<MariadbCapabilityFlags>,
362}
363
364impl<'a> Handshake<'a> {
365 pub fn new(opts: &'a Opts) -> Self {
367 Self {
368 state: HandshakeState::Start,
369 opts,
370 initial_handshake: None,
371 next_sequence_id: 1,
372 capability_flags: None,
373 mariadb_capabilities: None,
374 }
375 }
376
377 pub fn step<'buf>(&mut self, buffer_set: &'buf mut BufferSet) -> Result<HandshakeAction<'buf>> {
381 match &mut self.state {
382 HandshakeState::Start => {
383 self.state = HandshakeState::WaitingInitialHandshake;
384 Ok(HandshakeAction::ReadPacket(
385 &mut buffer_set.initial_handshake,
386 ))
387 }
388
389 HandshakeState::WaitingInitialHandshake => {
390 let handshake = read_initial_handshake(&buffer_set.initial_handshake)?;
391
392 let mut client_caps = CAPABILITIES_ALWAYS_ENABLED
393 | (self.opts.capabilities & CAPABILITIES_CONFIGURABLE);
394 if self.opts.db.is_some() {
395 client_caps |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
396 }
397 if self.opts.tls {
398 client_caps |= CapabilityFlags::CLIENT_SSL;
399 }
400
401 let negotiated_caps = client_caps & handshake.capability_flags;
402 let mariadb_caps = if negotiated_caps.is_mariadb() {
403 if !handshake
404 .mariadb_capabilities
405 .contains(MARIADB_CAPABILITIES_ENABLED)
406 {
407 return Err(Error::Unsupported(format!(
408 "MariaDB server does not support the required capabilities. Server: {:?} Required: {:?}",
409 handshake.mariadb_capabilities, MARIADB_CAPABILITIES_ENABLED
410 )));
411 }
412 MARIADB_CAPABILITIES_ENABLED
413 } else {
414 MariadbCapabilityFlags::empty()
415 };
416
417 self.capability_flags = Some(negotiated_caps);
419 self.mariadb_capabilities = Some(mariadb_caps);
420 self.initial_handshake = Some(handshake);
421
422 if self.opts.tls && negotiated_caps.contains(CapabilityFlags::CLIENT_SSL) {
424 write_ssl_request(buffer_set.new_write_buffer(), negotiated_caps, mariadb_caps);
425
426 let seq = self.next_sequence_id;
427 self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
428 self.state = HandshakeState::WaitingTlsUpgrade;
429
430 Ok(HandshakeAction::UpgradeTls { sequence_id: seq })
431 } else {
432 self.write_handshake_response(buffer_set)?;
434 let seq = self.next_sequence_id;
435 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
436 self.state = HandshakeState::WaitingAuthResult;
437
438 Ok(HandshakeAction::WritePacket { sequence_id: seq })
439 }
440 }
441
442 HandshakeState::WaitingTlsUpgrade => {
443 self.write_handshake_response(buffer_set)?;
445
446 let seq = self.next_sequence_id;
447 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
448 self.state = HandshakeState::WaitingAuthResult;
449
450 Ok(HandshakeAction::WritePacket { sequence_id: seq })
451 }
452
453 HandshakeState::WaitingAuthResult => {
454 let payload = &buffer_set.read_buffer[..];
455 if payload.is_empty() {
456 return Err(Error::LibraryBug(eyre!(
457 "empty payload while waiting for auth result"
458 )));
459 }
460
461 let initial_handshake = self.initial_handshake.as_ref().ok_or_else(|| {
463 Error::LibraryBug(eyre!("initial_handshake not set in WaitingAuthResult"))
464 })?;
465 let initial_plugin =
466 &buffer_set.initial_handshake[initial_handshake.auth_plugin_name.clone()];
467
468 match payload[0] {
469 0x00 => {
470 self.state = HandshakeState::Connected;
472 Ok(HandshakeAction::Finished)
473 }
474 0xFF => {
475 Err(ErrPayloadBytes(payload).into())
477 }
478 0x01 => {
479 if initial_plugin == b"caching_sha2_password" {
481 self.handle_auth_more_data(buffer_set)
482 } else {
483 Err(Error::LibraryBug(eyre!(
484 "unexpected AuthMoreData (0x01) for plugin {:?}",
485 String::from_utf8_lossy(initial_plugin)
486 )))
487 }
488 }
489 0xFE => {
490 let auth_switch = read_auth_switch_request(payload)?;
492
493 let (auth_response, is_caching_sha2) = match auth_switch.plugin_name {
495 b"mysql_native_password" => (
496 auth_mysql_native_password(
497 &self.opts.password,
498 auth_switch.plugin_data,
499 )
500 .to_vec(),
501 false,
502 ),
503 b"caching_sha2_password" => (
504 auth_caching_sha2_password(
505 &self.opts.password,
506 auth_switch.plugin_data,
507 )
508 .to_vec(),
509 true,
510 ),
511 plugin => {
512 return Err(Error::Unsupported(
513 String::from_utf8_lossy(plugin).to_string(),
514 ));
515 }
516 };
517
518 write_auth_switch_response(buffer_set.new_write_buffer(), &auth_response);
519
520 let seq = self.next_sequence_id;
521 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
522 self.state = HandshakeState::WaitingFinalAuthResult {
523 caching_sha2: is_caching_sha2,
524 };
525
526 Ok(HandshakeAction::WritePacket { sequence_id: seq })
527 }
528 header => Err(Error::LibraryBug(eyre!(
529 "unexpected packet header 0x{:02X} while waiting for auth result",
530 header
531 ))),
532 }
533 }
534
535 HandshakeState::WaitingFinalAuthResult { caching_sha2 } => {
536 let payload = &buffer_set.read_buffer[..];
537 if payload.is_empty() {
538 return Err(Error::LibraryBug(eyre!(
539 "empty payload while waiting for final auth result"
540 )));
541 }
542
543 match payload[0] {
544 0x00 => {
545 self.state = HandshakeState::Connected;
547 Ok(HandshakeAction::Finished)
548 }
549 0xFF => {
550 Err(ErrPayloadBytes(payload).into())
552 }
553 0x01 if *caching_sha2 => self.handle_auth_more_data(buffer_set),
554 header => Err(Error::LibraryBug(eyre!(
555 "unexpected packet header 0x{:02X} while waiting for final auth result",
556 header
557 ))),
558 }
559 }
560
561 HandshakeState::WaitingCachingSha2FastAuthOk => {
562 let payload = &buffer_set.read_buffer[..];
563 if payload.is_empty() {
564 return Err(Error::LibraryBug(eyre!(
565 "empty payload while waiting for caching_sha2 OK"
566 )));
567 }
568
569 match payload[0] {
570 0x00 => {
571 self.state = HandshakeState::Connected;
572 Ok(HandshakeAction::Finished)
573 }
574 0xFF => Err(ErrPayloadBytes(payload).into()),
575 header => Err(Error::LibraryBug(eyre!(
576 "unexpected packet header 0x{:02X} while waiting for caching_sha2 OK",
577 header
578 ))),
579 }
580 }
581
582 HandshakeState::WaitingRsaPublicKey => {
583 let payload = &buffer_set.read_buffer[..];
584 if payload.is_empty() {
585 return Err(Error::LibraryBug(eyre!(
586 "empty payload while waiting for RSA public key"
587 )));
588 }
589
590 match payload[0] {
591 0xFF => return Err(ErrPayloadBytes(payload).into()),
592 0x01 if payload.len() >= 2 => {}
593 header => {
594 return Err(Error::LibraryBug(eyre!(
595 "expected AuthMoreData (0x01) with RSA public key, got 0x{:02X}",
596 header
597 )));
598 }
599 }
600
601 let pem = std::str::from_utf8(&payload[1..]).map_err(|e| {
602 Error::LibraryBug(eyre!("RSA public key is not valid UTF-8: {}", e))
603 })?;
604
605 let handshake = self
606 .initial_handshake
607 .as_ref()
608 .ok_or_else(|| Error::LibraryBug(eyre!("initial_handshake not set")))?;
609
610 let encrypted =
611 rsa_encrypt_password(&self.opts.password, &handshake.auth_plugin_data, pem)?;
612
613 let out = buffer_set.new_write_buffer();
614 out.extend_from_slice(&encrypted);
615
616 let seq = self.next_sequence_id;
617 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
618 self.state = HandshakeState::WaitingFinalAuthResult {
619 caching_sha2: false,
620 };
621
622 Ok(HandshakeAction::WritePacket { sequence_id: seq })
623 }
624
625 HandshakeState::Connected => Err(Error::LibraryBug(eyre!(
626 "step() called after handshake completed"
627 ))),
628 }
629 }
630
631 pub fn finish(self) -> Result<(InitialHandshake, CapabilityFlags, MariadbCapabilityFlags)> {
635 if !matches!(self.state, HandshakeState::Connected) {
636 return Err(Error::LibraryBug(eyre!(
637 "finish() called before handshake completed"
638 )));
639 }
640
641 let initial_handshake = self.initial_handshake.ok_or_else(|| {
642 Error::LibraryBug(eyre!("initial_handshake not set in Connected state"))
643 })?;
644 let capability_flags = self.capability_flags.ok_or_else(|| {
645 Error::LibraryBug(eyre!("capability_flags not set in Connected state"))
646 })?;
647 let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
648 Error::LibraryBug(eyre!("mariadb_capabilities not set in Connected state"))
649 })?;
650
651 Ok((initial_handshake, capability_flags, mariadb_capabilities))
652 }
653
654 fn write_handshake_response(&self, buffer_set: &mut BufferSet) -> Result<()> {
656 buffer_set.new_write_buffer();
657
658 let handshake = self.initial_handshake.as_ref().ok_or_else(|| {
659 Error::LibraryBug(eyre!(
660 "initial_handshake not set in write_handshake_response"
661 ))
662 })?;
663 let capability_flags = self.capability_flags.ok_or_else(|| {
664 Error::LibraryBug(eyre!(
665 "capability_flags not set in write_handshake_response"
666 ))
667 })?;
668 let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
669 Error::LibraryBug(eyre!(
670 "mariadb_capabilities not set in write_handshake_response"
671 ))
672 })?;
673
674 let auth_plugin_name = &buffer_set.initial_handshake[handshake.auth_plugin_name.clone()];
676 let auth_response = {
677 match auth_plugin_name {
678 b"mysql_native_password" => {
679 auth_mysql_native_password(&self.opts.password, &handshake.auth_plugin_data)
680 .to_vec()
681 }
682 b"caching_sha2_password" => {
683 auth_caching_sha2_password(&self.opts.password, &handshake.auth_plugin_data)
684 .to_vec()
685 }
686 plugin => {
687 return Err(Error::Unsupported(
688 String::from_utf8_lossy(plugin).to_string(),
689 ));
690 }
691 }
692 };
693
694 let out = &mut buffer_set.write_buffer;
695 write_int_4(out, capability_flags.bits());
697 write_int_4(out, MAX_ALLOWED_PACKET);
699 write_int_1(out, UTF8MB4_GENERAL_CI);
701 out.extend_from_slice(&[0_u8; 19]);
703 write_int_4(out, mariadb_capabilities.bits());
704 write_string_null(out, self.opts.user.as_bytes());
706 if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
708 write_bytes_lenenc(out, &auth_response);
709 } else {
710 write_int_1(out, auth_response.len() as u8);
711 out.extend_from_slice(&auth_response);
712 }
713 if let Some(db) = &self.opts.db {
715 write_string_null(out, db.as_bytes());
716 }
717
718 if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
720 write_string_null(out, auth_plugin_name);
721 }
722
723 Ok(())
724 }
725
726 fn handle_auth_more_data<'buf>(
730 &mut self,
731 buffer_set: &'buf mut BufferSet,
732 ) -> Result<HandshakeAction<'buf>> {
733 let payload = &buffer_set.read_buffer[..];
734 if payload.len() < 2 {
735 return Err(Error::LibraryBug(eyre!(
736 "AuthMoreData packet too short: {} bytes",
737 payload.len()
738 )));
739 }
740
741 let result = read_caching_sha2_password_fast_auth_result(&payload[1..])?;
742
743 match result {
744 CachingSha2PasswordFastAuthResult::Success => {
745 self.state = HandshakeState::WaitingCachingSha2FastAuthOk;
747 Ok(HandshakeAction::ReadPacket(&mut buffer_set.read_buffer))
748 }
749 CachingSha2PasswordFastAuthResult::FullAuthRequired => {
750 let capability_flags = self
751 .capability_flags
752 .ok_or_else(|| Error::LibraryBug(eyre!("capability_flags not set")))?;
753
754 if capability_flags.contains(CapabilityFlags::CLIENT_SSL) {
755 let out = buffer_set.new_write_buffer();
757 out.extend_from_slice(self.opts.password.as_bytes());
758 out.push(0);
759
760 let seq = self.next_sequence_id;
761 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
762 self.state = HandshakeState::WaitingFinalAuthResult {
763 caching_sha2: false,
764 };
765
766 Ok(HandshakeAction::WritePacket { sequence_id: seq })
767 } else {
768 let out = buffer_set.new_write_buffer();
770 out.push(0x02);
771
772 let seq = self.next_sequence_id;
773 self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
774 self.state = HandshakeState::WaitingRsaPublicKey;
775
776 Ok(HandshakeAction::WritePacket { sequence_id: seq })
777 }
778 }
779 }
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use super::*;
786
787 #[test]
788 fn handshake_fixed_fields_has_alignment_of_1() {
789 assert_eq!(std::mem::align_of::<HandshakeFixedFields>(), 1);
790 }
791
792 #[test]
793 #[allow(clippy::unwrap_used)]
794 fn rsa_encrypt_password_xors_and_encrypts() {
795 use rsa::RsaPrivateKey;
796 use rsa::pkcs8::{EncodePublicKey, LineEnding};
797
798 let mut rng = rsa::rand_core::OsRng;
799 let private_key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
800 let public_key = rsa::RsaPublicKey::from(&private_key);
801 let pem = public_key.to_public_key_pem(LineEnding::LF).unwrap();
802
803 let password = "test_password";
804 let scramble = b"01234567890123456789";
805
806 let encrypted = super::rsa_encrypt_password(password, scramble, &pem).unwrap();
807
808 use rsa::Oaep;
810 let padding = Oaep::new::<sha1::Sha1>();
811 let decrypted = private_key.decrypt(padding, &encrypted).unwrap();
812
813 let mut expected = password.as_bytes().to_vec();
815 expected.push(0);
816 for (byte, key) in expected.iter_mut().zip(scramble.iter().cycle()) {
817 *byte ^= key;
818 }
819 assert_eq!(decrypted, expected);
820 }
821
822 #[test]
823 fn fast_auth_result_parsing() {
824 assert_eq!(
825 read_caching_sha2_password_fast_auth_result(&[0x03]).unwrap(),
826 CachingSha2PasswordFastAuthResult::Success,
827 );
828 assert_eq!(
829 read_caching_sha2_password_fast_auth_result(&[0x04]).unwrap(),
830 CachingSha2PasswordFastAuthResult::FullAuthRequired,
831 );
832 assert!(read_caching_sha2_password_fast_auth_result(&[0x05]).is_err());
833 assert!(read_caching_sha2_password_fast_auth_result(&[]).is_err());
834 }
835}