1use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use pbkdf2::pbkdf2_hmac;
15use rand::RngExt;
16use sha2::{Digest, Sha256};
17use subtle::ConstantTimeEq;
18
19type HmacSha256 = Hmac<Sha256>;
20
21pub const DEFAULT_ITERATIONS: u32 = 100_000;
23
24#[derive(Debug, thiserror::Error)]
30pub enum AuthError {
31 #[error("invalid credentials")]
32 InvalidCredentials,
33 #[error("invalid auth header: {0}")]
34 InvalidHeader(String),
35 #[error("handshake failed: {0}")]
36 HandshakeFailed(String),
37 #[error("base64 decode error: {0}")]
38 Base64Error(String),
39}
40
41#[derive(Debug, Clone)]
47pub struct ScramCredentials {
48 pub salt: Vec<u8>,
49 pub iterations: u32,
50 pub stored_key: Vec<u8>,
51 pub server_key: Vec<u8>,
52}
53
54#[derive(Debug, Clone)]
57pub struct ScramHandshake {
58 pub username: String,
59 pub client_nonce: String,
60 pub server_nonce: String,
61 pub salt: Vec<u8>,
62 pub iterations: u32,
63 pub auth_message: String,
64 pub server_signature: Vec<u8>,
65 stored_key: Vec<u8>,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum AuthHeader {
72 Hello {
73 username: String,
74 data: Option<String>,
76 },
77 Scram {
78 handshake_token: String,
79 data: String,
80 },
81 Bearer {
82 auth_token: String,
83 },
84}
85
86fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
92 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts keys of any size");
93 mac.update(msg);
94 mac.finalize().into_bytes().to_vec()
95}
96
97fn sha256(data: &[u8]) -> Vec<u8> {
99 let mut hasher = Sha256::new();
100 hasher.update(data);
101 hasher.finalize().to_vec()
102}
103
104fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
106 assert_eq!(a.len(), b.len(), "XOR operands must be the same length");
107 a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
108}
109
110fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
112 let mut salted_password = vec![0u8; 32];
113 pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut salted_password);
114 salted_password
115}
116
117fn derive_keys(salted_password: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
119 let client_key = hmac_sha256(salted_password, b"Client Key");
120 let stored_key = sha256(&client_key);
121 let server_key = hmac_sha256(salted_password, b"Server Key");
122 (client_key, stored_key, server_key)
123}
124
125fn parse_scram_param<'a>(segment: &'a str, prefix: &str) -> Result<&'a str, AuthError> {
127 let trimmed = segment.trim();
128 trimmed.strip_prefix(prefix).ok_or_else(|| {
129 AuthError::HandshakeFailed(format!(
130 "expected prefix '{}' but got '{}'",
131 prefix, trimmed
132 ))
133 })
134}
135
136fn make_client_first_bare(username: &str, client_nonce: &str) -> String {
138 format!("n={},r={}", username, client_nonce)
139}
140
141pub fn derive_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
149 let salted_password = pbkdf2_sha256(password.as_bytes(), salt, iterations);
150 let (_client_key, stored_key, server_key) = derive_keys(&salted_password);
151 ScramCredentials {
152 salt: salt.to_vec(),
153 iterations,
154 stored_key,
155 server_key,
156 }
157}
158
159pub fn generate_nonce() -> String {
161 let mut bytes = [0u8; 18];
162 rand::rng().fill(&mut bytes);
163 BASE64.encode(bytes)
164}
165
166pub fn client_first_message(username: &str) -> (String, String) {
173 let client_nonce = generate_nonce();
174 let bare = make_client_first_bare(username, &client_nonce);
175 let full = format!("n,,{}", bare);
176 let encoded = BASE64.encode(full.as_bytes());
177 (client_nonce, encoded)
178}
179
180pub fn server_first_message(
188 username: &str,
189 client_nonce_b64: &str,
190 credentials: &ScramCredentials,
191) -> (ScramHandshake, String) {
192 let server_nonce = generate_nonce();
193 let combined_nonce = format!("{}{}", client_nonce_b64, server_nonce);
194 let salt_b64 = BASE64.encode(&credentials.salt);
195
196 let server_first_msg = format!(
198 "r={},s={},i={}",
199 combined_nonce, salt_b64, credentials.iterations
200 );
201
202 let cfmb = make_client_first_bare(username, client_nonce_b64);
204
205 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
207
208 let auth_message = format!(
210 "{},{},{}",
211 cfmb, server_first_msg, client_final_without_proof
212 );
213
214 let server_signature = hmac_sha256(&credentials.server_key, auth_message.as_bytes());
216
217 let server_first_b64 = BASE64.encode(server_first_msg.as_bytes());
218
219 let handshake = ScramHandshake {
220 username: username.to_string(),
221 client_nonce: client_nonce_b64.to_string(),
222 server_nonce,
223 salt: credentials.salt.clone(),
224 iterations: credentials.iterations,
225 auth_message,
226 server_signature,
227 stored_key: credentials.stored_key.clone(),
228 };
229
230 (handshake, server_first_b64)
231}
232
233pub fn client_final_message(
242 password: &str,
243 client_nonce: &str,
244 server_first_b64: &str,
245 username: &str,
246) -> Result<(String, Vec<u8>), AuthError> {
247 let server_first_bytes = BASE64
249 .decode(server_first_b64)
250 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
251 let server_first_msg = String::from_utf8(server_first_bytes)
252 .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
253
254 let parts: Vec<&str> = server_first_msg.splitn(3, ',').collect();
256 if parts.len() != 3 {
257 return Err(AuthError::HandshakeFailed(
258 "invalid server-first-message format".to_string(),
259 ));
260 }
261
262 let combined_nonce = parse_scram_param(parts[0], "r=")?;
263 let salt_b64 = parse_scram_param(parts[1], "s=")?;
264 let iterations_str = parse_scram_param(parts[2], "i=")?;
265
266 if !combined_nonce.starts_with(client_nonce) {
268 return Err(AuthError::HandshakeFailed(
269 "combined nonce does not start with client nonce".to_string(),
270 ));
271 }
272
273 let salt = BASE64
274 .decode(salt_b64)
275 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
276 let iterations: u32 = iterations_str
277 .parse()
278 .map_err(|e: std::num::ParseIntError| AuthError::HandshakeFailed(e.to_string()))?;
279
280 let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
282 let (client_key, stored_key, server_key) = derive_keys(&salted_password);
283
284 let cfmb = make_client_first_bare(username, client_nonce);
286 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
287 let auth_message = format!(
288 "{},{},{}",
289 cfmb, server_first_msg, client_final_without_proof
290 );
291
292 let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
294 let client_proof = xor_bytes(&client_key, &client_signature);
296 let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
298
299 let proof_b64 = BASE64.encode(&client_proof);
301 let client_final_msg = format!("{},p={}", client_final_without_proof, proof_b64);
302 let client_final_b64 = BASE64.encode(client_final_msg.as_bytes());
303
304 Ok((client_final_b64, server_signature))
305}
306
307pub fn server_verify_final(
313 handshake: &ScramHandshake,
314 client_final_b64: &str,
315) -> Result<Vec<u8>, AuthError> {
316 let client_final_bytes = BASE64
318 .decode(client_final_b64)
319 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
320 let client_final_msg = String::from_utf8(client_final_bytes)
321 .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
322
323 let parts: Vec<&str> = client_final_msg.splitn(3, ',').collect();
325 if parts.len() != 3 {
326 return Err(AuthError::HandshakeFailed(
327 "invalid client-final-message format".to_string(),
328 ));
329 }
330
331 let channel_binding = parse_scram_param(parts[0], "c=")?;
333 if channel_binding != "biws" {
334 return Err(AuthError::HandshakeFailed(
335 "unexpected channel binding".to_string(),
336 ));
337 }
338
339 let combined_nonce = parse_scram_param(parts[1], "r=")?;
341 let expected_combined = format!("{}{}", handshake.client_nonce, handshake.server_nonce);
342 if !bool::from(
343 combined_nonce
344 .as_bytes()
345 .ct_eq(expected_combined.as_bytes()),
346 ) {
347 return Err(AuthError::HandshakeFailed("nonce mismatch".to_string()));
348 }
349
350 let proof_b64 = parse_scram_param(parts[2], "p=")?;
352 let client_proof = BASE64
353 .decode(proof_b64)
354 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
355
356 let client_signature = hmac_sha256(&handshake.stored_key, handshake.auth_message.as_bytes());
361 let recovered_client_key = xor_bytes(&client_proof, &client_signature);
362 let recovered_stored_key = sha256(&recovered_client_key);
363
364 if recovered_stored_key
365 .ct_eq(&handshake.stored_key)
366 .unwrap_u8()
367 == 0
368 {
369 return Err(AuthError::InvalidCredentials);
370 }
371
372 Ok(handshake.server_signature.clone())
374}
375
376pub fn extract_client_nonce(client_first_b64: &str) -> Result<String, AuthError> {
381 let bytes = BASE64
382 .decode(client_first_b64)
383 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
384 let msg = String::from_utf8(bytes).map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
385 let bare = msg
387 .strip_prefix("n,,")
388 .ok_or_else(|| AuthError::HandshakeFailed("missing GS2 header in client-first".into()))?;
389 for part in bare.split(',') {
391 if let Some(nonce) = part.strip_prefix("r=") {
392 return Ok(nonce.to_string());
393 }
394 }
395 Err(AuthError::HandshakeFailed(
396 "missing r= nonce in client-first-message".into(),
397 ))
398}
399
400pub fn parse_auth_header(header: &str) -> Result<AuthHeader, AuthError> {
407 let header = header.trim();
408
409 if let Some(rest) = header.strip_prefix("HELLO ") {
410 let mut username_b64_val = None;
411 let mut data_val = None;
412 for part in rest.split(',') {
413 let part = part.trim();
414 if let Some(val) = part.strip_prefix("username=") {
415 username_b64_val = Some(val.trim().to_string());
416 } else if let Some(val) = part.strip_prefix("data=") {
417 data_val = Some(val.trim().to_string());
418 }
419 }
420 let username_b64 = username_b64_val
421 .ok_or_else(|| AuthError::InvalidHeader("missing username= in HELLO".into()))?;
422 let username_bytes = BASE64
423 .decode(&username_b64)
424 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
425 let username = String::from_utf8(username_bytes)
426 .map_err(|e| AuthError::InvalidHeader(e.to_string()))?;
427 Ok(AuthHeader::Hello {
428 username,
429 data: data_val,
430 })
431 } else if let Some(rest) = header.strip_prefix("SCRAM ") {
432 let mut handshake_token = None;
433 let mut data = None;
434 for part in rest.split(',') {
435 let part = part.trim();
436 if let Some(val) = part.strip_prefix("handshakeToken=") {
437 handshake_token = Some(val.trim().to_string());
438 } else if let Some(val) = part.strip_prefix("data=") {
439 data = Some(val.trim().to_string());
440 }
441 }
442 let handshake_token = handshake_token
443 .ok_or_else(|| AuthError::InvalidHeader("missing handshakeToken= in SCRAM".into()))?;
444 let data = data.ok_or_else(|| AuthError::InvalidHeader("missing data= in SCRAM".into()))?;
445 Ok(AuthHeader::Scram {
446 handshake_token,
447 data,
448 })
449 } else if let Some(rest) = header.strip_prefix("BEARER ") {
450 let token = rest
451 .trim()
452 .strip_prefix("authToken=")
453 .ok_or_else(|| AuthError::InvalidHeader("missing authToken= in BEARER".into()))?;
454 Ok(AuthHeader::Bearer {
455 auth_token: token.trim().to_string(),
456 })
457 } else {
458 Err(AuthError::InvalidHeader(format!(
459 "unrecognized auth scheme: {}",
460 header
461 )))
462 }
463}
464
465pub fn format_www_authenticate(handshake_token: &str, hash: &str, data_b64: &str) -> String {
469 format!(
470 "SCRAM handshakeToken={}, hash={}, data={}",
471 handshake_token, hash, data_b64
472 )
473}
474
475pub fn format_auth_info(auth_token: &str, data_b64: &str) -> String {
479 format!("authToken={}, data={}", auth_token, data_b64)
480}
481
482#[cfg(test)]
487mod tests {
488 use super::*;
489
490 #[test]
491 fn test_derive_credentials() {
492 let password = "pencil";
493 let salt = b"random-salt-value";
494 let iterations = 4096;
495
496 let creds = derive_credentials(password, salt, iterations);
497
498 assert_eq!(creds.salt, salt.to_vec());
500 assert_eq!(creds.iterations, iterations);
501 assert_eq!(creds.stored_key.len(), 32); assert_eq!(creds.server_key.len(), 32);
503
504 let creds2 = derive_credentials(password, salt, iterations);
506 assert_eq!(creds.stored_key, creds2.stored_key);
507 assert_eq!(creds.server_key, creds2.server_key);
508
509 let creds3 = derive_credentials("other", salt, iterations);
511 assert_ne!(creds.stored_key, creds3.stored_key);
512 assert_ne!(creds.server_key, creds3.server_key);
513 }
514
515 #[test]
516 fn test_generate_nonce() {
517 let n1 = generate_nonce();
518 let n2 = generate_nonce();
519
520 assert_ne!(n1, n2);
522
523 let decoded1 = BASE64.decode(&n1).expect("nonce must be valid base64");
525 assert_eq!(decoded1.len(), 18);
526
527 let decoded2 = BASE64.decode(&n2).expect("nonce must be valid base64");
528 assert_eq!(decoded2.len(), 18);
529 }
530
531 #[test]
532 fn test_parse_auth_header_hello() {
533 let username = "user";
534 let username_b64 = BASE64.encode(username.as_bytes());
535 let header = format!("HELLO username={}", username_b64);
536
537 let parsed = parse_auth_header(&header).unwrap();
538 assert_eq!(
539 parsed,
540 AuthHeader::Hello {
541 username: "user".to_string(),
542 data: None,
543 }
544 );
545 }
546
547 #[test]
548 fn test_parse_auth_header_scram() {
549 let header = "SCRAM handshakeToken=abc123, data=c29tZWRhdGE=";
550 let parsed = parse_auth_header(header).unwrap();
551 assert_eq!(
552 parsed,
553 AuthHeader::Scram {
554 handshake_token: "abc123".to_string(),
555 data: "c29tZWRhdGE=".to_string(),
556 }
557 );
558 }
559
560 #[test]
561 fn test_parse_auth_header_bearer() {
562 let header = "BEARER authToken=mytoken123";
563 let parsed = parse_auth_header(header).unwrap();
564 assert_eq!(
565 parsed,
566 AuthHeader::Bearer {
567 auth_token: "mytoken123".to_string(),
568 }
569 );
570 }
571
572 #[test]
573 fn test_parse_auth_header_invalid() {
574 assert!(parse_auth_header("UNKNOWN foo=bar").is_err());
576 assert!(parse_auth_header("HELLO foo=bar").is_err());
578 assert!(parse_auth_header("SCRAM handshakeToken=abc").is_err());
580 assert!(parse_auth_header("BEARER token=abc").is_err());
582 assert!(parse_auth_header("").is_err());
584 }
585
586 #[test]
587 fn test_full_handshake() {
588 let username = "testuser";
590 let password = "s3cret";
591 let salt = b"test-salt-12345";
592 let iterations = 4096;
593
594 let credentials = derive_credentials(password, salt, iterations);
596
597 let username_b64 = BASE64.encode(username.as_bytes());
599 let hello_header = format!("HELLO username={}", username_b64);
600 let parsed = parse_auth_header(&hello_header).unwrap();
601 match &parsed {
602 AuthHeader::Hello { username: u, .. } => assert_eq!(u, username),
603 _ => panic!("expected Hello variant"),
604 }
605
606 let (client_nonce, _client_first_b64) = client_first_message(username);
608
609 let (handshake, server_first_b64) =
611 server_first_message(username, &client_nonce, &credentials);
612
613 let www_auth = format_www_authenticate("handshake-token-xyz", "SHA-256", &server_first_b64);
615 assert!(www_auth.contains("SCRAM"));
616 assert!(www_auth.contains("SHA-256"));
617 assert!(www_auth.contains("handshake-token-xyz"));
618
619 let (client_final_b64, expected_server_sig) =
621 client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
622
623 let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
625
626 assert_eq!(server_sig, expected_server_sig);
628
629 let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
631 let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
632 let auth_info = format_auth_info("auth-token-abc", &server_final_b64);
633 assert!(auth_info.contains("authToken=auth-token-abc"));
634
635 let server_final_decoded = BASE64.decode(&server_final_b64).unwrap();
637 let server_final_str = String::from_utf8(server_final_decoded).unwrap();
638 let sig_b64 = server_final_str.strip_prefix("v=").unwrap();
639 let received_server_sig = BASE64.decode(sig_b64).unwrap();
640 assert_eq!(received_server_sig, expected_server_sig);
641 }
642
643 #[test]
644 fn test_client_server_roundtrip() {
645 let username = "admin";
647 let password = "correcthorsebatterystaple";
648 let salt = b"unique-salt-value";
649 let iterations = DEFAULT_ITERATIONS;
650
651 let credentials = derive_credentials(password, salt, iterations);
653
654 let (client_nonce, client_first_b64) = client_first_message(username);
656
657 let client_first_decoded = BASE64.decode(&client_first_b64).unwrap();
659 let client_first_str = String::from_utf8(client_first_decoded).unwrap();
660 assert!(client_first_str.starts_with("n,,"));
661 assert!(client_first_str.contains(&format!("r={}", client_nonce)));
662
663 let (handshake, server_first_b64) =
665 server_first_message(username, &client_nonce, &credentials);
666
667 let server_first_decoded = BASE64.decode(&server_first_b64).unwrap();
669 let server_first_str = String::from_utf8(server_first_decoded).unwrap();
670 assert!(server_first_str.starts_with("r="));
671 assert!(server_first_str.contains(",s="));
672 assert!(server_first_str.contains(",i="));
673 assert!(server_first_str.contains(&client_nonce));
674
675 let (client_final_b64, expected_server_sig) =
677 client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
678
679 let client_final_decoded = BASE64.decode(&client_final_b64).unwrap();
681 let client_final_str = String::from_utf8(client_final_decoded).unwrap();
682 assert!(client_final_str.starts_with("c=biws,"));
683 assert!(client_final_str.contains(",p="));
684
685 let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
687 assert_eq!(server_sig, expected_server_sig);
688
689 let (wrong_final_b64, _) =
691 client_final_message("wrongpassword", &client_nonce, &server_first_b64, username)
692 .unwrap();
693 let result = server_verify_final(&handshake, &wrong_final_b64);
694 assert!(result.is_err());
695 match result {
696 Err(AuthError::InvalidCredentials) => {} other => panic!("expected InvalidCredentials, got {:?}", other),
698 }
699 }
700
701 #[test]
702 fn test_format_www_authenticate() {
703 let result = format_www_authenticate("tok123", "SHA-256", "c29tZQ==");
704 assert_eq!(
705 result,
706 "SCRAM handshakeToken=tok123, hash=SHA-256, data=c29tZQ=="
707 );
708 }
709
710 #[test]
711 fn test_format_auth_info() {
712 let result = format_auth_info("auth-tok", "ZGF0YQ==");
713 assert_eq!(result, "authToken=auth-tok, data=ZGF0YQ==");
714 }
715}