Skip to main content

haystack_client/
auth.rs

1//! Client-side SCRAM SHA-256 authentication handshake.
2//!
3//! Performs the three-phase Haystack auth handshake (HELLO, SCRAM, BEARER)
4//! against a Haystack server, returning the auth token on success.
5
6use base64::Engine;
7use base64::engine::general_purpose::STANDARD as BASE64;
8use reqwest::Client;
9
10use crate::error::ClientError;
11use haystack_core::auth;
12
13/// Perform SCRAM SHA-256 authentication handshake against a Haystack server.
14///
15/// Executes the three-phase handshake:
16/// 1. HELLO: sends username, receives SCRAM challenge
17/// 2. SCRAM: sends client proof, receives auth token and server signature
18/// 3. Returns the auth token for subsequent Bearer authentication
19///
20/// # Arguments
21/// * `client` - The reqwest HTTP client to use
22/// * `base_url` - The server API root (e.g. `http://localhost:8080/api`)
23/// * `username` - The username to authenticate as
24/// * `password` - The user's plaintext password
25///
26/// # Returns
27/// The auth token string on success.
28pub async fn authenticate(
29    client: &Client,
30    base_url: &str,
31    username: &str,
32    password: &str,
33) -> Result<String, ClientError> {
34    let base_url = base_url.trim_end_matches('/');
35    let about_url = format!("{}/about", base_url);
36
37    // -----------------------------------------------------------------------
38    // Phase 1: HELLO
39    // -----------------------------------------------------------------------
40    // Send GET /api/about with Authorization: HELLO username=<base64(username)>, data=<client_first>
41    let username_b64 = BASE64.encode(username.as_bytes());
42    let (client_nonce, client_first_b64) = auth::client_first_message(username);
43    let hello_header = format!(
44        "HELLO username={}, data={}",
45        username_b64, client_first_b64
46    );
47
48    let hello_resp = client
49        .get(&about_url)
50        .header("Authorization", &hello_header)
51        .send()
52        .await
53        .map_err(|e| ClientError::Transport(e.to_string()))?;
54
55    if hello_resp.status() != reqwest::StatusCode::UNAUTHORIZED {
56        return Err(ClientError::AuthFailed(format!(
57            "expected 401 from HELLO, got {}",
58            hello_resp.status()
59        )));
60    }
61
62    // Parse WWW-Authenticate header
63    // Expected: SCRAM handshakeToken=..., hash=SHA-256, data=<server_first_b64>
64    let www_auth = hello_resp
65        .headers()
66        .get("www-authenticate")
67        .and_then(|v| v.to_str().ok())
68        .ok_or_else(|| {
69            ClientError::AuthFailed("missing WWW-Authenticate header in 401 response".to_string())
70        })?
71        .to_string();
72
73    let (handshake_token, server_first_b64) = parse_www_authenticate(&www_auth)?;
74
75    // -----------------------------------------------------------------------
76    // Phase 2: SCRAM
77    // -----------------------------------------------------------------------
78    // Compute client-final-message from server-first-message
79    let (client_final_b64, expected_server_sig) =
80        auth::client_final_message(password, &client_nonce, &server_first_b64, username)
81            .map_err(|e| ClientError::AuthFailed(e.to_string()))?;
82
83    // Send GET /api/about with Authorization: SCRAM handshakeToken=..., data=<client_final>
84    let scram_header = format!(
85        "SCRAM handshakeToken={}, data={}",
86        handshake_token, client_final_b64
87    );
88
89    let scram_resp = client
90        .get(&about_url)
91        .header("Authorization", &scram_header)
92        .send()
93        .await
94        .map_err(|e| ClientError::Transport(e.to_string()))?;
95
96    if !scram_resp.status().is_success() {
97        return Err(ClientError::AuthFailed(format!(
98            "SCRAM phase failed with status {}",
99            scram_resp.status()
100        )));
101    }
102
103    // -----------------------------------------------------------------------
104    // Phase 3: Extract auth token
105    // -----------------------------------------------------------------------
106    // Parse Authentication-Info header: authToken=..., data=<server_final_b64>
107    let auth_info = scram_resp
108        .headers()
109        .get("authentication-info")
110        .and_then(|v| v.to_str().ok())
111        .ok_or_else(|| {
112            ClientError::AuthFailed(
113                "missing Authentication-Info header in SCRAM response".to_string(),
114            )
115        })?
116        .to_string();
117
118    let (auth_token, server_final_b64) = parse_auth_info(&auth_info)?;
119
120    // Verify the server signature from the server-final-message
121    let server_final_bytes = BASE64.decode(&server_final_b64).map_err(|e| {
122        ClientError::AuthFailed(format!("invalid base64 in server-final data: {}", e))
123    })?;
124    let server_final_msg = String::from_utf8(server_final_bytes).map_err(|e| {
125        ClientError::AuthFailed(format!("invalid UTF-8 in server-final data: {}", e))
126    })?;
127    let server_sig_b64 = server_final_msg.strip_prefix("v=").ok_or_else(|| {
128        ClientError::AuthFailed("server-final message missing v= prefix".to_string())
129    })?;
130    let received_server_sig = BASE64.decode(server_sig_b64).map_err(|e| {
131        ClientError::AuthFailed(format!("invalid base64 in server signature: {}", e))
132    })?;
133
134    if received_server_sig != expected_server_sig {
135        return Err(ClientError::AuthFailed(
136            "server signature verification failed".to_string(),
137        ));
138    }
139
140    Ok(auth_token)
141}
142
143/// Parse the WWW-Authenticate header from a SCRAM challenge response.
144///
145/// Expected format: `SCRAM handshakeToken=<token>, hash=SHA-256, data=<b64>`
146///
147/// Returns `(handshake_token, server_first_data_b64)`.
148fn parse_www_authenticate(header: &str) -> Result<(String, String), ClientError> {
149    let rest = header
150        .trim()
151        .strip_prefix("SCRAM ")
152        .ok_or_else(|| ClientError::AuthFailed("WWW-Authenticate not SCRAM scheme".to_string()))?;
153
154    let mut handshake_token = None;
155    let mut data = None;
156
157    for part in rest.split(',') {
158        let part = part.trim();
159        if let Some(val) = part.strip_prefix("handshakeToken=") {
160            handshake_token = Some(val.trim().to_string());
161        } else if let Some(val) = part.strip_prefix("data=") {
162            data = Some(val.trim().to_string());
163        }
164        // hash= is informational; we always use SHA-256
165    }
166
167    let handshake_token = handshake_token.ok_or_else(|| {
168        ClientError::AuthFailed("missing handshakeToken in WWW-Authenticate".to_string())
169    })?;
170    let data = data
171        .ok_or_else(|| ClientError::AuthFailed("missing data in WWW-Authenticate".to_string()))?;
172
173    Ok((handshake_token, data))
174}
175
176/// Parse the Authentication-Info header to extract the auth token and server-final data.
177///
178/// Expected format: `authToken=<token>, data=<b64>`
179///
180/// Returns `(auth_token, server_final_data_b64)`.
181fn parse_auth_info(header: &str) -> Result<(String, String), ClientError> {
182    let mut auth_token = None;
183    let mut data = None;
184
185    for part in header.split(',') {
186        let part = part.trim();
187        if let Some(val) = part.strip_prefix("authToken=") {
188            auth_token = Some(val.trim().to_string());
189        } else if let Some(val) = part.strip_prefix("data=") {
190            data = Some(val.trim().to_string());
191        }
192    }
193
194    let auth_token = auth_token.ok_or_else(|| {
195        ClientError::AuthFailed("missing authToken in Authentication-Info header".to_string())
196    })?;
197    let data = data.ok_or_else(|| {
198        ClientError::AuthFailed("missing data in Authentication-Info header".to_string())
199    })?;
200
201    Ok((auth_token, data))
202}