1use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use pbkdf2::pbkdf2_hmac;
15use rand::RngExt;
16use sha2::{Digest, Sha256};
17use subtle::ConstantTimeEq;
18use zeroize::Zeroize;
19
20type HmacSha256 = Hmac<Sha256>;
21
22pub const DEFAULT_ITERATIONS: u32 = 100_000;
24
25#[derive(Debug, thiserror::Error)]
31pub enum AuthError {
32 #[error("invalid credentials")]
33 InvalidCredentials,
34 #[error("invalid auth header: {0}")]
35 InvalidHeader(String),
36 #[error("handshake failed: {0}")]
37 HandshakeFailed(String),
38 #[error("base64 decode error: {0}")]
39 Base64Error(String),
40}
41
42#[derive(Debug, Clone)]
48pub struct ScramCredentials {
49 pub salt: Vec<u8>,
50 pub iterations: u32,
51 pub stored_key: Vec<u8>,
52 pub server_key: Vec<u8>,
53}
54
55impl Drop for ScramCredentials {
56 fn drop(&mut self) {
57 self.stored_key.zeroize();
58 self.server_key.zeroize();
59 }
60}
61
62#[derive(Debug, Clone)]
65pub struct ScramHandshake {
66 pub username: String,
67 pub client_nonce: String,
68 pub server_nonce: String,
69 pub salt: Vec<u8>,
70 pub iterations: u32,
71 pub auth_message: String,
72 pub server_signature: Vec<u8>,
73 stored_key: Vec<u8>,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
79pub enum AuthHeader {
80 Hello {
81 username: String,
82 data: Option<String>,
84 },
85 Scram {
86 handshake_token: String,
87 data: String,
88 },
89 Bearer {
90 auth_token: String,
91 },
92}
93
94fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
100 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts keys of any size");
101 mac.update(msg);
102 mac.finalize().into_bytes().to_vec()
103}
104
105fn sha256(data: &[u8]) -> Vec<u8> {
107 let mut hasher = Sha256::new();
108 hasher.update(data);
109 hasher.finalize().to_vec()
110}
111
112fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
114 assert_eq!(a.len(), b.len(), "XOR operands must be the same length");
115 a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
116}
117
118fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
120 let mut salted_password = vec![0u8; 32];
121 pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut salted_password);
122 salted_password
123 }
125
126fn derive_keys(salted_password: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
128 let client_key = hmac_sha256(salted_password, b"Client Key");
129 let stored_key = sha256(&client_key);
130 let server_key = hmac_sha256(salted_password, b"Server Key");
131 (client_key, stored_key, server_key)
132}
133
134fn parse_scram_param<'a>(segment: &'a str, prefix: &str) -> Result<&'a str, AuthError> {
136 let trimmed = segment.trim();
137 trimmed.strip_prefix(prefix).ok_or_else(|| {
138 AuthError::HandshakeFailed(format!(
139 "expected prefix '{}' but got '{}'",
140 prefix, trimmed
141 ))
142 })
143}
144
145fn make_client_first_bare(username: &str, client_nonce: &str) -> String {
147 format!("n={},r={}", username, client_nonce)
148}
149
150pub fn derive_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
158 let mut salted_password = pbkdf2_sha256(password.as_bytes(), salt, iterations);
159 let (mut _client_key, stored_key, server_key) = derive_keys(&salted_password);
160 salted_password.zeroize();
161 _client_key.zeroize();
162 ScramCredentials {
163 salt: salt.to_vec(),
164 iterations,
165 stored_key,
166 server_key,
167 }
168}
169
170pub fn generate_nonce() -> String {
172 let mut bytes = [0u8; 18];
173 rand::rng().fill(&mut bytes);
174 BASE64.encode(bytes)
175}
176
177pub fn client_first_message(username: &str) -> (String, String) {
184 let client_nonce = generate_nonce();
185 let bare = make_client_first_bare(username, &client_nonce);
186 let full = format!("n,,{}", bare);
187 let encoded = BASE64.encode(full.as_bytes());
188 (client_nonce, encoded)
189}
190
191pub fn server_first_message(
199 username: &str,
200 client_nonce_b64: &str,
201 credentials: &ScramCredentials,
202) -> (ScramHandshake, String) {
203 let server_nonce = generate_nonce();
204 let combined_nonce = format!("{}{}", client_nonce_b64, server_nonce);
205 let salt_b64 = BASE64.encode(&credentials.salt);
206
207 let server_first_msg = format!(
209 "r={},s={},i={}",
210 combined_nonce, salt_b64, credentials.iterations
211 );
212
213 let cfmb = make_client_first_bare(username, client_nonce_b64);
215
216 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
218
219 let auth_message = format!(
221 "{},{},{}",
222 cfmb, server_first_msg, client_final_without_proof
223 );
224
225 let server_signature = hmac_sha256(&credentials.server_key, auth_message.as_bytes());
227
228 let server_first_b64 = BASE64.encode(server_first_msg.as_bytes());
229
230 let handshake = ScramHandshake {
231 username: username.to_string(),
232 client_nonce: client_nonce_b64.to_string(),
233 server_nonce,
234 salt: credentials.salt.clone(),
235 iterations: credentials.iterations,
236 auth_message,
237 server_signature,
238 stored_key: credentials.stored_key.clone(),
239 };
240
241 (handshake, server_first_b64)
242}
243
244pub fn client_final_message(
253 password: &str,
254 client_nonce: &str,
255 server_first_b64: &str,
256 username: &str,
257) -> Result<(String, Vec<u8>), AuthError> {
258 let server_first_bytes = BASE64
260 .decode(server_first_b64)
261 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
262 let server_first_msg = String::from_utf8(server_first_bytes)
263 .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
264
265 let parts: Vec<&str> = server_first_msg.splitn(3, ',').collect();
267 if parts.len() != 3 {
268 return Err(AuthError::HandshakeFailed(
269 "invalid server-first-message format".to_string(),
270 ));
271 }
272
273 let combined_nonce = parse_scram_param(parts[0], "r=")?;
274 let salt_b64 = parse_scram_param(parts[1], "s=")?;
275 let iterations_str = parse_scram_param(parts[2], "i=")?;
276
277 if !combined_nonce.starts_with(client_nonce) {
279 return Err(AuthError::HandshakeFailed(
280 "combined nonce does not start with client nonce".to_string(),
281 ));
282 }
283
284 let salt = BASE64
285 .decode(salt_b64)
286 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
287 let iterations: u32 = iterations_str
288 .parse()
289 .map_err(|e: std::num::ParseIntError| AuthError::HandshakeFailed(e.to_string()))?;
290
291 let mut salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
293 let (mut client_key, stored_key, server_key) = derive_keys(&salted_password);
294 salted_password.zeroize();
295
296 let cfmb = make_client_first_bare(username, client_nonce);
298 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
299 let auth_message = format!(
300 "{},{},{}",
301 cfmb, server_first_msg, client_final_without_proof
302 );
303
304 let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
306 let client_proof = xor_bytes(&client_key, &client_signature);
308 let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
310
311 client_key.zeroize();
313
314 let proof_b64 = BASE64.encode(&client_proof);
316 let client_final_msg = format!("{},p={}", client_final_without_proof, proof_b64);
317 let client_final_b64 = BASE64.encode(client_final_msg.as_bytes());
318
319 Ok((client_final_b64, server_signature))
320}
321
322pub fn server_verify_final(
328 handshake: &ScramHandshake,
329 client_final_b64: &str,
330) -> Result<Vec<u8>, AuthError> {
331 let client_final_bytes = BASE64
333 .decode(client_final_b64)
334 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
335 let client_final_msg = String::from_utf8(client_final_bytes)
336 .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
337
338 let parts: Vec<&str> = client_final_msg.splitn(3, ',').collect();
340 if parts.len() != 3 {
341 return Err(AuthError::HandshakeFailed(
342 "invalid client-final-message format".to_string(),
343 ));
344 }
345
346 let channel_binding = parse_scram_param(parts[0], "c=")?;
348 if channel_binding != "biws" {
349 return Err(AuthError::HandshakeFailed(
350 "unexpected channel binding".to_string(),
351 ));
352 }
353
354 let combined_nonce = parse_scram_param(parts[1], "r=")?;
356 let expected_combined = format!("{}{}", handshake.client_nonce, handshake.server_nonce);
357 if !bool::from(
358 combined_nonce
359 .as_bytes()
360 .ct_eq(expected_combined.as_bytes()),
361 ) {
362 return Err(AuthError::HandshakeFailed("nonce mismatch".to_string()));
363 }
364
365 let proof_b64 = parse_scram_param(parts[2], "p=")?;
367 let client_proof = BASE64
368 .decode(proof_b64)
369 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
370
371 let client_signature = hmac_sha256(&handshake.stored_key, handshake.auth_message.as_bytes());
376 let recovered_client_key = xor_bytes(&client_proof, &client_signature);
377 let recovered_stored_key = sha256(&recovered_client_key);
378
379 if recovered_stored_key
380 .ct_eq(&handshake.stored_key)
381 .unwrap_u8()
382 == 0
383 {
384 return Err(AuthError::InvalidCredentials);
385 }
386
387 Ok(handshake.server_signature.clone())
389}
390
391pub fn extract_client_nonce(client_first_b64: &str) -> Result<String, AuthError> {
396 let bytes = BASE64
397 .decode(client_first_b64)
398 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
399 let msg = String::from_utf8(bytes).map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
400 let bare = msg
402 .strip_prefix("n,,")
403 .ok_or_else(|| AuthError::HandshakeFailed("missing GS2 header in client-first".into()))?;
404 for part in bare.split(',') {
406 if let Some(nonce) = part.strip_prefix("r=") {
407 return Ok(nonce.to_string());
408 }
409 }
410 Err(AuthError::HandshakeFailed(
411 "missing r= nonce in client-first-message".into(),
412 ))
413}
414
415pub fn parse_auth_header(header: &str) -> Result<AuthHeader, AuthError> {
422 let header = header.trim();
423
424 if let Some(rest) = header.strip_prefix("HELLO ") {
425 let mut username_b64_val = None;
426 let mut data_val = None;
427 for part in rest.split(',') {
428 let part = part.trim();
429 if let Some(val) = part.strip_prefix("username=") {
430 username_b64_val = Some(val.trim().to_string());
431 } else if let Some(val) = part.strip_prefix("data=") {
432 data_val = Some(val.trim().to_string());
433 }
434 }
435 let username_b64 = username_b64_val
436 .ok_or_else(|| AuthError::InvalidHeader("missing username= in HELLO".into()))?;
437 let username_bytes = BASE64
438 .decode(&username_b64)
439 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
440 let username = String::from_utf8(username_bytes)
441 .map_err(|e| AuthError::InvalidHeader(e.to_string()))?;
442 Ok(AuthHeader::Hello {
443 username,
444 data: data_val,
445 })
446 } else if let Some(rest) = header.strip_prefix("SCRAM ") {
447 let mut handshake_token = None;
448 let mut data = None;
449 for part in rest.split(',') {
450 let part = part.trim();
451 if let Some(val) = part.strip_prefix("handshakeToken=") {
452 handshake_token = Some(val.trim().to_string());
453 } else if let Some(val) = part.strip_prefix("data=") {
454 data = Some(val.trim().to_string());
455 }
456 }
457 let handshake_token = handshake_token
458 .ok_or_else(|| AuthError::InvalidHeader("missing handshakeToken= in SCRAM".into()))?;
459 let data = data.ok_or_else(|| AuthError::InvalidHeader("missing data= in SCRAM".into()))?;
460 Ok(AuthHeader::Scram {
461 handshake_token,
462 data,
463 })
464 } else if let Some(rest) = header.strip_prefix("BEARER ") {
465 let token = rest
466 .trim()
467 .strip_prefix("authToken=")
468 .ok_or_else(|| AuthError::InvalidHeader("missing authToken= in BEARER".into()))?;
469 Ok(AuthHeader::Bearer {
470 auth_token: token.trim().to_string(),
471 })
472 } else {
473 Err(AuthError::InvalidHeader(format!(
474 "unrecognized auth scheme: {}",
475 header
476 )))
477 }
478}
479
480pub fn format_www_authenticate(handshake_token: &str, hash: &str, data_b64: &str) -> String {
484 format!(
485 "SCRAM handshakeToken={}, hash={}, data={}",
486 handshake_token, hash, data_b64
487 )
488}
489
490pub fn format_auth_info(auth_token: &str, data_b64: &str) -> String {
494 format!("authToken={}, data={}", auth_token, data_b64)
495}
496
497#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_derive_credentials() {
507 let password = "pencil";
508 let salt = b"random-salt-value";
509 let iterations = 4096;
510
511 let creds = derive_credentials(password, salt, iterations);
512
513 assert_eq!(creds.salt, salt.to_vec());
515 assert_eq!(creds.iterations, iterations);
516 assert_eq!(creds.stored_key.len(), 32); assert_eq!(creds.server_key.len(), 32);
518
519 let creds2 = derive_credentials(password, salt, iterations);
521 assert_eq!(creds.stored_key, creds2.stored_key);
522 assert_eq!(creds.server_key, creds2.server_key);
523
524 let creds3 = derive_credentials("other", salt, iterations);
526 assert_ne!(creds.stored_key, creds3.stored_key);
527 assert_ne!(creds.server_key, creds3.server_key);
528 }
529
530 #[test]
531 fn test_generate_nonce() {
532 let n1 = generate_nonce();
533 let n2 = generate_nonce();
534
535 assert_ne!(n1, n2);
537
538 let decoded1 = BASE64.decode(&n1).expect("nonce must be valid base64");
540 assert_eq!(decoded1.len(), 18);
541
542 let decoded2 = BASE64.decode(&n2).expect("nonce must be valid base64");
543 assert_eq!(decoded2.len(), 18);
544 }
545
546 #[test]
547 fn test_parse_auth_header_hello() {
548 let username = "user";
549 let username_b64 = BASE64.encode(username.as_bytes());
550 let header = format!("HELLO username={}", username_b64);
551
552 let parsed = parse_auth_header(&header).unwrap();
553 assert_eq!(
554 parsed,
555 AuthHeader::Hello {
556 username: "user".to_string(),
557 data: None,
558 }
559 );
560 }
561
562 #[test]
563 fn test_parse_auth_header_scram() {
564 let header = "SCRAM handshakeToken=abc123, data=c29tZWRhdGE=";
565 let parsed = parse_auth_header(header).unwrap();
566 assert_eq!(
567 parsed,
568 AuthHeader::Scram {
569 handshake_token: "abc123".to_string(),
570 data: "c29tZWRhdGE=".to_string(),
571 }
572 );
573 }
574
575 #[test]
576 fn test_parse_auth_header_bearer() {
577 let header = "BEARER authToken=mytoken123";
578 let parsed = parse_auth_header(header).unwrap();
579 assert_eq!(
580 parsed,
581 AuthHeader::Bearer {
582 auth_token: "mytoken123".to_string(),
583 }
584 );
585 }
586
587 #[test]
588 fn test_parse_auth_header_invalid() {
589 assert!(parse_auth_header("UNKNOWN foo=bar").is_err());
591 assert!(parse_auth_header("HELLO foo=bar").is_err());
593 assert!(parse_auth_header("SCRAM handshakeToken=abc").is_err());
595 assert!(parse_auth_header("BEARER token=abc").is_err());
597 assert!(parse_auth_header("").is_err());
599 }
600
601 #[test]
602 fn test_full_handshake() {
603 let username = "testuser";
605 let password = "s3cret";
606 let salt = b"test-salt-12345";
607 let iterations = 4096;
608
609 let credentials = derive_credentials(password, salt, iterations);
611
612 let username_b64 = BASE64.encode(username.as_bytes());
614 let hello_header = format!("HELLO username={}", username_b64);
615 let parsed = parse_auth_header(&hello_header).unwrap();
616 match &parsed {
617 AuthHeader::Hello { username: u, .. } => assert_eq!(u, username),
618 _ => panic!("expected Hello variant"),
619 }
620
621 let (client_nonce, _client_first_b64) = client_first_message(username);
623
624 let (handshake, server_first_b64) =
626 server_first_message(username, &client_nonce, &credentials);
627
628 let www_auth = format_www_authenticate("handshake-token-xyz", "SHA-256", &server_first_b64);
630 assert!(www_auth.contains("SCRAM"));
631 assert!(www_auth.contains("SHA-256"));
632 assert!(www_auth.contains("handshake-token-xyz"));
633
634 let (client_final_b64, expected_server_sig) =
636 client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
637
638 let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
640
641 assert_eq!(server_sig, expected_server_sig);
643
644 let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
646 let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
647 let auth_info = format_auth_info("auth-token-abc", &server_final_b64);
648 assert!(auth_info.contains("authToken=auth-token-abc"));
649
650 let server_final_decoded = BASE64.decode(&server_final_b64).unwrap();
652 let server_final_str = String::from_utf8(server_final_decoded).unwrap();
653 let sig_b64 = server_final_str.strip_prefix("v=").unwrap();
654 let received_server_sig = BASE64.decode(sig_b64).unwrap();
655 assert_eq!(received_server_sig, expected_server_sig);
656 }
657
658 #[test]
659 fn test_client_server_roundtrip() {
660 let username = "admin";
662 let password = "correcthorsebatterystaple";
663 let salt = b"unique-salt-value";
664 let iterations = DEFAULT_ITERATIONS;
665
666 let credentials = derive_credentials(password, salt, iterations);
668
669 let (client_nonce, client_first_b64) = client_first_message(username);
671
672 let client_first_decoded = BASE64.decode(&client_first_b64).unwrap();
674 let client_first_str = String::from_utf8(client_first_decoded).unwrap();
675 assert!(client_first_str.starts_with("n,,"));
676 assert!(client_first_str.contains(&format!("r={}", client_nonce)));
677
678 let (handshake, server_first_b64) =
680 server_first_message(username, &client_nonce, &credentials);
681
682 let server_first_decoded = BASE64.decode(&server_first_b64).unwrap();
684 let server_first_str = String::from_utf8(server_first_decoded).unwrap();
685 assert!(server_first_str.starts_with("r="));
686 assert!(server_first_str.contains(",s="));
687 assert!(server_first_str.contains(",i="));
688 assert!(server_first_str.contains(&client_nonce));
689
690 let (client_final_b64, expected_server_sig) =
692 client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
693
694 let client_final_decoded = BASE64.decode(&client_final_b64).unwrap();
696 let client_final_str = String::from_utf8(client_final_decoded).unwrap();
697 assert!(client_final_str.starts_with("c=biws,"));
698 assert!(client_final_str.contains(",p="));
699
700 let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
702 assert_eq!(server_sig, expected_server_sig);
703
704 let (wrong_final_b64, _) =
706 client_final_message("wrongpassword", &client_nonce, &server_first_b64, username)
707 .unwrap();
708 let result = server_verify_final(&handshake, &wrong_final_b64);
709 assert!(result.is_err());
710 match result {
711 Err(AuthError::InvalidCredentials) => {} other => panic!("expected InvalidCredentials, got {:?}", other),
713 }
714 }
715
716 #[test]
717 fn test_format_www_authenticate() {
718 let result = format_www_authenticate("tok123", "SHA-256", "c29tZQ==");
719 assert_eq!(
720 result,
721 "SCRAM handshakeToken=tok123, hash=SHA-256, data=c29tZQ=="
722 );
723 }
724
725 #[test]
726 fn test_format_auth_info() {
727 let result = format_auth_info("auth-tok", "ZGF0YQ==");
728 assert_eq!(result, "authToken=auth-tok, data=ZGF0YQ==");
729 }
730}