1use base64::Engine;
7use base64::engine::general_purpose::STANDARD as BASE64;
8use reqwest::Client;
9
10use crate::error::ClientError;
11use haystack_core::auth;
12
13pub async fn authenticate(
29 client: &Client,
30 base_url: &str,
31 username: &str,
32 password: &str,
33) -> Result<String, ClientError> {
34 let base_url = base_url.trim_end_matches('/');
35 let about_url = format!("{}/about", base_url);
36
37 let username_b64 = BASE64.encode(username.as_bytes());
42 let hello_header = format!("HELLO username={}", username_b64);
43
44 let (client_nonce, _client_first_b64) = auth::client_first_message(username);
45
46 let hello_resp = client
47 .get(&about_url)
48 .header("Authorization", &hello_header)
49 .send()
50 .await
51 .map_err(|e| ClientError::Transport(e.to_string()))?;
52
53 if hello_resp.status() != reqwest::StatusCode::UNAUTHORIZED {
54 return Err(ClientError::AuthFailed(format!(
55 "expected 401 from HELLO, got {}",
56 hello_resp.status()
57 )));
58 }
59
60 let www_auth = hello_resp
63 .headers()
64 .get("www-authenticate")
65 .and_then(|v| v.to_str().ok())
66 .ok_or_else(|| {
67 ClientError::AuthFailed("missing WWW-Authenticate header in 401 response".to_string())
68 })?
69 .to_string();
70
71 let (handshake_token, server_first_b64) = parse_www_authenticate(&www_auth)?;
72
73 let (client_final_b64, expected_server_sig) =
78 auth::client_final_message(password, &client_nonce, &server_first_b64, username)
79 .map_err(|e| ClientError::AuthFailed(e.to_string()))?;
80
81 let scram_header = format!(
83 "SCRAM handshakeToken={}, data={}",
84 handshake_token, client_final_b64
85 );
86
87 let scram_resp = client
88 .get(&about_url)
89 .header("Authorization", &scram_header)
90 .send()
91 .await
92 .map_err(|e| ClientError::Transport(e.to_string()))?;
93
94 if !scram_resp.status().is_success() {
95 return Err(ClientError::AuthFailed(format!(
96 "SCRAM phase failed with status {}",
97 scram_resp.status()
98 )));
99 }
100
101 let auth_info = scram_resp
106 .headers()
107 .get("authentication-info")
108 .and_then(|v| v.to_str().ok())
109 .ok_or_else(|| {
110 ClientError::AuthFailed(
111 "missing Authentication-Info header in SCRAM response".to_string(),
112 )
113 })?
114 .to_string();
115
116 let (auth_token, server_final_b64) = parse_auth_info(&auth_info)?;
117
118 let server_final_bytes = BASE64.decode(&server_final_b64).map_err(|e| {
120 ClientError::AuthFailed(format!("invalid base64 in server-final data: {}", e))
121 })?;
122 let server_final_msg = String::from_utf8(server_final_bytes).map_err(|e| {
123 ClientError::AuthFailed(format!("invalid UTF-8 in server-final data: {}", e))
124 })?;
125 let server_sig_b64 = server_final_msg.strip_prefix("v=").ok_or_else(|| {
126 ClientError::AuthFailed("server-final message missing v= prefix".to_string())
127 })?;
128 let received_server_sig = BASE64.decode(server_sig_b64).map_err(|e| {
129 ClientError::AuthFailed(format!("invalid base64 in server signature: {}", e))
130 })?;
131
132 if received_server_sig != expected_server_sig {
133 return Err(ClientError::AuthFailed(
134 "server signature verification failed".to_string(),
135 ));
136 }
137
138 Ok(auth_token)
139}
140
141fn parse_www_authenticate(header: &str) -> Result<(String, String), ClientError> {
147 let rest = header
148 .trim()
149 .strip_prefix("SCRAM ")
150 .ok_or_else(|| ClientError::AuthFailed("WWW-Authenticate not SCRAM scheme".to_string()))?;
151
152 let mut handshake_token = None;
153 let mut data = None;
154
155 for part in rest.split(',') {
156 let part = part.trim();
157 if let Some(val) = part.strip_prefix("handshakeToken=") {
158 handshake_token = Some(val.trim().to_string());
159 } else if let Some(val) = part.strip_prefix("data=") {
160 data = Some(val.trim().to_string());
161 }
162 }
164
165 let handshake_token = handshake_token.ok_or_else(|| {
166 ClientError::AuthFailed("missing handshakeToken in WWW-Authenticate".to_string())
167 })?;
168 let data = data
169 .ok_or_else(|| ClientError::AuthFailed("missing data in WWW-Authenticate".to_string()))?;
170
171 Ok((handshake_token, data))
172}
173
174fn parse_auth_info(header: &str) -> Result<(String, String), ClientError> {
180 let mut auth_token = None;
181 let mut data = None;
182
183 for part in header.split(',') {
184 let part = part.trim();
185 if let Some(val) = part.strip_prefix("authToken=") {
186 auth_token = Some(val.trim().to_string());
187 } else if let Some(val) = part.strip_prefix("data=") {
188 data = Some(val.trim().to_string());
189 }
190 }
191
192 let auth_token = auth_token.ok_or_else(|| {
193 ClientError::AuthFailed("missing authToken in Authentication-Info header".to_string())
194 })?;
195 let data = data.ok_or_else(|| {
196 ClientError::AuthFailed("missing data in Authentication-Info header".to_string())
197 })?;
198
199 Ok((auth_token, data))
200}