Skip to main content

haystack_core/
auth.rs

1//! SCRAM SHA-256 authentication primitives for the Haystack auth protocol.
2//!
3//! This module implements the cryptographic operations needed for SCRAM
4//! (Salted Challenge Response Authentication Mechanism) with SHA-256 as
5//! specified by the [Project Haystack auth spec](https://project-haystack.org/doc/docHaystack/Auth).
6//!
7//! It provides functions shared by both server and client implementations
8//! for the three-phase handshake: HELLO, SCRAM challenge/response, and
9//! BEARER token issuance.
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use pbkdf2::pbkdf2_hmac;
15use rand::RngExt;
16use sha2::{Digest, Sha256};
17use subtle::ConstantTimeEq;
18use zeroize::Zeroize;
19
20type HmacSha256 = Hmac<Sha256>;
21
22/// Default PBKDF2 iteration count for SCRAM SHA-256.
23pub const DEFAULT_ITERATIONS: u32 = 100_000;
24
25/// Maximum PBKDF2 iteration count accepted from a server during client auth.
26pub const MAX_CLIENT_ITERATIONS: u32 = 1_000_000;
27
28// ---------------------------------------------------------------------------
29// Error type
30// ---------------------------------------------------------------------------
31
32/// Errors that can occur during SCRAM authentication.
33#[derive(Debug, thiserror::Error)]
34pub enum AuthError {
35    #[error("invalid credentials")]
36    InvalidCredentials,
37    #[error("invalid auth header: {0}")]
38    InvalidHeader(String),
39    #[error("handshake failed: {0}")]
40    HandshakeFailed(String),
41    #[error("invalid message: {0}")]
42    InvalidMessage(String),
43    #[error("base64 decode error: {0}")]
44    Base64Error(String),
45}
46
47// ---------------------------------------------------------------------------
48// Types
49// ---------------------------------------------------------------------------
50
51/// Pre-computed SCRAM credentials for a user (stored server-side).
52#[derive(Debug)]
53pub struct ScramCredentials {
54    pub salt: Vec<u8>,
55    pub iterations: u32,
56    pub stored_key: Vec<u8>,
57    pub server_key: Vec<u8>,
58}
59
60impl Drop for ScramCredentials {
61    fn drop(&mut self) {
62        self.stored_key.zeroize();
63        self.server_key.zeroize();
64    }
65}
66
67/// In-flight SCRAM handshake state held by the server between the
68/// server-first-message and client-final-message exchanges.
69pub struct ScramHandshake {
70    pub username: String,
71    pub client_nonce: String,
72    pub server_nonce: String,
73    pub salt: Vec<u8>,
74    pub iterations: u32,
75    pub auth_message: String,
76    pub server_signature: Vec<u8>,
77    /// Stored key from credentials, needed to verify the client proof.
78    stored_key: Vec<u8>,
79}
80
81impl std::fmt::Debug for ScramHandshake {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("ScramHandshake")
84            .field("username", &self.username)
85            .field("stored_key", &"[REDACTED]")
86            .field("server_signature", &"[REDACTED]")
87            .finish()
88    }
89}
90
91/// Parsed Haystack `Authorization` header.
92#[derive(Debug, Clone, PartialEq, Eq)]
93pub enum AuthHeader {
94    Hello {
95        username: String,
96        /// Base64-encoded client-first-message (contains the client nonce).
97        data: Option<String>,
98    },
99    Scram {
100        handshake_token: String,
101        data: String,
102    },
103    Bearer {
104        auth_token: String,
105    },
106}
107
108// ---------------------------------------------------------------------------
109// Internal helpers
110// ---------------------------------------------------------------------------
111
112/// Compute HMAC-SHA-256(key, msg).
113fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
114    // HMAC-SHA256 accepts keys of any size per RFC 2104; this cannot fail.
115    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts keys of any size");
116    mac.update(msg);
117    mac.finalize().into_bytes().to_vec()
118}
119
120/// Compute SHA-256(data).
121fn sha256(data: &[u8]) -> Vec<u8> {
122    let mut hasher = Sha256::new();
123    hasher.update(data);
124    hasher.finalize().to_vec()
125}
126
127/// XOR two equal-length byte slices.
128fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
129    // Both operands are always 32-byte SHA-256 outputs in SCRAM.
130    // A length mismatch here indicates a programming bug, not a runtime condition.
131    debug_assert_eq!(a.len(), b.len(), "XOR operands must be same length");
132    a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
133}
134
135/// PBKDF2-HMAC-SHA-256 key derivation, producing a 32-byte salted password.
136fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
137    let mut salted_password = vec![0u8; 32];
138    pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut salted_password);
139    salted_password
140    // Note: caller is responsible for zeroizing via Zeroize trait on Vec<u8>
141}
142
143/// Derive (ClientKey, StoredKey, ServerKey) from a salted password.
144fn derive_keys(salted_password: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
145    let client_key = hmac_sha256(salted_password, b"Client Key");
146    let stored_key = sha256(&client_key);
147    let server_key = hmac_sha256(salted_password, b"Server Key");
148    (client_key, stored_key, server_key)
149}
150
151/// Parse a `key=value` parameter from a SCRAM message segment.
152fn parse_scram_param<'a>(segment: &'a str, prefix: &str) -> Result<&'a str, AuthError> {
153    let trimmed = segment.trim();
154    trimmed.strip_prefix(prefix).ok_or_else(|| {
155        AuthError::HandshakeFailed(format!(
156            "expected prefix '{}' but got '{}'",
157            prefix, trimmed
158        ))
159    })
160}
161
162/// Build the client-first-message-bare: `n=<username>,r=<client_nonce>`.
163fn make_client_first_bare(username: &str, client_nonce: &str) -> String {
164    format!("n={},r={}", username, client_nonce)
165}
166
167// ---------------------------------------------------------------------------
168// Public API
169// ---------------------------------------------------------------------------
170
171/// Derive SCRAM credentials from a password (for user creation/storage).
172///
173/// Uses PBKDF2-HMAC-SHA-256 with the given salt and iteration count.
174pub fn derive_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
175    let mut salted_password = pbkdf2_sha256(password.as_bytes(), salt, iterations);
176    let (mut _client_key, stored_key, server_key) = derive_keys(&salted_password);
177    salted_password.zeroize();
178    _client_key.zeroize();
179    ScramCredentials {
180        salt: salt.to_vec(),
181        iterations,
182        stored_key,
183        server_key,
184    }
185}
186
187/// Generate a random nonce string (base64-encoded 18 random bytes).
188pub fn generate_nonce() -> String {
189    let mut bytes = [0u8; 18];
190    rand::rng().fill(&mut bytes);
191    BASE64.encode(bytes)
192}
193
194/// Client-side: Create the client-first-message data (base64-encoded).
195///
196/// Returns `(client_nonce, client_first_data_base64)`.
197///
198/// The client-first-message-bare is `n=<username>,r=<client_nonce>`.
199/// The full message prepends the GS2 header `n,,` (no channel binding).
200pub fn client_first_message(username: &str) -> (String, String) {
201    let client_nonce = generate_nonce();
202    let bare = make_client_first_bare(username, &client_nonce);
203    let full = format!("n,,{}", bare);
204    let encoded = BASE64.encode(full.as_bytes());
205    (client_nonce, encoded)
206}
207
208/// Server-side: Create the server-first-message data and handshake state.
209///
210/// `username` is taken from the HELLO phase. `client_nonce_b64` is the raw
211/// client nonce (as returned by [`client_first_message`]). `credentials` are
212/// the pre-computed SCRAM credentials for this user.
213///
214/// Returns `(handshake_state, server_first_data_base64)`.
215pub fn server_first_message(
216    username: &str,
217    client_nonce_b64: &str,
218    credentials: &ScramCredentials,
219) -> (ScramHandshake, String) {
220    let server_nonce = generate_nonce();
221    let combined_nonce = format!("{}{}", client_nonce_b64, server_nonce);
222    let salt_b64 = BASE64.encode(&credentials.salt);
223
224    // server-first-message: r=<combined>,s=<salt_b64>,i=<iterations>
225    let server_first_msg = format!(
226        "r={},s={},i={}",
227        combined_nonce, salt_b64, credentials.iterations
228    );
229
230    // client-first-message-bare (includes username per SCRAM spec)
231    let cfmb = make_client_first_bare(username, client_nonce_b64);
232
233    // client-final-message-without-proof (anticipated)
234    let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
235
236    // AuthMessage = client-first-bare "," server-first-msg "," client-final-without-proof
237    let auth_message = format!(
238        "{},{},{}",
239        cfmb, server_first_msg, client_final_without_proof
240    );
241
242    // Pre-compute server signature
243    let server_signature = hmac_sha256(&credentials.server_key, auth_message.as_bytes());
244
245    let server_first_b64 = BASE64.encode(server_first_msg.as_bytes());
246
247    let handshake = ScramHandshake {
248        username: username.to_string(),
249        client_nonce: client_nonce_b64.to_string(),
250        server_nonce,
251        salt: credentials.salt.clone(),
252        iterations: credentials.iterations,
253        auth_message,
254        server_signature,
255        stored_key: credentials.stored_key.clone(),
256    };
257
258    (handshake, server_first_b64)
259}
260
261/// Client-side: Process server-first-message, produce client-final-message.
262///
263/// `username` is the same value originally passed to [`client_first_message`].
264/// `password` is the user's plaintext password. `client_nonce` is the nonce
265/// returned by [`client_first_message`]. `server_first_b64` is the base64
266/// server-first-message data received from the server.
267///
268/// Returns `(client_final_data_base64, expected_server_signature)`.
269pub fn client_final_message(
270    password: &str,
271    client_nonce: &str,
272    server_first_b64: &str,
273    username: &str,
274) -> Result<(String, Vec<u8>), AuthError> {
275    // Decode and parse server-first-message
276    let server_first_bytes = BASE64
277        .decode(server_first_b64)
278        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
279    let server_first_msg = String::from_utf8(server_first_bytes)
280        .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
281
282    // Expected format: r=<combined_nonce>,s=<salt_b64>,i=<iterations>
283    let parts: Vec<&str> = server_first_msg.splitn(3, ',').collect();
284    if parts.len() != 3 {
285        return Err(AuthError::HandshakeFailed(
286            "invalid server-first-message format".to_string(),
287        ));
288    }
289
290    let combined_nonce = parse_scram_param(parts[0], "r=")?;
291    let salt_b64 = parse_scram_param(parts[1], "s=")?;
292    let iterations_str = parse_scram_param(parts[2], "i=")?;
293
294    // The combined nonce must start with our client nonce
295    if !combined_nonce.starts_with(client_nonce) {
296        return Err(AuthError::HandshakeFailed(
297            "combined nonce does not start with client nonce".to_string(),
298        ));
299    }
300
301    let salt = BASE64
302        .decode(salt_b64)
303        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
304    let iterations: u32 = iterations_str
305        .parse()
306        .map_err(|e: std::num::ParseIntError| AuthError::HandshakeFailed(e.to_string()))?;
307
308    if iterations > MAX_CLIENT_ITERATIONS {
309        return Err(AuthError::InvalidMessage(format!(
310            "server requested {} PBKDF2 iterations, maximum allowed is {}",
311            iterations, MAX_CLIENT_ITERATIONS
312        )));
313    }
314
315    // Key derivation
316    let mut salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
317    let (mut client_key, stored_key, server_key) = derive_keys(&salted_password);
318    salted_password.zeroize();
319
320    // Build AuthMessage
321    let cfmb = make_client_first_bare(username, client_nonce);
322    let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
323    let auth_message = format!(
324        "{},{},{}",
325        cfmb, server_first_msg, client_final_without_proof
326    );
327
328    // ClientSignature = HMAC(StoredKey, AuthMessage)
329    let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
330    // ClientProof = ClientKey XOR ClientSignature
331    let client_proof = xor_bytes(&client_key, &client_signature);
332    // ServerSignature = HMAC(ServerKey, AuthMessage)
333    let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
334
335    // Zeroize intermediate key material
336    client_key.zeroize();
337
338    // client-final-message: c=biws,r=<combined>,p=<proof_b64>
339    let proof_b64 = BASE64.encode(&client_proof);
340    let client_final_msg = format!("{},p={}", client_final_without_proof, proof_b64);
341    let client_final_b64 = BASE64.encode(client_final_msg.as_bytes());
342
343    Ok((client_final_b64, server_signature))
344}
345
346/// Server-side: Verify client-final-message and produce server signature.
347///
348/// Decodes the client-final-message, verifies the client proof against the
349/// stored key in the handshake state, and returns the server signature for
350/// the client to verify (sent as the `v=` field in server-final-message).
351pub fn server_verify_final(
352    handshake: &ScramHandshake,
353    client_final_b64: &str,
354) -> Result<Vec<u8>, AuthError> {
355    // Decode client-final-message
356    let client_final_bytes = BASE64
357        .decode(client_final_b64)
358        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
359    let client_final_msg = String::from_utf8(client_final_bytes)
360        .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
361
362    // Expected format: c=biws,r=<combined_nonce>,p=<proof_b64>
363    let parts: Vec<&str> = client_final_msg.splitn(3, ',').collect();
364    if parts.len() != 3 {
365        return Err(AuthError::HandshakeFailed(
366            "invalid client-final-message format".to_string(),
367        ));
368    }
369
370    // Validate channel binding
371    let channel_binding = parse_scram_param(parts[0], "c=")?;
372    if channel_binding != "biws" {
373        return Err(AuthError::HandshakeFailed(
374            "unexpected channel binding".to_string(),
375        ));
376    }
377
378    // Validate combined nonce
379    let combined_nonce = parse_scram_param(parts[1], "r=")?;
380    let expected_combined = format!("{}{}", handshake.client_nonce, handshake.server_nonce);
381    if !bool::from(
382        combined_nonce
383            .as_bytes()
384            .ct_eq(expected_combined.as_bytes()),
385    ) {
386        return Err(AuthError::HandshakeFailed("nonce mismatch".to_string()));
387    }
388
389    // Extract and decode client proof
390    let proof_b64 = parse_scram_param(parts[2], "p=")?;
391    let client_proof = BASE64
392        .decode(proof_b64)
393        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
394
395    // Verify the proof per RFC 5802:
396    //   ClientSignature = HMAC(StoredKey, AuthMessage)
397    //   RecoveredClientKey = ClientProof XOR ClientSignature
398    //   Check: SHA-256(RecoveredClientKey) == StoredKey
399    let client_signature = hmac_sha256(&handshake.stored_key, handshake.auth_message.as_bytes());
400    let recovered_client_key = xor_bytes(&client_proof, &client_signature);
401    let recovered_stored_key = sha256(&recovered_client_key);
402
403    if recovered_stored_key
404        .ct_eq(&handshake.stored_key)
405        .unwrap_u8()
406        == 0
407    {
408        return Err(AuthError::InvalidCredentials);
409    }
410
411    // Proof verified -- return server signature for the client to verify
412    Ok(handshake.server_signature.clone())
413}
414
415/// Extract the client nonce from a base64-encoded client-first-message.
416///
417/// The client-first-message format is `n,,n=<username>,r=<client_nonce>`.
418/// Returns the raw nonce string.
419pub fn extract_client_nonce(client_first_b64: &str) -> Result<String, AuthError> {
420    let bytes = BASE64
421        .decode(client_first_b64)
422        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
423    let msg = String::from_utf8(bytes).map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
424    // Strip GS2 header "n,," prefix
425    let bare = msg
426        .strip_prefix("n,,")
427        .ok_or_else(|| AuthError::HandshakeFailed("missing GS2 header in client-first".into()))?;
428    // Parse n=<user>,r=<nonce>
429    for part in bare.split(',') {
430        if let Some(nonce) = part.strip_prefix("r=") {
431            return Ok(nonce.to_string());
432        }
433    }
434    Err(AuthError::HandshakeFailed(
435        "missing r= nonce in client-first-message".into(),
436    ))
437}
438
439/// Parse a Haystack `Authorization` header value.
440///
441/// Supported formats:
442/// - `HELLO username=<base64(username)>`
443/// - `SCRAM handshakeToken=<token>, data=<data>`
444/// - `BEARER authToken=<token>`
445pub fn parse_auth_header(header: &str) -> Result<AuthHeader, AuthError> {
446    let header = header.trim();
447
448    if let Some(rest) = header.strip_prefix("HELLO ") {
449        let mut username_b64_val = None;
450        let mut data_val = None;
451        for part in rest.split(',') {
452            let part = part.trim();
453            if let Some(val) = part.strip_prefix("username=") {
454                username_b64_val = Some(val.trim().to_string());
455            } else if let Some(val) = part.strip_prefix("data=") {
456                data_val = Some(val.trim().to_string());
457            }
458        }
459        let username_b64 = username_b64_val
460            .ok_or_else(|| AuthError::InvalidHeader("missing username= in HELLO".into()))?;
461        let username_bytes = BASE64
462            .decode(&username_b64)
463            .map_err(|e| AuthError::Base64Error(e.to_string()))?;
464        let username = String::from_utf8(username_bytes)
465            .map_err(|e| AuthError::InvalidHeader(e.to_string()))?;
466        Ok(AuthHeader::Hello {
467            username,
468            data: data_val,
469        })
470    } else if let Some(rest) = header.strip_prefix("SCRAM ") {
471        let mut handshake_token = None;
472        let mut data = None;
473        for part in rest.split(',') {
474            let part = part.trim();
475            if let Some(val) = part.strip_prefix("handshakeToken=") {
476                handshake_token = Some(val.trim().to_string());
477            } else if let Some(val) = part.strip_prefix("data=") {
478                data = Some(val.trim().to_string());
479            }
480        }
481        let handshake_token = handshake_token
482            .ok_or_else(|| AuthError::InvalidHeader("missing handshakeToken= in SCRAM".into()))?;
483        let data = data.ok_or_else(|| AuthError::InvalidHeader("missing data= in SCRAM".into()))?;
484        Ok(AuthHeader::Scram {
485            handshake_token,
486            data,
487        })
488    } else if let Some(rest) = header.strip_prefix("BEARER ") {
489        let token = rest
490            .trim()
491            .strip_prefix("authToken=")
492            .ok_or_else(|| AuthError::InvalidHeader("missing authToken= in BEARER".into()))?;
493        Ok(AuthHeader::Bearer {
494            auth_token: token.trim().to_string(),
495        })
496    } else {
497        Err(AuthError::InvalidHeader(format!(
498            "unrecognized auth scheme: {}",
499            header
500        )))
501    }
502}
503
504/// Format a Haystack `WWW-Authenticate` header for a SCRAM challenge.
505///
506/// Produces: `SCRAM handshakeToken=<token>, hash=<hash>, data=<data_b64>`
507pub fn format_www_authenticate(handshake_token: &str, hash: &str, data_b64: &str) -> String {
508    format!(
509        "SCRAM handshakeToken={}, hash={}, data={}",
510        handshake_token, hash, data_b64
511    )
512}
513
514/// Format a Haystack `Authentication-Info` header with the auth token.
515///
516/// Produces: `authToken=<token>, data=<data_b64>`
517pub fn format_auth_info(auth_token: &str, data_b64: &str) -> String {
518    format!("authToken={}, data={}", auth_token, data_b64)
519}
520
521// ---------------------------------------------------------------------------
522// Tests
523// ---------------------------------------------------------------------------
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[test]
530    fn test_derive_credentials() {
531        let password = "pencil";
532        let salt = b"random-salt-value";
533        let iterations = 4096;
534
535        let creds = derive_credentials(password, salt, iterations);
536
537        // Fields are populated correctly
538        assert_eq!(creds.salt, salt.to_vec());
539        assert_eq!(creds.iterations, iterations);
540        assert_eq!(creds.stored_key.len(), 32); // SHA-256 output length
541        assert_eq!(creds.server_key.len(), 32);
542
543        // Deterministic: same inputs produce same outputs
544        let creds2 = derive_credentials(password, salt, iterations);
545        assert_eq!(creds.stored_key, creds2.stored_key);
546        assert_eq!(creds.server_key, creds2.server_key);
547
548        // Different password yields different credentials
549        let creds3 = derive_credentials("other", salt, iterations);
550        assert_ne!(creds.stored_key, creds3.stored_key);
551        assert_ne!(creds.server_key, creds3.server_key);
552    }
553
554    #[test]
555    fn test_generate_nonce() {
556        let n1 = generate_nonce();
557        let n2 = generate_nonce();
558
559        // Each call produces a unique nonce
560        assert_ne!(n1, n2);
561
562        // Valid base64 encoding of 18 bytes
563        let decoded1 = BASE64.decode(&n1).expect("nonce must be valid base64");
564        assert_eq!(decoded1.len(), 18);
565
566        let decoded2 = BASE64.decode(&n2).expect("nonce must be valid base64");
567        assert_eq!(decoded2.len(), 18);
568    }
569
570    #[test]
571    fn test_parse_auth_header_hello() {
572        let username = "user";
573        let username_b64 = BASE64.encode(username.as_bytes());
574        let header = format!("HELLO username={}", username_b64);
575
576        let parsed = parse_auth_header(&header).unwrap();
577        assert_eq!(
578            parsed,
579            AuthHeader::Hello {
580                username: "user".to_string(),
581                data: None,
582            }
583        );
584    }
585
586    #[test]
587    fn test_parse_auth_header_scram() {
588        let header = "SCRAM handshakeToken=abc123, data=c29tZWRhdGE=";
589        let parsed = parse_auth_header(header).unwrap();
590        assert_eq!(
591            parsed,
592            AuthHeader::Scram {
593                handshake_token: "abc123".to_string(),
594                data: "c29tZWRhdGE=".to_string(),
595            }
596        );
597    }
598
599    #[test]
600    fn test_parse_auth_header_bearer() {
601        let header = "BEARER authToken=mytoken123";
602        let parsed = parse_auth_header(header).unwrap();
603        assert_eq!(
604            parsed,
605            AuthHeader::Bearer {
606                auth_token: "mytoken123".to_string(),
607            }
608        );
609    }
610
611    #[test]
612    fn test_parse_auth_header_invalid() {
613        // Unknown scheme
614        assert!(parse_auth_header("UNKNOWN foo=bar").is_err());
615        // HELLO missing username=
616        assert!(parse_auth_header("HELLO foo=bar").is_err());
617        // SCRAM missing data=
618        assert!(parse_auth_header("SCRAM handshakeToken=abc").is_err());
619        // BEARER missing authToken=
620        assert!(parse_auth_header("BEARER token=abc").is_err());
621        // Empty
622        assert!(parse_auth_header("").is_err());
623    }
624
625    #[test]
626    fn test_full_handshake() {
627        // Simulate the complete HELLO -> SCRAM -> BEARER flow.
628        let username = "testuser";
629        let password = "s3cret";
630        let salt = b"test-salt-12345";
631        let iterations = 4096;
632
633        // --- Server: pre-compute credentials (user registration) ---
634        let credentials = derive_credentials(password, salt, iterations);
635
636        // --- Client: HELLO phase ---
637        let username_b64 = BASE64.encode(username.as_bytes());
638        let hello_header = format!("HELLO username={}", username_b64);
639        let parsed = parse_auth_header(&hello_header).unwrap();
640        match &parsed {
641            AuthHeader::Hello { username: u, .. } => assert_eq!(u, username),
642            _ => panic!("expected Hello variant"),
643        }
644
645        // --- Client: generate client-first-message ---
646        let (client_nonce, _client_first_b64) = client_first_message(username);
647
648        // --- Server: generate server-first-message ---
649        let (handshake, server_first_b64) =
650            server_first_message(username, &client_nonce, &credentials);
651
652        // --- Server: format WWW-Authenticate header ---
653        let www_auth = format_www_authenticate("handshake-token-xyz", "SHA-256", &server_first_b64);
654        assert!(www_auth.contains("SCRAM"));
655        assert!(www_auth.contains("SHA-256"));
656        assert!(www_auth.contains("handshake-token-xyz"));
657
658        // --- Client: process server-first, produce client-final ---
659        let (client_final_b64, expected_server_sig) =
660            client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
661
662        // --- Server: verify client-final ---
663        let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
664
665        // Server signature should match what the client expects
666        assert_eq!(server_sig, expected_server_sig);
667
668        // --- Server: format Authentication-Info header ---
669        let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
670        let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
671        let auth_info = format_auth_info("auth-token-abc", &server_final_b64);
672        assert!(auth_info.contains("authToken=auth-token-abc"));
673
674        // --- Client: verify server signature from server-final ---
675        let server_final_decoded = BASE64.decode(&server_final_b64).unwrap();
676        let server_final_str = String::from_utf8(server_final_decoded).unwrap();
677        let sig_b64 = server_final_str.strip_prefix("v=").unwrap();
678        let received_server_sig = BASE64.decode(sig_b64).unwrap();
679        assert_eq!(received_server_sig, expected_server_sig);
680    }
681
682    #[test]
683    fn test_client_server_roundtrip() {
684        // Full roundtrip using the public API functions.
685        let username = "admin";
686        let password = "correcthorsebatterystaple";
687        let salt = b"unique-salt-value";
688        let iterations = DEFAULT_ITERATIONS;
689
690        // 1. Server: create credentials during user registration
691        let credentials = derive_credentials(password, salt, iterations);
692
693        // 2. Client: create client-first-message
694        let (client_nonce, client_first_b64) = client_first_message(username);
695
696        // Verify client-first is valid base64 and well-formed
697        let client_first_decoded = BASE64.decode(&client_first_b64).unwrap();
698        let client_first_str = String::from_utf8(client_first_decoded).unwrap();
699        assert!(client_first_str.starts_with("n,,"));
700        assert!(client_first_str.contains(&format!("r={}", client_nonce)));
701
702        // 3. Server: create server-first-message
703        let (handshake, server_first_b64) =
704            server_first_message(username, &client_nonce, &credentials);
705
706        // Verify server-first contains expected SCRAM fields
707        let server_first_decoded = BASE64.decode(&server_first_b64).unwrap();
708        let server_first_str = String::from_utf8(server_first_decoded).unwrap();
709        assert!(server_first_str.starts_with("r="));
710        assert!(server_first_str.contains(",s="));
711        assert!(server_first_str.contains(",i="));
712        assert!(server_first_str.contains(&client_nonce));
713
714        // 4. Client: create client-final-message
715        let (client_final_b64, expected_server_sig) =
716            client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
717
718        // Verify client-final structure
719        let client_final_decoded = BASE64.decode(&client_final_b64).unwrap();
720        let client_final_str = String::from_utf8(client_final_decoded).unwrap();
721        assert!(client_final_str.starts_with("c=biws,"));
722        assert!(client_final_str.contains(",p="));
723
724        // 5. Server: verify and get server signature
725        let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
726        assert_eq!(server_sig, expected_server_sig);
727
728        // 6. Wrong password: server rejects the proof
729        let (wrong_final_b64, _) =
730            client_final_message("wrongpassword", &client_nonce, &server_first_b64, username)
731                .unwrap();
732        let result = server_verify_final(&handshake, &wrong_final_b64);
733        assert!(result.is_err());
734        match result {
735            Err(AuthError::InvalidCredentials) => {} // expected
736            other => panic!("expected InvalidCredentials, got {:?}", other),
737        }
738    }
739
740    #[test]
741    fn test_format_www_authenticate() {
742        let result = format_www_authenticate("tok123", "SHA-256", "c29tZQ==");
743        assert_eq!(
744            result,
745            "SCRAM handshakeToken=tok123, hash=SHA-256, data=c29tZQ=="
746        );
747    }
748
749    #[test]
750    fn test_format_auth_info() {
751        let result = format_auth_info("auth-tok", "ZGF0YQ==");
752        assert_eq!(result, "authToken=auth-tok, data=ZGF0YQ==");
753    }
754}