1use std::net::SocketAddr;
2use std::time::Duration;
3
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use base64::Engine as _;
6use ed25519_dalek::SigningKey;
7use reqwest::Client as HttpClient;
8use serde_json::{json, Value};
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::TcpListener;
11use tokio::sync::oneshot;
12use tokio::time::timeout;
13use url::Url;
14
15use crate::client::{connect_admin_client_async, AuthClient};
16use crate::models::{
17 AdminLoginOutcome, AdminSessionState, BindResponse, BindResponseBound, BoundSession,
18 BrowserLoginChallenge, CallbackOutcome, CallbackTokenRequest, StartBrowserLoginOpts,
19};
20use crate::TrellisAuthError;
21use trellis_client::SessionAuth;
22
23fn base64url_encode(bytes: &[u8]) -> String {
24 URL_SAFE_NO_PAD.encode(bytes)
25}
26
27fn encode_contract_query(contract_json: &str) -> Result<String, TrellisAuthError> {
28 let parsed: Value = serde_json::from_str(contract_json)?;
29 Ok(base64url_encode(serde_json::to_string(&parsed)?.as_bytes()))
30}
31
32pub(crate) fn callback_page_html() -> &'static str {
33 r#"<!doctype html>
34<html>
35 <head>
36 <meta charset=\"utf-8\" />
37 <title>Trellis CLI Login</title>
38 </head>
39 <body>
40 <p id=\"status\">Completing Trellis CLI login...</p>
41 <script>
42 const params = new URLSearchParams(window.location.search);
43 const flowId = params.get("flowId");
44 const authError = params.get("authError");
45 const status = document.getElementById("status");
46 if (!flowId && !authError) {
47 status.textContent = "Missing auth result in callback URL.";
48 } else {
49 fetch("/token", {
50 method: "POST",
51 headers: { "content-type": "application/json" },
52 body: JSON.stringify({ flowId, authError })
53 }).then(async (response) => {
54 if (!response.ok) {
55 throw new Error(await response.text());
56 }
57 status.textContent = authError
58 ? `Login failed: ${authError}`
59 : "Login complete. You can close this window.";
60 }).catch((error) => {
61 status.textContent = `Login handoff failed: ${error}`;
62 });
63 }
64 </script>
65 </body>
66</html>
67"#
68}
69
70fn http_response(status_line: &str, content_type: &str, body: &[u8]) -> Vec<u8> {
71 let mut out = format!(
72 "HTTP/1.1 {status_line}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
73 body.len()
74 )
75 .into_bytes();
76 out.extend_from_slice(body);
77 out
78}
79
80fn find_header_end(bytes: &[u8]) -> Option<usize> {
81 bytes.windows(4).position(|window| window == b"\r\n\r\n")
82}
83
84async fn read_http_request(
85 stream: &mut tokio::net::TcpStream,
86) -> Result<(String, String, Vec<u8>), TrellisAuthError> {
87 let mut buffer = Vec::new();
88 let mut chunk = [0u8; 4096];
89 let mut header_end = None;
90 let mut content_length = 0usize;
91
92 loop {
93 let read = stream.read(&mut chunk).await?;
94 if read == 0 {
95 return Err(TrellisAuthError::InvalidCallbackRequest);
96 }
97 buffer.extend_from_slice(&chunk[..read]);
98
99 if header_end.is_none() {
100 header_end = find_header_end(&buffer);
101 if let Some(end) = header_end {
102 let header_text = String::from_utf8_lossy(&buffer[..end]);
103 for line in header_text.lines() {
104 if let Some(value) = line.strip_prefix("Content-Length:") {
105 content_length = value
106 .trim()
107 .parse()
108 .map_err(|_| TrellisAuthError::InvalidCallbackRequest)?;
109 }
110 }
111 }
112 }
113
114 if let Some(end) = header_end {
115 let body_start = end + 4;
116 if buffer.len() >= body_start + content_length {
117 let header_text = String::from_utf8_lossy(&buffer[..end]);
118 let request_line = header_text
119 .lines()
120 .next()
121 .ok_or(TrellisAuthError::InvalidCallbackRequest)?;
122 let mut parts = request_line.split_whitespace();
123 let method = parts
124 .next()
125 .ok_or(TrellisAuthError::InvalidCallbackRequest)?
126 .to_string();
127 let path = parts
128 .next()
129 .ok_or(TrellisAuthError::InvalidCallbackRequest)?
130 .to_string();
131 let body = buffer[body_start..body_start + content_length].to_vec();
132 return Ok((method, path, body));
133 }
134 }
135 }
136}
137
138pub fn generate_session_keypair() -> (String, String) {
140 let seed: [u8; 32] = rand::random();
141 let signing_key = SigningKey::from_bytes(&seed);
142 let public_key = signing_key.verifying_key().to_bytes();
143 (base64url_encode(&seed), base64url_encode(&public_key))
144}
145
146pub fn build_auth_login_url(
149 auth_url: &str,
150 redirect_to: &str,
151 auth: &SessionAuth,
152 contract_json: &str,
153) -> Result<String, TrellisAuthError> {
154 let sig = auth.sign_sha256_domain("oauth-init", &format!("{redirect_to}:null"));
155 let mut url = Url::parse(auth_url)?;
156 url.set_path("/auth/login");
157 url.query_pairs_mut()
158 .append_pair("redirectTo", redirect_to)
159 .append_pair("sessionKey", &auth.session_key)
160 .append_pair("sig", &sig)
161 .append_pair("contract", &encode_contract_query(contract_json)?);
162 Ok(url.to_string())
163}
164
165async fn start_callback_server(
166 listen: &str,
167) -> Result<
168 (
169 SocketAddr,
170 oneshot::Receiver<CallbackOutcome>,
171 tokio::task::JoinHandle<()>,
172 ),
173 TrellisAuthError,
174> {
175 let listener = TcpListener::bind(listen).await?;
176 let local_addr = listener.local_addr()?;
177 let (token_tx, token_rx) = oneshot::channel::<CallbackOutcome>();
178 let shared_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(token_tx)));
179
180 let handle = tokio::spawn(async move {
181 loop {
182 let Ok((mut stream, _)) = listener.accept().await else {
183 break;
184 };
185 let response = match read_http_request(&mut stream).await {
186 Ok((method, path, _body)) if method == "GET" && path.starts_with("/callback") => {
187 http_response(
188 "200 OK",
189 "text/html; charset=utf-8",
190 callback_page_html().as_bytes(),
191 )
192 }
193 Ok((method, path, body)) if method == "POST" && path == "/token" => {
194 let parsed = serde_json::from_slice::<CallbackTokenRequest>(&body);
195 match parsed {
196 Ok(payload) => {
197 let outcome = payload
198 .flow_id
199 .filter(|value| !value.is_empty())
200 .map(CallbackOutcome::FlowId)
201 .or_else(|| {
202 payload
203 .auth_error
204 .filter(|value| !value.is_empty())
205 .map(CallbackOutcome::AuthError)
206 });
207 match outcome {
208 Some(value) => {
209 if let Some(sender) =
210 shared_tx.lock().expect("callback mutex poisoned").take()
211 {
212 let _ = sender.send(value);
213 }
214 http_response("200 OK", "text/plain; charset=utf-8", b"ok")
215 }
216 None => http_response(
217 "400 Bad Request",
218 "text/plain; charset=utf-8",
219 b"invalid auth callback payload",
220 ),
221 }
222 }
223 Err(_) => http_response(
224 "400 Bad Request",
225 "text/plain; charset=utf-8",
226 b"invalid auth callback payload",
227 ),
228 }
229 }
230 Ok(_) => http_response("404 Not Found", "text/plain; charset=utf-8", b"not found"),
231 Err(_) => http_response(
232 "400 Bad Request",
233 "text/plain; charset=utf-8",
234 b"invalid request",
235 ),
236 };
237
238 let _ = stream.write_all(&response).await;
239 let _ = stream.shutdown().await;
240 }
241 });
242
243 Ok((local_addr, token_rx, handle))
244}
245
246async fn bind_session(
247 auth_url: &str,
248 auth: &SessionAuth,
249 flow_id: &str,
250) -> Result<BoundSession, TrellisAuthError> {
251 let client = HttpClient::builder().build()?;
252 let bind_url = format!(
253 "{}/auth/flow/{}/bind",
254 auth_url.trim_end_matches('/'),
255 flow_id
256 );
257 let sig = auth.sign_sha256_domain("bind-flow", flow_id);
258 let response = client
259 .post(bind_url)
260 .json(&json!({
261 "sessionKey": auth.session_key,
262 "sig": sig,
263 }))
264 .send()
265 .await?;
266 let status = response.status();
267 let text = response.text().await?;
268 if !status.is_success() {
269 return Err(TrellisAuthError::BindHttpFailure(status.as_u16(), text));
270 }
271
272 match serde_json::from_str::<BindResponse>(&text)? {
273 BindResponse::Bound(BindResponseBound {
274 binding_token,
275 inbox_prefix,
276 expires,
277 sentinel,
278 }) => Ok(BoundSession {
279 binding_token,
280 inbox_prefix,
281 expires,
282 sentinel,
283 }),
284 BindResponse::ApprovalRequired { approval } => Err(TrellisAuthError::UnexpectedBindStatus(
285 format!("approval_required:{approval}"),
286 )),
287 BindResponse::ApprovalDenied { approval } => Err(TrellisAuthError::UnexpectedBindStatus(
288 format!("approval_denied:{approval}"),
289 )),
290 BindResponse::InsufficientCapabilities {
291 approval,
292 missing_capabilities,
293 } => Err(TrellisAuthError::UnexpectedBindStatus(format!(
294 "insufficient_capabilities:{approval}:{missing_capabilities:?}"
295 ))),
296 }
297}
298
299impl BrowserLoginChallenge {
300 pub fn login_url(&self) -> &str {
302 &self.login_url
303 }
304
305 pub async fn complete(
307 self,
308 auth_url: &str,
309 nats_servers: &str,
310 ) -> Result<AdminLoginOutcome, TrellisAuthError> {
311 let outcome = timeout(Duration::from_secs(300), self.receiver)
312 .await
313 .map_err(|_| TrellisAuthError::LoginTimedOut)?
314 .map_err(|_| TrellisAuthError::LoginInterrupted)?;
315 self.server_handle.abort();
316
317 let flow_id = match outcome {
318 CallbackOutcome::FlowId(value) => value,
319 CallbackOutcome::AuthError(value) => {
320 return Err(TrellisAuthError::AuthFlowFailed(value))
321 }
322 };
323
324 let bound = bind_session(auth_url, &self.auth, &flow_id).await?;
325 let mut state = AdminSessionState {
326 auth_url: auth_url.to_string(),
327 nats_servers: nats_servers.to_string(),
328 session_seed: self.session_seed,
329 session_key: self.auth.session_key.clone(),
330 binding_token: bound.binding_token,
331 sentinel_jwt: bound.sentinel.jwt,
332 sentinel_seed: bound.sentinel.seed,
333 expires: bound.expires,
334 };
335
336 let client = connect_admin_client_async(&state).await?;
337 let auth_client = AuthClient::new(&client);
338 let user = auth_client.me().await?;
339 if !user
340 .capabilities
341 .iter()
342 .any(|capability| capability == "admin")
343 {
344 return Err(TrellisAuthError::NotAdmin);
345 }
346 auth_client.renew_binding_token(&mut state).await?;
347
348 Ok(AdminLoginOutcome { state, user })
349 }
350}
351
352pub async fn start_browser_login(
354 opts: &StartBrowserLoginOpts<'_>,
355) -> Result<BrowserLoginChallenge, TrellisAuthError> {
356 let (session_seed, _session_key) = generate_session_keypair();
357 let auth = SessionAuth::from_seed_base64url(&session_seed)?;
358 let (callback_addr, receiver, server_handle) = start_callback_server(opts.listen).await?;
359 let redirect_to = format!("http://{callback_addr}/callback");
360 let login_url = build_auth_login_url(
361 opts.auth_url,
362 &redirect_to,
363 &auth,
364 opts.contract_json,
365 )?;
366
367 Ok(BrowserLoginChallenge {
368 login_url,
369 session_seed,
370 auth,
371 receiver,
372 server_handle,
373 })
374}