Skip to main content

haystack_core/
auth.rs

1//! SCRAM SHA-256 authentication primitives for the Haystack auth protocol.
2//!
3//! This module implements the cryptographic operations needed for SCRAM
4//! (Salted Challenge Response Authentication Mechanism) with SHA-256 as
5//! specified by the [Project Haystack auth spec](https://project-haystack.org/doc/docHaystack/Auth).
6//!
7//! It provides functions shared by both server and client implementations
8//! for the three-phase handshake: HELLO, SCRAM challenge/response, and
9//! BEARER token issuance.
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use pbkdf2::pbkdf2_hmac;
15use rand::RngExt;
16use sha2::{Digest, Sha256};
17use subtle::ConstantTimeEq;
18use zeroize::Zeroize;
19
20type HmacSha256 = Hmac<Sha256>;
21
22/// Default PBKDF2 iteration count for SCRAM SHA-256.
23pub const DEFAULT_ITERATIONS: u32 = 100_000;
24
25// ---------------------------------------------------------------------------
26// Error type
27// ---------------------------------------------------------------------------
28
29/// Errors that can occur during SCRAM authentication.
30#[derive(Debug, thiserror::Error)]
31pub enum AuthError {
32    #[error("invalid credentials")]
33    InvalidCredentials,
34    #[error("invalid auth header: {0}")]
35    InvalidHeader(String),
36    #[error("handshake failed: {0}")]
37    HandshakeFailed(String),
38    #[error("base64 decode error: {0}")]
39    Base64Error(String),
40}
41
42// ---------------------------------------------------------------------------
43// Types
44// ---------------------------------------------------------------------------
45
46/// Pre-computed SCRAM credentials for a user (stored server-side).
47#[derive(Debug, Clone)]
48pub struct ScramCredentials {
49    pub salt: Vec<u8>,
50    pub iterations: u32,
51    pub stored_key: Vec<u8>,
52    pub server_key: Vec<u8>,
53}
54
55impl Drop for ScramCredentials {
56    fn drop(&mut self) {
57        self.stored_key.zeroize();
58        self.server_key.zeroize();
59    }
60}
61
62/// In-flight SCRAM handshake state held by the server between the
63/// server-first-message and client-final-message exchanges.
64#[derive(Debug, Clone)]
65pub struct ScramHandshake {
66    pub username: String,
67    pub client_nonce: String,
68    pub server_nonce: String,
69    pub salt: Vec<u8>,
70    pub iterations: u32,
71    pub auth_message: String,
72    pub server_signature: Vec<u8>,
73    /// Stored key from credentials, needed to verify the client proof.
74    stored_key: Vec<u8>,
75}
76
77/// Parsed Haystack `Authorization` header.
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub enum AuthHeader {
80    Hello {
81        username: String,
82        /// Base64-encoded client-first-message (contains the client nonce).
83        data: Option<String>,
84    },
85    Scram {
86        handshake_token: String,
87        data: String,
88    },
89    Bearer {
90        auth_token: String,
91    },
92}
93
94// ---------------------------------------------------------------------------
95// Internal helpers
96// ---------------------------------------------------------------------------
97
98/// Compute HMAC-SHA-256(key, msg).
99fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
100    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts keys of any size");
101    mac.update(msg);
102    mac.finalize().into_bytes().to_vec()
103}
104
105/// Compute SHA-256(data).
106fn sha256(data: &[u8]) -> Vec<u8> {
107    let mut hasher = Sha256::new();
108    hasher.update(data);
109    hasher.finalize().to_vec()
110}
111
112/// XOR two equal-length byte slices.
113fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
114    assert_eq!(a.len(), b.len(), "XOR operands must be the same length");
115    a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
116}
117
118/// PBKDF2-HMAC-SHA-256 key derivation, producing a 32-byte salted password.
119fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
120    let mut salted_password = vec![0u8; 32];
121    pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut salted_password);
122    salted_password
123    // Note: caller is responsible for zeroizing via Zeroize trait on Vec<u8>
124}
125
126/// Derive (ClientKey, StoredKey, ServerKey) from a salted password.
127fn derive_keys(salted_password: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
128    let client_key = hmac_sha256(salted_password, b"Client Key");
129    let stored_key = sha256(&client_key);
130    let server_key = hmac_sha256(salted_password, b"Server Key");
131    (client_key, stored_key, server_key)
132}
133
134/// Parse a `key=value` parameter from a SCRAM message segment.
135fn parse_scram_param<'a>(segment: &'a str, prefix: &str) -> Result<&'a str, AuthError> {
136    let trimmed = segment.trim();
137    trimmed.strip_prefix(prefix).ok_or_else(|| {
138        AuthError::HandshakeFailed(format!(
139            "expected prefix '{}' but got '{}'",
140            prefix, trimmed
141        ))
142    })
143}
144
145/// Build the client-first-message-bare: `n=<username>,r=<client_nonce>`.
146fn make_client_first_bare(username: &str, client_nonce: &str) -> String {
147    format!("n={},r={}", username, client_nonce)
148}
149
150// ---------------------------------------------------------------------------
151// Public API
152// ---------------------------------------------------------------------------
153
154/// Derive SCRAM credentials from a password (for user creation/storage).
155///
156/// Uses PBKDF2-HMAC-SHA-256 with the given salt and iteration count.
157pub fn derive_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
158    let mut salted_password = pbkdf2_sha256(password.as_bytes(), salt, iterations);
159    let (mut _client_key, stored_key, server_key) = derive_keys(&salted_password);
160    salted_password.zeroize();
161    _client_key.zeroize();
162    ScramCredentials {
163        salt: salt.to_vec(),
164        iterations,
165        stored_key,
166        server_key,
167    }
168}
169
170/// Generate a random nonce string (base64-encoded 18 random bytes).
171pub fn generate_nonce() -> String {
172    let mut bytes = [0u8; 18];
173    rand::rng().fill(&mut bytes);
174    BASE64.encode(bytes)
175}
176
177/// Client-side: Create the client-first-message data (base64-encoded).
178///
179/// Returns `(client_nonce, client_first_data_base64)`.
180///
181/// The client-first-message-bare is `n=<username>,r=<client_nonce>`.
182/// The full message prepends the GS2 header `n,,` (no channel binding).
183pub fn client_first_message(username: &str) -> (String, String) {
184    let client_nonce = generate_nonce();
185    let bare = make_client_first_bare(username, &client_nonce);
186    let full = format!("n,,{}", bare);
187    let encoded = BASE64.encode(full.as_bytes());
188    (client_nonce, encoded)
189}
190
191/// Server-side: Create the server-first-message data and handshake state.
192///
193/// `username` is taken from the HELLO phase. `client_nonce_b64` is the raw
194/// client nonce (as returned by [`client_first_message`]). `credentials` are
195/// the pre-computed SCRAM credentials for this user.
196///
197/// Returns `(handshake_state, server_first_data_base64)`.
198pub fn server_first_message(
199    username: &str,
200    client_nonce_b64: &str,
201    credentials: &ScramCredentials,
202) -> (ScramHandshake, String) {
203    let server_nonce = generate_nonce();
204    let combined_nonce = format!("{}{}", client_nonce_b64, server_nonce);
205    let salt_b64 = BASE64.encode(&credentials.salt);
206
207    // server-first-message: r=<combined>,s=<salt_b64>,i=<iterations>
208    let server_first_msg = format!(
209        "r={},s={},i={}",
210        combined_nonce, salt_b64, credentials.iterations
211    );
212
213    // client-first-message-bare (includes username per SCRAM spec)
214    let cfmb = make_client_first_bare(username, client_nonce_b64);
215
216    // client-final-message-without-proof (anticipated)
217    let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
218
219    // AuthMessage = client-first-bare "," server-first-msg "," client-final-without-proof
220    let auth_message = format!(
221        "{},{},{}",
222        cfmb, server_first_msg, client_final_without_proof
223    );
224
225    // Pre-compute server signature
226    let server_signature = hmac_sha256(&credentials.server_key, auth_message.as_bytes());
227
228    let server_first_b64 = BASE64.encode(server_first_msg.as_bytes());
229
230    let handshake = ScramHandshake {
231        username: username.to_string(),
232        client_nonce: client_nonce_b64.to_string(),
233        server_nonce,
234        salt: credentials.salt.clone(),
235        iterations: credentials.iterations,
236        auth_message,
237        server_signature,
238        stored_key: credentials.stored_key.clone(),
239    };
240
241    (handshake, server_first_b64)
242}
243
244/// Client-side: Process server-first-message, produce client-final-message.
245///
246/// `username` is the same value originally passed to [`client_first_message`].
247/// `password` is the user's plaintext password. `client_nonce` is the nonce
248/// returned by [`client_first_message`]. `server_first_b64` is the base64
249/// server-first-message data received from the server.
250///
251/// Returns `(client_final_data_base64, expected_server_signature)`.
252pub fn client_final_message(
253    password: &str,
254    client_nonce: &str,
255    server_first_b64: &str,
256    username: &str,
257) -> Result<(String, Vec<u8>), AuthError> {
258    // Decode and parse server-first-message
259    let server_first_bytes = BASE64
260        .decode(server_first_b64)
261        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
262    let server_first_msg = String::from_utf8(server_first_bytes)
263        .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
264
265    // Expected format: r=<combined_nonce>,s=<salt_b64>,i=<iterations>
266    let parts: Vec<&str> = server_first_msg.splitn(3, ',').collect();
267    if parts.len() != 3 {
268        return Err(AuthError::HandshakeFailed(
269            "invalid server-first-message format".to_string(),
270        ));
271    }
272
273    let combined_nonce = parse_scram_param(parts[0], "r=")?;
274    let salt_b64 = parse_scram_param(parts[1], "s=")?;
275    let iterations_str = parse_scram_param(parts[2], "i=")?;
276
277    // The combined nonce must start with our client nonce
278    if !combined_nonce.starts_with(client_nonce) {
279        return Err(AuthError::HandshakeFailed(
280            "combined nonce does not start with client nonce".to_string(),
281        ));
282    }
283
284    let salt = BASE64
285        .decode(salt_b64)
286        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
287    let iterations: u32 = iterations_str
288        .parse()
289        .map_err(|e: std::num::ParseIntError| AuthError::HandshakeFailed(e.to_string()))?;
290
291    // Key derivation
292    let mut salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
293    let (mut client_key, stored_key, server_key) = derive_keys(&salted_password);
294    salted_password.zeroize();
295
296    // Build AuthMessage
297    let cfmb = make_client_first_bare(username, client_nonce);
298    let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
299    let auth_message = format!(
300        "{},{},{}",
301        cfmb, server_first_msg, client_final_without_proof
302    );
303
304    // ClientSignature = HMAC(StoredKey, AuthMessage)
305    let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
306    // ClientProof = ClientKey XOR ClientSignature
307    let client_proof = xor_bytes(&client_key, &client_signature);
308    // ServerSignature = HMAC(ServerKey, AuthMessage)
309    let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
310
311    // Zeroize intermediate key material
312    client_key.zeroize();
313
314    // client-final-message: c=biws,r=<combined>,p=<proof_b64>
315    let proof_b64 = BASE64.encode(&client_proof);
316    let client_final_msg = format!("{},p={}", client_final_without_proof, proof_b64);
317    let client_final_b64 = BASE64.encode(client_final_msg.as_bytes());
318
319    Ok((client_final_b64, server_signature))
320}
321
322/// Server-side: Verify client-final-message and produce server signature.
323///
324/// Decodes the client-final-message, verifies the client proof against the
325/// stored key in the handshake state, and returns the server signature for
326/// the client to verify (sent as the `v=` field in server-final-message).
327pub fn server_verify_final(
328    handshake: &ScramHandshake,
329    client_final_b64: &str,
330) -> Result<Vec<u8>, AuthError> {
331    // Decode client-final-message
332    let client_final_bytes = BASE64
333        .decode(client_final_b64)
334        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
335    let client_final_msg = String::from_utf8(client_final_bytes)
336        .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
337
338    // Expected format: c=biws,r=<combined_nonce>,p=<proof_b64>
339    let parts: Vec<&str> = client_final_msg.splitn(3, ',').collect();
340    if parts.len() != 3 {
341        return Err(AuthError::HandshakeFailed(
342            "invalid client-final-message format".to_string(),
343        ));
344    }
345
346    // Validate channel binding
347    let channel_binding = parse_scram_param(parts[0], "c=")?;
348    if channel_binding != "biws" {
349        return Err(AuthError::HandshakeFailed(
350            "unexpected channel binding".to_string(),
351        ));
352    }
353
354    // Validate combined nonce
355    let combined_nonce = parse_scram_param(parts[1], "r=")?;
356    let expected_combined = format!("{}{}", handshake.client_nonce, handshake.server_nonce);
357    if !bool::from(
358        combined_nonce
359            .as_bytes()
360            .ct_eq(expected_combined.as_bytes()),
361    ) {
362        return Err(AuthError::HandshakeFailed("nonce mismatch".to_string()));
363    }
364
365    // Extract and decode client proof
366    let proof_b64 = parse_scram_param(parts[2], "p=")?;
367    let client_proof = BASE64
368        .decode(proof_b64)
369        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
370
371    // Verify the proof per RFC 5802:
372    //   ClientSignature = HMAC(StoredKey, AuthMessage)
373    //   RecoveredClientKey = ClientProof XOR ClientSignature
374    //   Check: SHA-256(RecoveredClientKey) == StoredKey
375    let client_signature = hmac_sha256(&handshake.stored_key, handshake.auth_message.as_bytes());
376    let recovered_client_key = xor_bytes(&client_proof, &client_signature);
377    let recovered_stored_key = sha256(&recovered_client_key);
378
379    if recovered_stored_key
380        .ct_eq(&handshake.stored_key)
381        .unwrap_u8()
382        == 0
383    {
384        return Err(AuthError::InvalidCredentials);
385    }
386
387    // Proof verified -- return server signature for the client to verify
388    Ok(handshake.server_signature.clone())
389}
390
391/// Extract the client nonce from a base64-encoded client-first-message.
392///
393/// The client-first-message format is `n,,n=<username>,r=<client_nonce>`.
394/// Returns the raw nonce string.
395pub fn extract_client_nonce(client_first_b64: &str) -> Result<String, AuthError> {
396    let bytes = BASE64
397        .decode(client_first_b64)
398        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
399    let msg = String::from_utf8(bytes).map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
400    // Strip GS2 header "n,," prefix
401    let bare = msg
402        .strip_prefix("n,,")
403        .ok_or_else(|| AuthError::HandshakeFailed("missing GS2 header in client-first".into()))?;
404    // Parse n=<user>,r=<nonce>
405    for part in bare.split(',') {
406        if let Some(nonce) = part.strip_prefix("r=") {
407            return Ok(nonce.to_string());
408        }
409    }
410    Err(AuthError::HandshakeFailed(
411        "missing r= nonce in client-first-message".into(),
412    ))
413}
414
415/// Parse a Haystack `Authorization` header value.
416///
417/// Supported formats:
418/// - `HELLO username=<base64(username)>`
419/// - `SCRAM handshakeToken=<token>, data=<data>`
420/// - `BEARER authToken=<token>`
421pub fn parse_auth_header(header: &str) -> Result<AuthHeader, AuthError> {
422    let header = header.trim();
423
424    if let Some(rest) = header.strip_prefix("HELLO ") {
425        let mut username_b64_val = None;
426        let mut data_val = None;
427        for part in rest.split(',') {
428            let part = part.trim();
429            if let Some(val) = part.strip_prefix("username=") {
430                username_b64_val = Some(val.trim().to_string());
431            } else if let Some(val) = part.strip_prefix("data=") {
432                data_val = Some(val.trim().to_string());
433            }
434        }
435        let username_b64 = username_b64_val
436            .ok_or_else(|| AuthError::InvalidHeader("missing username= in HELLO".into()))?;
437        let username_bytes = BASE64
438            .decode(&username_b64)
439            .map_err(|e| AuthError::Base64Error(e.to_string()))?;
440        let username = String::from_utf8(username_bytes)
441            .map_err(|e| AuthError::InvalidHeader(e.to_string()))?;
442        Ok(AuthHeader::Hello {
443            username,
444            data: data_val,
445        })
446    } else if let Some(rest) = header.strip_prefix("SCRAM ") {
447        let mut handshake_token = None;
448        let mut data = None;
449        for part in rest.split(',') {
450            let part = part.trim();
451            if let Some(val) = part.strip_prefix("handshakeToken=") {
452                handshake_token = Some(val.trim().to_string());
453            } else if let Some(val) = part.strip_prefix("data=") {
454                data = Some(val.trim().to_string());
455            }
456        }
457        let handshake_token = handshake_token
458            .ok_or_else(|| AuthError::InvalidHeader("missing handshakeToken= in SCRAM".into()))?;
459        let data = data.ok_or_else(|| AuthError::InvalidHeader("missing data= in SCRAM".into()))?;
460        Ok(AuthHeader::Scram {
461            handshake_token,
462            data,
463        })
464    } else if let Some(rest) = header.strip_prefix("BEARER ") {
465        let token = rest
466            .trim()
467            .strip_prefix("authToken=")
468            .ok_or_else(|| AuthError::InvalidHeader("missing authToken= in BEARER".into()))?;
469        Ok(AuthHeader::Bearer {
470            auth_token: token.trim().to_string(),
471        })
472    } else {
473        Err(AuthError::InvalidHeader(format!(
474            "unrecognized auth scheme: {}",
475            header
476        )))
477    }
478}
479
480/// Format a Haystack `WWW-Authenticate` header for a SCRAM challenge.
481///
482/// Produces: `SCRAM handshakeToken=<token>, hash=<hash>, data=<data_b64>`
483pub fn format_www_authenticate(handshake_token: &str, hash: &str, data_b64: &str) -> String {
484    format!(
485        "SCRAM handshakeToken={}, hash={}, data={}",
486        handshake_token, hash, data_b64
487    )
488}
489
490/// Format a Haystack `Authentication-Info` header with the auth token.
491///
492/// Produces: `authToken=<token>, data=<data_b64>`
493pub fn format_auth_info(auth_token: &str, data_b64: &str) -> String {
494    format!("authToken={}, data={}", auth_token, data_b64)
495}
496
497// ---------------------------------------------------------------------------
498// Tests
499// ---------------------------------------------------------------------------
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_derive_credentials() {
507        let password = "pencil";
508        let salt = b"random-salt-value";
509        let iterations = 4096;
510
511        let creds = derive_credentials(password, salt, iterations);
512
513        // Fields are populated correctly
514        assert_eq!(creds.salt, salt.to_vec());
515        assert_eq!(creds.iterations, iterations);
516        assert_eq!(creds.stored_key.len(), 32); // SHA-256 output length
517        assert_eq!(creds.server_key.len(), 32);
518
519        // Deterministic: same inputs produce same outputs
520        let creds2 = derive_credentials(password, salt, iterations);
521        assert_eq!(creds.stored_key, creds2.stored_key);
522        assert_eq!(creds.server_key, creds2.server_key);
523
524        // Different password yields different credentials
525        let creds3 = derive_credentials("other", salt, iterations);
526        assert_ne!(creds.stored_key, creds3.stored_key);
527        assert_ne!(creds.server_key, creds3.server_key);
528    }
529
530    #[test]
531    fn test_generate_nonce() {
532        let n1 = generate_nonce();
533        let n2 = generate_nonce();
534
535        // Each call produces a unique nonce
536        assert_ne!(n1, n2);
537
538        // Valid base64 encoding of 18 bytes
539        let decoded1 = BASE64.decode(&n1).expect("nonce must be valid base64");
540        assert_eq!(decoded1.len(), 18);
541
542        let decoded2 = BASE64.decode(&n2).expect("nonce must be valid base64");
543        assert_eq!(decoded2.len(), 18);
544    }
545
546    #[test]
547    fn test_parse_auth_header_hello() {
548        let username = "user";
549        let username_b64 = BASE64.encode(username.as_bytes());
550        let header = format!("HELLO username={}", username_b64);
551
552        let parsed = parse_auth_header(&header).unwrap();
553        assert_eq!(
554            parsed,
555            AuthHeader::Hello {
556                username: "user".to_string(),
557                data: None,
558            }
559        );
560    }
561
562    #[test]
563    fn test_parse_auth_header_scram() {
564        let header = "SCRAM handshakeToken=abc123, data=c29tZWRhdGE=";
565        let parsed = parse_auth_header(header).unwrap();
566        assert_eq!(
567            parsed,
568            AuthHeader::Scram {
569                handshake_token: "abc123".to_string(),
570                data: "c29tZWRhdGE=".to_string(),
571            }
572        );
573    }
574
575    #[test]
576    fn test_parse_auth_header_bearer() {
577        let header = "BEARER authToken=mytoken123";
578        let parsed = parse_auth_header(header).unwrap();
579        assert_eq!(
580            parsed,
581            AuthHeader::Bearer {
582                auth_token: "mytoken123".to_string(),
583            }
584        );
585    }
586
587    #[test]
588    fn test_parse_auth_header_invalid() {
589        // Unknown scheme
590        assert!(parse_auth_header("UNKNOWN foo=bar").is_err());
591        // HELLO missing username=
592        assert!(parse_auth_header("HELLO foo=bar").is_err());
593        // SCRAM missing data=
594        assert!(parse_auth_header("SCRAM handshakeToken=abc").is_err());
595        // BEARER missing authToken=
596        assert!(parse_auth_header("BEARER token=abc").is_err());
597        // Empty
598        assert!(parse_auth_header("").is_err());
599    }
600
601    #[test]
602    fn test_full_handshake() {
603        // Simulate the complete HELLO -> SCRAM -> BEARER flow.
604        let username = "testuser";
605        let password = "s3cret";
606        let salt = b"test-salt-12345";
607        let iterations = 4096;
608
609        // --- Server: pre-compute credentials (user registration) ---
610        let credentials = derive_credentials(password, salt, iterations);
611
612        // --- Client: HELLO phase ---
613        let username_b64 = BASE64.encode(username.as_bytes());
614        let hello_header = format!("HELLO username={}", username_b64);
615        let parsed = parse_auth_header(&hello_header).unwrap();
616        match &parsed {
617            AuthHeader::Hello { username: u, .. } => assert_eq!(u, username),
618            _ => panic!("expected Hello variant"),
619        }
620
621        // --- Client: generate client-first-message ---
622        let (client_nonce, _client_first_b64) = client_first_message(username);
623
624        // --- Server: generate server-first-message ---
625        let (handshake, server_first_b64) =
626            server_first_message(username, &client_nonce, &credentials);
627
628        // --- Server: format WWW-Authenticate header ---
629        let www_auth = format_www_authenticate("handshake-token-xyz", "SHA-256", &server_first_b64);
630        assert!(www_auth.contains("SCRAM"));
631        assert!(www_auth.contains("SHA-256"));
632        assert!(www_auth.contains("handshake-token-xyz"));
633
634        // --- Client: process server-first, produce client-final ---
635        let (client_final_b64, expected_server_sig) =
636            client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
637
638        // --- Server: verify client-final ---
639        let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
640
641        // Server signature should match what the client expects
642        assert_eq!(server_sig, expected_server_sig);
643
644        // --- Server: format Authentication-Info header ---
645        let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
646        let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
647        let auth_info = format_auth_info("auth-token-abc", &server_final_b64);
648        assert!(auth_info.contains("authToken=auth-token-abc"));
649
650        // --- Client: verify server signature from server-final ---
651        let server_final_decoded = BASE64.decode(&server_final_b64).unwrap();
652        let server_final_str = String::from_utf8(server_final_decoded).unwrap();
653        let sig_b64 = server_final_str.strip_prefix("v=").unwrap();
654        let received_server_sig = BASE64.decode(sig_b64).unwrap();
655        assert_eq!(received_server_sig, expected_server_sig);
656    }
657
658    #[test]
659    fn test_client_server_roundtrip() {
660        // Full roundtrip using the public API functions.
661        let username = "admin";
662        let password = "correcthorsebatterystaple";
663        let salt = b"unique-salt-value";
664        let iterations = DEFAULT_ITERATIONS;
665
666        // 1. Server: create credentials during user registration
667        let credentials = derive_credentials(password, salt, iterations);
668
669        // 2. Client: create client-first-message
670        let (client_nonce, client_first_b64) = client_first_message(username);
671
672        // Verify client-first is valid base64 and well-formed
673        let client_first_decoded = BASE64.decode(&client_first_b64).unwrap();
674        let client_first_str = String::from_utf8(client_first_decoded).unwrap();
675        assert!(client_first_str.starts_with("n,,"));
676        assert!(client_first_str.contains(&format!("r={}", client_nonce)));
677
678        // 3. Server: create server-first-message
679        let (handshake, server_first_b64) =
680            server_first_message(username, &client_nonce, &credentials);
681
682        // Verify server-first contains expected SCRAM fields
683        let server_first_decoded = BASE64.decode(&server_first_b64).unwrap();
684        let server_first_str = String::from_utf8(server_first_decoded).unwrap();
685        assert!(server_first_str.starts_with("r="));
686        assert!(server_first_str.contains(",s="));
687        assert!(server_first_str.contains(",i="));
688        assert!(server_first_str.contains(&client_nonce));
689
690        // 4. Client: create client-final-message
691        let (client_final_b64, expected_server_sig) =
692            client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
693
694        // Verify client-final structure
695        let client_final_decoded = BASE64.decode(&client_final_b64).unwrap();
696        let client_final_str = String::from_utf8(client_final_decoded).unwrap();
697        assert!(client_final_str.starts_with("c=biws,"));
698        assert!(client_final_str.contains(",p="));
699
700        // 5. Server: verify and get server signature
701        let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
702        assert_eq!(server_sig, expected_server_sig);
703
704        // 6. Wrong password: server rejects the proof
705        let (wrong_final_b64, _) =
706            client_final_message("wrongpassword", &client_nonce, &server_first_b64, username)
707                .unwrap();
708        let result = server_verify_final(&handshake, &wrong_final_b64);
709        assert!(result.is_err());
710        match result {
711            Err(AuthError::InvalidCredentials) => {} // expected
712            other => panic!("expected InvalidCredentials, got {:?}", other),
713        }
714    }
715
716    #[test]
717    fn test_format_www_authenticate() {
718        let result = format_www_authenticate("tok123", "SHA-256", "c29tZQ==");
719        assert_eq!(
720            result,
721            "SCRAM handshakeToken=tok123, hash=SHA-256, data=c29tZQ=="
722        );
723    }
724
725    #[test]
726    fn test_format_auth_info() {
727        let result = format_auth_info("auth-tok", "ZGF0YQ==");
728        assert_eq!(result, "authToken=auth-tok, data=ZGF0YQ==");
729    }
730}