1use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use pbkdf2::pbkdf2_hmac;
15use rand::RngExt;
16use sha2::{Digest, Sha256};
17use subtle::ConstantTimeEq;
18use zeroize::Zeroize;
19
20type HmacSha256 = Hmac<Sha256>;
21
22pub const DEFAULT_ITERATIONS: u32 = 100_000;
24
25pub const MAX_CLIENT_ITERATIONS: u32 = 1_000_000;
27
28#[derive(Debug, thiserror::Error)]
34pub enum AuthError {
35 #[error("invalid credentials")]
36 InvalidCredentials,
37 #[error("invalid auth header: {0}")]
38 InvalidHeader(String),
39 #[error("handshake failed: {0}")]
40 HandshakeFailed(String),
41 #[error("invalid message: {0}")]
42 InvalidMessage(String),
43 #[error("base64 decode error: {0}")]
44 Base64Error(String),
45}
46
47#[derive(Debug)]
53pub struct ScramCredentials {
54 pub salt: Vec<u8>,
55 pub iterations: u32,
56 pub stored_key: Vec<u8>,
57 pub server_key: Vec<u8>,
58}
59
60impl Drop for ScramCredentials {
61 fn drop(&mut self) {
62 self.stored_key.zeroize();
63 self.server_key.zeroize();
64 }
65}
66
67pub struct ScramHandshake {
70 pub username: String,
71 pub client_nonce: String,
72 pub server_nonce: String,
73 pub salt: Vec<u8>,
74 pub iterations: u32,
75 pub auth_message: String,
76 pub server_signature: Vec<u8>,
77 stored_key: Vec<u8>,
79}
80
81impl std::fmt::Debug for ScramHandshake {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("ScramHandshake")
84 .field("username", &self.username)
85 .field("stored_key", &"[REDACTED]")
86 .field("server_signature", &"[REDACTED]")
87 .finish()
88 }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
93pub enum AuthHeader {
94 Hello {
95 username: String,
96 data: Option<String>,
98 },
99 Scram {
100 handshake_token: String,
101 data: String,
102 },
103 Bearer {
104 auth_token: String,
105 },
106}
107
108fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
114 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts keys of any size");
116 mac.update(msg);
117 mac.finalize().into_bytes().to_vec()
118}
119
120fn sha256(data: &[u8]) -> Vec<u8> {
122 let mut hasher = Sha256::new();
123 hasher.update(data);
124 hasher.finalize().to_vec()
125}
126
127fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
129 debug_assert_eq!(a.len(), b.len(), "XOR operands must be same length");
132 a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
133}
134
135fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
137 let mut salted_password = vec![0u8; 32];
138 pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut salted_password);
139 salted_password
140 }
142
143fn derive_keys(salted_password: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
145 let client_key = hmac_sha256(salted_password, b"Client Key");
146 let stored_key = sha256(&client_key);
147 let server_key = hmac_sha256(salted_password, b"Server Key");
148 (client_key, stored_key, server_key)
149}
150
151fn parse_scram_param<'a>(segment: &'a str, prefix: &str) -> Result<&'a str, AuthError> {
153 let trimmed = segment.trim();
154 trimmed.strip_prefix(prefix).ok_or_else(|| {
155 AuthError::HandshakeFailed(format!(
156 "expected prefix '{}' but got '{}'",
157 prefix, trimmed
158 ))
159 })
160}
161
162fn make_client_first_bare(username: &str, client_nonce: &str) -> String {
164 format!("n={},r={}", username, client_nonce)
165}
166
167pub fn derive_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
175 let mut salted_password = pbkdf2_sha256(password.as_bytes(), salt, iterations);
176 let (mut _client_key, stored_key, server_key) = derive_keys(&salted_password);
177 salted_password.zeroize();
178 _client_key.zeroize();
179 ScramCredentials {
180 salt: salt.to_vec(),
181 iterations,
182 stored_key,
183 server_key,
184 }
185}
186
187pub fn generate_nonce() -> String {
189 let mut bytes = [0u8; 18];
190 rand::rng().fill(&mut bytes);
191 BASE64.encode(bytes)
192}
193
194pub fn client_first_message(username: &str) -> (String, String) {
201 let client_nonce = generate_nonce();
202 let bare = make_client_first_bare(username, &client_nonce);
203 let full = format!("n,,{}", bare);
204 let encoded = BASE64.encode(full.as_bytes());
205 (client_nonce, encoded)
206}
207
208pub fn server_first_message(
216 username: &str,
217 client_nonce_b64: &str,
218 credentials: &ScramCredentials,
219) -> (ScramHandshake, String) {
220 let server_nonce = generate_nonce();
221 let combined_nonce = format!("{}{}", client_nonce_b64, server_nonce);
222 let salt_b64 = BASE64.encode(&credentials.salt);
223
224 let server_first_msg = format!(
226 "r={},s={},i={}",
227 combined_nonce, salt_b64, credentials.iterations
228 );
229
230 let cfmb = make_client_first_bare(username, client_nonce_b64);
232
233 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
235
236 let auth_message = format!(
238 "{},{},{}",
239 cfmb, server_first_msg, client_final_without_proof
240 );
241
242 let server_signature = hmac_sha256(&credentials.server_key, auth_message.as_bytes());
244
245 let server_first_b64 = BASE64.encode(server_first_msg.as_bytes());
246
247 let handshake = ScramHandshake {
248 username: username.to_string(),
249 client_nonce: client_nonce_b64.to_string(),
250 server_nonce,
251 salt: credentials.salt.clone(),
252 iterations: credentials.iterations,
253 auth_message,
254 server_signature,
255 stored_key: credentials.stored_key.clone(),
256 };
257
258 (handshake, server_first_b64)
259}
260
261pub fn client_final_message(
270 password: &str,
271 client_nonce: &str,
272 server_first_b64: &str,
273 username: &str,
274) -> Result<(String, Vec<u8>), AuthError> {
275 let server_first_bytes = BASE64
277 .decode(server_first_b64)
278 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
279 let server_first_msg = String::from_utf8(server_first_bytes)
280 .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
281
282 let parts: Vec<&str> = server_first_msg.splitn(3, ',').collect();
284 if parts.len() != 3 {
285 return Err(AuthError::HandshakeFailed(
286 "invalid server-first-message format".to_string(),
287 ));
288 }
289
290 let combined_nonce = parse_scram_param(parts[0], "r=")?;
291 let salt_b64 = parse_scram_param(parts[1], "s=")?;
292 let iterations_str = parse_scram_param(parts[2], "i=")?;
293
294 if !combined_nonce.starts_with(client_nonce) {
296 return Err(AuthError::HandshakeFailed(
297 "combined nonce does not start with client nonce".to_string(),
298 ));
299 }
300
301 let salt = BASE64
302 .decode(salt_b64)
303 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
304 let iterations: u32 = iterations_str
305 .parse()
306 .map_err(|e: std::num::ParseIntError| AuthError::HandshakeFailed(e.to_string()))?;
307
308 if iterations > MAX_CLIENT_ITERATIONS {
309 return Err(AuthError::InvalidMessage(format!(
310 "server requested {} PBKDF2 iterations, maximum allowed is {}",
311 iterations, MAX_CLIENT_ITERATIONS
312 )));
313 }
314
315 let mut salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
317 let (mut client_key, stored_key, server_key) = derive_keys(&salted_password);
318 salted_password.zeroize();
319
320 let cfmb = make_client_first_bare(username, client_nonce);
322 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
323 let auth_message = format!(
324 "{},{},{}",
325 cfmb, server_first_msg, client_final_without_proof
326 );
327
328 let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
330 let client_proof = xor_bytes(&client_key, &client_signature);
332 let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
334
335 client_key.zeroize();
337
338 let proof_b64 = BASE64.encode(&client_proof);
340 let client_final_msg = format!("{},p={}", client_final_without_proof, proof_b64);
341 let client_final_b64 = BASE64.encode(client_final_msg.as_bytes());
342
343 Ok((client_final_b64, server_signature))
344}
345
346pub fn server_verify_final(
352 handshake: &ScramHandshake,
353 client_final_b64: &str,
354) -> Result<Vec<u8>, AuthError> {
355 let client_final_bytes = BASE64
357 .decode(client_final_b64)
358 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
359 let client_final_msg = String::from_utf8(client_final_bytes)
360 .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
361
362 let parts: Vec<&str> = client_final_msg.splitn(3, ',').collect();
364 if parts.len() != 3 {
365 return Err(AuthError::HandshakeFailed(
366 "invalid client-final-message format".to_string(),
367 ));
368 }
369
370 let channel_binding = parse_scram_param(parts[0], "c=")?;
372 if channel_binding != "biws" {
373 return Err(AuthError::HandshakeFailed(
374 "unexpected channel binding".to_string(),
375 ));
376 }
377
378 let combined_nonce = parse_scram_param(parts[1], "r=")?;
380 let expected_combined = format!("{}{}", handshake.client_nonce, handshake.server_nonce);
381 if !bool::from(
382 combined_nonce
383 .as_bytes()
384 .ct_eq(expected_combined.as_bytes()),
385 ) {
386 return Err(AuthError::HandshakeFailed("nonce mismatch".to_string()));
387 }
388
389 let proof_b64 = parse_scram_param(parts[2], "p=")?;
391 let client_proof = BASE64
392 .decode(proof_b64)
393 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
394
395 let client_signature = hmac_sha256(&handshake.stored_key, handshake.auth_message.as_bytes());
400 let recovered_client_key = xor_bytes(&client_proof, &client_signature);
401 let recovered_stored_key = sha256(&recovered_client_key);
402
403 if recovered_stored_key
404 .ct_eq(&handshake.stored_key)
405 .unwrap_u8()
406 == 0
407 {
408 return Err(AuthError::InvalidCredentials);
409 }
410
411 Ok(handshake.server_signature.clone())
413}
414
415pub fn extract_client_nonce(client_first_b64: &str) -> Result<String, AuthError> {
420 let bytes = BASE64
421 .decode(client_first_b64)
422 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
423 let msg = String::from_utf8(bytes).map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
424 let bare = msg
426 .strip_prefix("n,,")
427 .ok_or_else(|| AuthError::HandshakeFailed("missing GS2 header in client-first".into()))?;
428 for part in bare.split(',') {
430 if let Some(nonce) = part.strip_prefix("r=") {
431 return Ok(nonce.to_string());
432 }
433 }
434 Err(AuthError::HandshakeFailed(
435 "missing r= nonce in client-first-message".into(),
436 ))
437}
438
439pub fn parse_auth_header(header: &str) -> Result<AuthHeader, AuthError> {
446 let header = header.trim();
447
448 if let Some(rest) = header.strip_prefix("HELLO ") {
449 let mut username_b64_val = None;
450 let mut data_val = None;
451 for part in rest.split(',') {
452 let part = part.trim();
453 if let Some(val) = part.strip_prefix("username=") {
454 username_b64_val = Some(val.trim().to_string());
455 } else if let Some(val) = part.strip_prefix("data=") {
456 data_val = Some(val.trim().to_string());
457 }
458 }
459 let username_b64 = username_b64_val
460 .ok_or_else(|| AuthError::InvalidHeader("missing username= in HELLO".into()))?;
461 let username_bytes = BASE64
462 .decode(&username_b64)
463 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
464 let username = String::from_utf8(username_bytes)
465 .map_err(|e| AuthError::InvalidHeader(e.to_string()))?;
466 Ok(AuthHeader::Hello {
467 username,
468 data: data_val,
469 })
470 } else if let Some(rest) = header.strip_prefix("SCRAM ") {
471 let mut handshake_token = None;
472 let mut data = None;
473 for part in rest.split(',') {
474 let part = part.trim();
475 if let Some(val) = part.strip_prefix("handshakeToken=") {
476 handshake_token = Some(val.trim().to_string());
477 } else if let Some(val) = part.strip_prefix("data=") {
478 data = Some(val.trim().to_string());
479 }
480 }
481 let handshake_token = handshake_token
482 .ok_or_else(|| AuthError::InvalidHeader("missing handshakeToken= in SCRAM".into()))?;
483 let data = data.ok_or_else(|| AuthError::InvalidHeader("missing data= in SCRAM".into()))?;
484 Ok(AuthHeader::Scram {
485 handshake_token,
486 data,
487 })
488 } else if let Some(rest) = header.strip_prefix("BEARER ") {
489 let token = rest
490 .trim()
491 .strip_prefix("authToken=")
492 .ok_or_else(|| AuthError::InvalidHeader("missing authToken= in BEARER".into()))?;
493 Ok(AuthHeader::Bearer {
494 auth_token: token.trim().to_string(),
495 })
496 } else {
497 Err(AuthError::InvalidHeader(format!(
498 "unrecognized auth scheme: {}",
499 header
500 )))
501 }
502}
503
504pub fn format_www_authenticate(handshake_token: &str, hash: &str, data_b64: &str) -> String {
508 format!(
509 "SCRAM handshakeToken={}, hash={}, data={}",
510 handshake_token, hash, data_b64
511 )
512}
513
514pub fn format_auth_info(auth_token: &str, data_b64: &str) -> String {
518 format!("authToken={}, data={}", auth_token, data_b64)
519}
520
521#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn test_derive_credentials() {
531 let password = "pencil";
532 let salt = b"random-salt-value";
533 let iterations = 4096;
534
535 let creds = derive_credentials(password, salt, iterations);
536
537 assert_eq!(creds.salt, salt.to_vec());
539 assert_eq!(creds.iterations, iterations);
540 assert_eq!(creds.stored_key.len(), 32); assert_eq!(creds.server_key.len(), 32);
542
543 let creds2 = derive_credentials(password, salt, iterations);
545 assert_eq!(creds.stored_key, creds2.stored_key);
546 assert_eq!(creds.server_key, creds2.server_key);
547
548 let creds3 = derive_credentials("other", salt, iterations);
550 assert_ne!(creds.stored_key, creds3.stored_key);
551 assert_ne!(creds.server_key, creds3.server_key);
552 }
553
554 #[test]
555 fn test_generate_nonce() {
556 let n1 = generate_nonce();
557 let n2 = generate_nonce();
558
559 assert_ne!(n1, n2);
561
562 let decoded1 = BASE64.decode(&n1).expect("nonce must be valid base64");
564 assert_eq!(decoded1.len(), 18);
565
566 let decoded2 = BASE64.decode(&n2).expect("nonce must be valid base64");
567 assert_eq!(decoded2.len(), 18);
568 }
569
570 #[test]
571 fn test_parse_auth_header_hello() {
572 let username = "user";
573 let username_b64 = BASE64.encode(username.as_bytes());
574 let header = format!("HELLO username={}", username_b64);
575
576 let parsed = parse_auth_header(&header).unwrap();
577 assert_eq!(
578 parsed,
579 AuthHeader::Hello {
580 username: "user".to_string(),
581 data: None,
582 }
583 );
584 }
585
586 #[test]
587 fn test_parse_auth_header_scram() {
588 let header = "SCRAM handshakeToken=abc123, data=c29tZWRhdGE=";
589 let parsed = parse_auth_header(header).unwrap();
590 assert_eq!(
591 parsed,
592 AuthHeader::Scram {
593 handshake_token: "abc123".to_string(),
594 data: "c29tZWRhdGE=".to_string(),
595 }
596 );
597 }
598
599 #[test]
600 fn test_parse_auth_header_bearer() {
601 let header = "BEARER authToken=mytoken123";
602 let parsed = parse_auth_header(header).unwrap();
603 assert_eq!(
604 parsed,
605 AuthHeader::Bearer {
606 auth_token: "mytoken123".to_string(),
607 }
608 );
609 }
610
611 #[test]
612 fn test_parse_auth_header_invalid() {
613 assert!(parse_auth_header("UNKNOWN foo=bar").is_err());
615 assert!(parse_auth_header("HELLO foo=bar").is_err());
617 assert!(parse_auth_header("SCRAM handshakeToken=abc").is_err());
619 assert!(parse_auth_header("BEARER token=abc").is_err());
621 assert!(parse_auth_header("").is_err());
623 }
624
625 #[test]
626 fn test_full_handshake() {
627 let username = "testuser";
629 let password = "s3cret";
630 let salt = b"test-salt-12345";
631 let iterations = 4096;
632
633 let credentials = derive_credentials(password, salt, iterations);
635
636 let username_b64 = BASE64.encode(username.as_bytes());
638 let hello_header = format!("HELLO username={}", username_b64);
639 let parsed = parse_auth_header(&hello_header).unwrap();
640 match &parsed {
641 AuthHeader::Hello { username: u, .. } => assert_eq!(u, username),
642 _ => panic!("expected Hello variant"),
643 }
644
645 let (client_nonce, _client_first_b64) = client_first_message(username);
647
648 let (handshake, server_first_b64) =
650 server_first_message(username, &client_nonce, &credentials);
651
652 let www_auth = format_www_authenticate("handshake-token-xyz", "SHA-256", &server_first_b64);
654 assert!(www_auth.contains("SCRAM"));
655 assert!(www_auth.contains("SHA-256"));
656 assert!(www_auth.contains("handshake-token-xyz"));
657
658 let (client_final_b64, expected_server_sig) =
660 client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
661
662 let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
664
665 assert_eq!(server_sig, expected_server_sig);
667
668 let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
670 let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
671 let auth_info = format_auth_info("auth-token-abc", &server_final_b64);
672 assert!(auth_info.contains("authToken=auth-token-abc"));
673
674 let server_final_decoded = BASE64.decode(&server_final_b64).unwrap();
676 let server_final_str = String::from_utf8(server_final_decoded).unwrap();
677 let sig_b64 = server_final_str.strip_prefix("v=").unwrap();
678 let received_server_sig = BASE64.decode(sig_b64).unwrap();
679 assert_eq!(received_server_sig, expected_server_sig);
680 }
681
682 #[test]
683 fn test_client_server_roundtrip() {
684 let username = "admin";
686 let password = "correcthorsebatterystaple";
687 let salt = b"unique-salt-value";
688 let iterations = DEFAULT_ITERATIONS;
689
690 let credentials = derive_credentials(password, salt, iterations);
692
693 let (client_nonce, client_first_b64) = client_first_message(username);
695
696 let client_first_decoded = BASE64.decode(&client_first_b64).unwrap();
698 let client_first_str = String::from_utf8(client_first_decoded).unwrap();
699 assert!(client_first_str.starts_with("n,,"));
700 assert!(client_first_str.contains(&format!("r={}", client_nonce)));
701
702 let (handshake, server_first_b64) =
704 server_first_message(username, &client_nonce, &credentials);
705
706 let server_first_decoded = BASE64.decode(&server_first_b64).unwrap();
708 let server_first_str = String::from_utf8(server_first_decoded).unwrap();
709 assert!(server_first_str.starts_with("r="));
710 assert!(server_first_str.contains(",s="));
711 assert!(server_first_str.contains(",i="));
712 assert!(server_first_str.contains(&client_nonce));
713
714 let (client_final_b64, expected_server_sig) =
716 client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
717
718 let client_final_decoded = BASE64.decode(&client_final_b64).unwrap();
720 let client_final_str = String::from_utf8(client_final_decoded).unwrap();
721 assert!(client_final_str.starts_with("c=biws,"));
722 assert!(client_final_str.contains(",p="));
723
724 let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
726 assert_eq!(server_sig, expected_server_sig);
727
728 let (wrong_final_b64, _) =
730 client_final_message("wrongpassword", &client_nonce, &server_first_b64, username)
731 .unwrap();
732 let result = server_verify_final(&handshake, &wrong_final_b64);
733 assert!(result.is_err());
734 match result {
735 Err(AuthError::InvalidCredentials) => {} other => panic!("expected InvalidCredentials, got {:?}", other),
737 }
738 }
739
740 #[test]
741 fn test_format_www_authenticate() {
742 let result = format_www_authenticate("tok123", "SHA-256", "c29tZQ==");
743 assert_eq!(
744 result,
745 "SCRAM handshakeToken=tok123, hash=SHA-256, data=c29tZQ=="
746 );
747 }
748
749 #[test]
750 fn test_format_auth_info() {
751 let result = format_auth_info("auth-tok", "ZGF0YQ==");
752 assert_eq!(result, "authToken=auth-tok, data=ZGF0YQ==");
753 }
754}