Skip to main content

trellis_auth/
browser_login.rs

1use std::net::SocketAddr;
2use std::time::Duration;
3
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use base64::Engine as _;
6use ed25519_dalek::SigningKey;
7use reqwest::Client as HttpClient;
8use serde_json::{json, Value};
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::TcpListener;
11use tokio::sync::oneshot;
12use tokio::time::timeout;
13use url::Url;
14
15use crate::client::{connect_admin_client_async, AuthClient};
16use crate::models::{
17    AdminLoginOutcome, AdminSessionState, BindResponse, BindResponseBound, BoundSession,
18    BrowserLoginChallenge, CallbackOutcome, CallbackTokenRequest, StartBrowserLoginOpts,
19};
20use crate::TrellisAuthError;
21use trellis_client::SessionAuth;
22
23fn base64url_encode(bytes: &[u8]) -> String {
24    URL_SAFE_NO_PAD.encode(bytes)
25}
26
27fn encode_contract_query(contract_json: &str) -> Result<String, TrellisAuthError> {
28    let parsed: Value = serde_json::from_str(contract_json)?;
29    Ok(base64url_encode(serde_json::to_string(&parsed)?.as_bytes()))
30}
31
32pub(crate) fn callback_page_html() -> &'static str {
33    r#"<!doctype html>
34<html>
35  <head>
36    <meta charset=\"utf-8\" />
37    <title>Trellis CLI Login</title>
38  </head>
39  <body>
40    <p id=\"status\">Completing Trellis CLI login...</p>
41    <script>
42      const params = new URLSearchParams(window.location.search);
43      const flowId = params.get("flowId");
44      const authError = params.get("authError");
45      const status = document.getElementById("status");
46      if (!flowId && !authError) {
47        status.textContent = "Missing auth result in callback URL.";
48      } else {
49        fetch("/token", {
50          method: "POST",
51          headers: { "content-type": "application/json" },
52          body: JSON.stringify({ flowId, authError })
53        }).then(async (response) => {
54          if (!response.ok) {
55            throw new Error(await response.text());
56          }
57          status.textContent = authError
58            ? `Login failed: ${authError}`
59            : "Login complete. You can close this window.";
60        }).catch((error) => {
61          status.textContent = `Login handoff failed: ${error}`;
62        });
63      }
64    </script>
65  </body>
66</html>
67"#
68}
69
70fn http_response(status_line: &str, content_type: &str, body: &[u8]) -> Vec<u8> {
71    let mut out = format!(
72        "HTTP/1.1 {status_line}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
73        body.len()
74    )
75    .into_bytes();
76    out.extend_from_slice(body);
77    out
78}
79
80fn find_header_end(bytes: &[u8]) -> Option<usize> {
81    bytes.windows(4).position(|window| window == b"\r\n\r\n")
82}
83
84async fn read_http_request(
85    stream: &mut tokio::net::TcpStream,
86) -> Result<(String, String, Vec<u8>), TrellisAuthError> {
87    let mut buffer = Vec::new();
88    let mut chunk = [0u8; 4096];
89    let mut header_end = None;
90    let mut content_length = 0usize;
91
92    loop {
93        let read = stream.read(&mut chunk).await?;
94        if read == 0 {
95            return Err(TrellisAuthError::InvalidCallbackRequest);
96        }
97        buffer.extend_from_slice(&chunk[..read]);
98
99        if header_end.is_none() {
100            header_end = find_header_end(&buffer);
101            if let Some(end) = header_end {
102                let header_text = String::from_utf8_lossy(&buffer[..end]);
103                for line in header_text.lines() {
104                    if let Some(value) = line.strip_prefix("Content-Length:") {
105                        content_length = value
106                            .trim()
107                            .parse()
108                            .map_err(|_| TrellisAuthError::InvalidCallbackRequest)?;
109                    }
110                }
111            }
112        }
113
114        if let Some(end) = header_end {
115            let body_start = end + 4;
116            if buffer.len() >= body_start + content_length {
117                let header_text = String::from_utf8_lossy(&buffer[..end]);
118                let request_line = header_text
119                    .lines()
120                    .next()
121                    .ok_or(TrellisAuthError::InvalidCallbackRequest)?;
122                let mut parts = request_line.split_whitespace();
123                let method = parts
124                    .next()
125                    .ok_or(TrellisAuthError::InvalidCallbackRequest)?
126                    .to_string();
127                let path = parts
128                    .next()
129                    .ok_or(TrellisAuthError::InvalidCallbackRequest)?
130                    .to_string();
131                let body = buffer[body_start..body_start + content_length].to_vec();
132                return Ok((method, path, body));
133            }
134        }
135    }
136}
137
138/// Generate a new base64url-encoded Ed25519 session seed and public key.
139pub fn generate_session_keypair() -> (String, String) {
140    let seed: [u8; 32] = rand::random();
141    let signing_key = SigningKey::from_bytes(&seed);
142    let public_key = signing_key.verifying_key().to_bytes();
143    (base64url_encode(&seed), base64url_encode(&public_key))
144}
145
146/// Build the Trellis `GET /auth/login` URL that creates a browser flow and
147/// redirects into the deployment portal.
148pub fn build_auth_login_url(
149    auth_url: &str,
150    redirect_to: &str,
151    auth: &SessionAuth,
152    contract_json: &str,
153) -> Result<String, TrellisAuthError> {
154    let sig = auth.sign_sha256_domain("oauth-init", &format!("{redirect_to}:null"));
155    let mut url = Url::parse(auth_url)?;
156    url.set_path("/auth/login");
157    url.query_pairs_mut()
158        .append_pair("redirectTo", redirect_to)
159        .append_pair("sessionKey", &auth.session_key)
160        .append_pair("sig", &sig)
161        .append_pair("contract", &encode_contract_query(contract_json)?);
162    Ok(url.to_string())
163}
164
165async fn start_callback_server(
166    listen: &str,
167) -> Result<
168    (
169        SocketAddr,
170        oneshot::Receiver<CallbackOutcome>,
171        tokio::task::JoinHandle<()>,
172    ),
173    TrellisAuthError,
174> {
175    let listener = TcpListener::bind(listen).await?;
176    let local_addr = listener.local_addr()?;
177    let (token_tx, token_rx) = oneshot::channel::<CallbackOutcome>();
178    let shared_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(token_tx)));
179
180    let handle = tokio::spawn(async move {
181        loop {
182            let Ok((mut stream, _)) = listener.accept().await else {
183                break;
184            };
185            let response = match read_http_request(&mut stream).await {
186                Ok((method, path, _body)) if method == "GET" && path.starts_with("/callback") => {
187                    http_response(
188                        "200 OK",
189                        "text/html; charset=utf-8",
190                        callback_page_html().as_bytes(),
191                    )
192                }
193                Ok((method, path, body)) if method == "POST" && path == "/token" => {
194                    let parsed = serde_json::from_slice::<CallbackTokenRequest>(&body);
195                    match parsed {
196                        Ok(payload) => {
197                            let outcome = payload
198                                .flow_id
199                                .filter(|value| !value.is_empty())
200                                .map(CallbackOutcome::FlowId)
201                                .or_else(|| {
202                                    payload
203                                        .auth_error
204                                        .filter(|value| !value.is_empty())
205                                        .map(CallbackOutcome::AuthError)
206                                });
207                            match outcome {
208                                Some(value) => {
209                                    if let Some(sender) =
210                                        shared_tx.lock().expect("callback mutex poisoned").take()
211                                    {
212                                        let _ = sender.send(value);
213                                    }
214                                    http_response("200 OK", "text/plain; charset=utf-8", b"ok")
215                                }
216                                None => http_response(
217                                    "400 Bad Request",
218                                    "text/plain; charset=utf-8",
219                                    b"invalid auth callback payload",
220                                ),
221                            }
222                        }
223                        Err(_) => http_response(
224                            "400 Bad Request",
225                            "text/plain; charset=utf-8",
226                            b"invalid auth callback payload",
227                        ),
228                    }
229                }
230                Ok(_) => http_response("404 Not Found", "text/plain; charset=utf-8", b"not found"),
231                Err(_) => http_response(
232                    "400 Bad Request",
233                    "text/plain; charset=utf-8",
234                    b"invalid request",
235                ),
236            };
237
238            let _ = stream.write_all(&response).await;
239            let _ = stream.shutdown().await;
240        }
241    });
242
243    Ok((local_addr, token_rx, handle))
244}
245
246async fn bind_session(
247    auth_url: &str,
248    auth: &SessionAuth,
249    flow_id: &str,
250) -> Result<BoundSession, TrellisAuthError> {
251    let client = HttpClient::builder().build()?;
252    let bind_url = format!(
253        "{}/auth/flow/{}/bind",
254        auth_url.trim_end_matches('/'),
255        flow_id
256    );
257    let sig = auth.sign_sha256_domain("bind-flow", flow_id);
258    let response = client
259        .post(bind_url)
260        .json(&json!({
261            "sessionKey": auth.session_key,
262            "sig": sig,
263        }))
264        .send()
265        .await?;
266    let status = response.status();
267    let text = response.text().await?;
268    if !status.is_success() {
269        return Err(TrellisAuthError::BindHttpFailure(status.as_u16(), text));
270    }
271
272    match serde_json::from_str::<BindResponse>(&text)? {
273        BindResponse::Bound(BindResponseBound {
274            binding_token,
275            inbox_prefix,
276            expires,
277            sentinel,
278        }) => Ok(BoundSession {
279            binding_token,
280            inbox_prefix,
281            expires,
282            sentinel,
283        }),
284        BindResponse::ApprovalRequired { approval } => Err(TrellisAuthError::UnexpectedBindStatus(
285            format!("approval_required:{approval}"),
286        )),
287        BindResponse::ApprovalDenied { approval } => Err(TrellisAuthError::UnexpectedBindStatus(
288            format!("approval_denied:{approval}"),
289        )),
290        BindResponse::InsufficientCapabilities {
291            approval,
292            missing_capabilities,
293        } => Err(TrellisAuthError::UnexpectedBindStatus(format!(
294            "insufficient_capabilities:{approval}:{missing_capabilities:?}"
295        ))),
296    }
297}
298
299impl BrowserLoginChallenge {
300    /// Return the URL the user should open to complete login.
301    pub fn login_url(&self) -> &str {
302        &self.login_url
303    }
304
305    /// Wait for the callback, bind the session, and confirm the user is an admin.
306    pub async fn complete(
307        self,
308        auth_url: &str,
309        nats_servers: &str,
310    ) -> Result<AdminLoginOutcome, TrellisAuthError> {
311        let outcome = timeout(Duration::from_secs(300), self.receiver)
312            .await
313            .map_err(|_| TrellisAuthError::LoginTimedOut)?
314            .map_err(|_| TrellisAuthError::LoginInterrupted)?;
315        self.server_handle.abort();
316
317        let flow_id = match outcome {
318            CallbackOutcome::FlowId(value) => value,
319            CallbackOutcome::AuthError(value) => {
320                return Err(TrellisAuthError::AuthFlowFailed(value))
321            }
322        };
323
324        let bound = bind_session(auth_url, &self.auth, &flow_id).await?;
325        let mut state = AdminSessionState {
326            auth_url: auth_url.to_string(),
327            nats_servers: nats_servers.to_string(),
328            session_seed: self.session_seed,
329            session_key: self.auth.session_key.clone(),
330            binding_token: bound.binding_token,
331            sentinel_jwt: bound.sentinel.jwt,
332            sentinel_seed: bound.sentinel.seed,
333            expires: bound.expires,
334        };
335
336        let client = connect_admin_client_async(&state).await?;
337        let auth_client = AuthClient::new(&client);
338        let user = auth_client.me().await?;
339        if !user
340            .capabilities
341            .iter()
342            .any(|capability| capability == "admin")
343        {
344            return Err(TrellisAuthError::NotAdmin);
345        }
346        auth_client.renew_binding_token(&mut state).await?;
347
348        Ok(AdminLoginOutcome { state, user })
349    }
350}
351
352/// Start the browser login flow and local callback listener.
353pub async fn start_browser_login(
354    opts: &StartBrowserLoginOpts<'_>,
355) -> Result<BrowserLoginChallenge, TrellisAuthError> {
356    let (session_seed, _session_key) = generate_session_keypair();
357    let auth = SessionAuth::from_seed_base64url(&session_seed)?;
358    let (callback_addr, receiver, server_handle) = start_callback_server(opts.listen).await?;
359    let redirect_to = format!("http://{callback_addr}/callback");
360    let login_url = build_auth_login_url(
361        opts.auth_url,
362        &redirect_to,
363        &auth,
364        opts.contract_json,
365    )?;
366
367    Ok(BrowserLoginChallenge {
368        login_url,
369        session_seed,
370        auth,
371        receiver,
372        server_handle,
373    })
374}