Skip to main content

shuttle_rs/
app.rs

1use std::env;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4
5use crate::core::{Event, EventStore, EventType, Result, ShuttleError};
6use crate::oauth::{self, OAuthConfig, OAuthStore};
7use crate::store::SqliteEventStore;
8use axum::extract::{Form, Path as AxumPath, Query, State};
9use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
10use axum::response::{Html, IntoResponse, Redirect, Response};
11use axum::routing::{get, patch, post};
12use axum::{Json, Router};
13use serde::{Deserialize, Serialize};
14use serde_json::json;
15use uuid::Uuid;
16
17#[derive(Clone)]
18pub struct AppRuntime {
19    pub store: SqliteEventStore,
20    pub cwd: PathBuf,
21    pub workspace_id: String,
22    pub agent: String,
23    pub session_id: String,
24    pub oauth: Option<OAuthRuntime>,
25}
26
27#[derive(Clone)]
28pub struct OAuthRuntime {
29    pub config: OAuthConfig,
30    pub store: OAuthStore,
31}
32
33#[derive(Debug, Serialize)]
34struct Dashboard {
35    inbox: Vec<Event>,
36    tasks: Vec<crate::task::TaskSummary>,
37    memories: Vec<Event>,
38    context: crate::context::Context,
39}
40
41pub async fn serve(runtime: AppRuntime, addr: SocketAddr) -> Result<()> {
42    let app = router(runtime);
43    let listener = tokio::net::TcpListener::bind(addr)
44        .await
45        .map_err(|err| ShuttleError::Store(err.to_string()))?;
46    axum::serve(listener, app)
47        .await
48        .map_err(|err| ShuttleError::Store(err.to_string()))
49}
50
51pub fn router(runtime: AppRuntime) -> Router {
52    Router::new()
53        .route("/", get(index))
54        .route("/api/dashboard", get(dashboard))
55        .route("/api/inbox", get(inbox))
56        .route("/api/tasks", get(tasks))
57        .route("/api/tasks", post(create_task))
58        .route("/api/tasks/{id}", patch(update_task))
59        .route("/api/tasks/{id}/done", post(done_task))
60        .route("/api/memories", get(memories))
61        .route("/api/context", get(context))
62        .route("/api/recall", post(recall))
63        .route("/api/remember", post(remember))
64        .route(
65            "/mcp",
66            get(mcp_health)
67                .post(mcp_post)
68                .delete(mcp_delete)
69                .options(mcp_options),
70        )
71        .route(
72            "/.well-known/oauth-protected-resource",
73            get(oauth_protected_resource),
74        )
75        .route(
76            "/.well-known/oauth-protected-resource/mcp",
77            get(oauth_protected_resource),
78        )
79        .route(
80            "/.well-known/oauth-authorization-server",
81            get(oauth_authorization_server),
82        )
83        .route("/oauth/register", post(oauth_register))
84        .route(
85            "/oauth/authorize",
86            get(oauth_authorize_page).post(oauth_authorize_submit),
87        )
88        .route("/oauth/token", post(oauth_token))
89        .with_state(runtime)
90}
91
92async fn index(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
93    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
94        return response;
95    }
96    Html(
97        r#"<!doctype html>
98<html>
99<head>
100  <meta charset="utf-8">
101  <meta name="viewport" content="width=device-width, initial-scale=1">
102  <title>Shuttle</title>
103  <style>
104    body { font-family: system-ui, sans-serif; margin: 2rem; color: #1f2937; }
105    main { display: grid; gap: 1rem; grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); }
106    section { border: 1px solid #d1d5db; border-radius: 8px; padding: 1rem; }
107    h1 { margin-top: 0; }
108    pre { white-space: pre-wrap; overflow-wrap: anywhere; }
109  </style>
110</head>
111<body>
112  <h1>Shuttle</h1>
113  <main id="dashboard"></main>
114  <script>
115    fetch('/api/dashboard').then(r => r.json()).then(data => {
116      const root = document.getElementById('dashboard');
117      for (const [name, value] of Object.entries(data)) {
118        const section = document.createElement('section');
119        const heading = document.createElement('h2');
120        heading.textContent = name;
121        const pre = document.createElement('pre');
122        pre.textContent = JSON.stringify(value, null, 2);
123        section.append(heading, pre);
124        root.append(section);
125      }
126    });
127  </script>
128</body>
129</html>"#,
130    )
131    .into_response()
132}
133
134async fn dashboard(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
135    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
136        return response;
137    }
138    Json(Dashboard {
139        inbox: crate::message::inbox(&runtime.store, &runtime.agent)
140            .await
141            .unwrap_or_default(),
142        tasks: crate::task::open_tasks(&runtime.store, &runtime.workspace_id, Some(20))
143            .await
144            .unwrap_or_default(),
145        memories: crate::memory::memories(&runtime.store)
146            .await
147            .unwrap_or_default(),
148        context: crate::context::assemble_context(
149            &runtime.store,
150            &runtime.cwd,
151            &runtime.workspace_id,
152            &runtime.agent,
153        )
154        .await
155        .unwrap_or_else(|_| crate::context::Context {
156            repo: runtime.cwd.display().to_string(),
157            branch: "unknown".to_owned(),
158            commit: "unknown".to_owned(),
159            git_remote: None,
160            dirty: false,
161            dirty_files: Vec::new(),
162            open_tasks: Vec::new(),
163            claimed_tasks: Vec::new(),
164            recent_decisions: Vec::new(),
165            related_memories: Vec::new(),
166            recent_messages: Vec::new(),
167            pending_handoffs: Vec::new(),
168            recent_completed_handoffs: Vec::new(),
169            inbox: Vec::new(),
170        }),
171    })
172    .into_response()
173}
174
175async fn inbox(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
176    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
177        return response;
178    }
179    Json(
180        crate::message::inbox(&runtime.store, &runtime.agent)
181            .await
182            .unwrap_or_default(),
183    )
184    .into_response()
185}
186
187async fn tasks(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
188    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
189        return response;
190    }
191    Json(
192        crate::task::open_tasks(&runtime.store, &runtime.workspace_id, Some(20))
193            .await
194            .unwrap_or_default(),
195    )
196    .into_response()
197}
198
199async fn memories(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
200    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
201        return response;
202    }
203    Json(
204        crate::memory::memories(&runtime.store)
205            .await
206            .unwrap_or_default(),
207    )
208    .into_response()
209}
210
211async fn context(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
212    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
213        return response;
214    }
215    Json(
216        crate::context::assemble_context(
217            &runtime.store,
218            &runtime.cwd,
219            &runtime.workspace_id,
220            &runtime.agent,
221        )
222        .await
223        .ok(),
224    )
225    .into_response()
226}
227
228#[derive(Deserialize)]
229struct RecallRequest {
230    #[serde(default)]
231    query: String,
232}
233
234async fn recall(
235    headers: HeaderMap,
236    State(runtime): State<AppRuntime>,
237    Json(request): Json<RecallRequest>,
238) -> Response {
239    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
240        return response;
241    }
242    if request.query.trim().is_empty() {
243        return api_error("query is required");
244    }
245    let status = crate::context::repo_status(&runtime.cwd).ok();
246    let repo_id = status.as_ref().map(crate::context::repo_id);
247    let branch = status.as_ref().map(|status| status.branch.as_str());
248    Json(
249        crate::memory::ranked_recall(
250            &runtime.store,
251            &request.query,
252            None,
253            Some(&runtime.workspace_id),
254            repo_id.as_deref(),
255            branch,
256        )
257        .await
258        .unwrap_or_default(),
259    )
260    .into_response()
261}
262
263#[derive(Deserialize)]
264struct RememberRequest {
265    #[serde(default)]
266    kind: String,
267    #[serde(default)]
268    text: String,
269}
270
271async fn remember(
272    headers: HeaderMap,
273    State(runtime): State<AppRuntime>,
274    Json(request): Json<RememberRequest>,
275) -> Response {
276    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
277        return response;
278    }
279    if request.text.trim().is_empty() {
280        return api_error("text is required");
281    }
282    let event_type = match request.kind.as_str() {
283        "" | "memory" => EventType::Memory,
284        "decision" => EventType::Decision,
285        "observation" => EventType::Observation,
286        "pattern" => EventType::Pattern,
287        "fact" => EventType::Fact,
288        "bug" => EventType::Bug,
289        other => return api_error(&format!("unknown memory kind {other:?}")),
290    };
291    let event = crate::memory::new_typed_memory(
292        event_type,
293        runtime.workspace_id.clone(),
294        runtime.agent.clone(),
295        runtime.session_id.clone(),
296        request.text,
297    );
298    match runtime
299        .store
300        .append(with_repo_metadata(event, &runtime))
301        .await
302    {
303        Ok(event) => Json(event).into_response(),
304        Err(err) => api_error(&err.to_string()),
305    }
306}
307
308#[derive(Deserialize)]
309struct CreateTaskRequest {
310    #[serde(default)]
311    title: String,
312    #[serde(default)]
313    body: String,
314}
315
316async fn create_task(
317    headers: HeaderMap,
318    State(runtime): State<AppRuntime>,
319    Json(request): Json<CreateTaskRequest>,
320) -> Response {
321    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
322        return response;
323    }
324    if request.title.trim().is_empty() {
325        return api_error("title is required");
326    }
327    let content = if request.body.is_empty() {
328        request.title
329    } else {
330        format!("{}\n\n{}", request.title, request.body)
331    };
332    let event = crate::task::new_task(
333        runtime.workspace_id.clone(),
334        runtime.agent.clone(),
335        runtime.session_id.clone(),
336        content,
337    );
338    match runtime
339        .store
340        .append(with_repo_metadata(event, &runtime))
341        .await
342    {
343        Ok(event) => Json(event).into_response(),
344        Err(err) => api_error(&err.to_string()),
345    }
346}
347
348#[derive(Deserialize)]
349struct UpdateTaskRequest {
350    #[serde(default)]
351    text: String,
352}
353
354async fn update_task(
355    headers: HeaderMap,
356    State(runtime): State<AppRuntime>,
357    AxumPath(id): AxumPath<Uuid>,
358    Json(request): Json<UpdateTaskRequest>,
359) -> Response {
360    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
361        return response;
362    }
363    if request.text.trim().is_empty() {
364        return api_error("text is required");
365    }
366    if let Err(err) =
367        crate::task::ensure_task_exists(&runtime.store, &runtime.workspace_id, id).await
368    {
369        return api_error(&err.to_string());
370    }
371    let event = crate::task::new_task_update(
372        runtime.workspace_id.clone(),
373        runtime.agent.clone(),
374        runtime.session_id.clone(),
375        id,
376        request.text,
377    );
378    match runtime
379        .store
380        .append(with_repo_metadata(event, &runtime))
381        .await
382    {
383        Ok(event) => Json(event).into_response(),
384        Err(err) => api_error(&err.to_string()),
385    }
386}
387
388async fn done_task(
389    headers: HeaderMap,
390    State(runtime): State<AppRuntime>,
391    AxumPath(id): AxumPath<Uuid>,
392) -> Response {
393    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
394        return response;
395    }
396    if let Err(err) =
397        crate::task::ensure_task_exists(&runtime.store, &runtime.workspace_id, id).await
398    {
399        return api_error(&err.to_string());
400    }
401    let event = crate::task::new_task_done(
402        runtime.workspace_id.clone(),
403        runtime.agent.clone(),
404        runtime.session_id.clone(),
405        id,
406    );
407    match runtime
408        .store
409        .append(with_repo_metadata(event, &runtime))
410        .await
411    {
412        Ok(event) => Json(event).into_response(),
413        Err(err) => api_error(&err.to_string()),
414    }
415}
416
417fn api_error(message: &str) -> Response {
418    (StatusCode::BAD_REQUEST, Json(json!({ "error": message }))).into_response()
419}
420
421fn with_repo_metadata(mut event: Event, runtime: &AppRuntime) -> Event {
422    if let Ok(status) = crate::context::repo_status(&runtime.cwd) {
423        let repo_id = crate::context::repo_id(&status);
424        event.repo_id = Some(repo_id.clone());
425        event.repo_path = Some(status.repo_path.clone());
426        event.git_remote = status.git_remote.clone();
427        event.branch = Some(status.branch.clone());
428        event.commit = Some(status.commit.clone());
429        event.repo_dirty = Some(status.dirty);
430        if let Some(metadata) = event.metadata_json.as_object_mut() {
431            metadata.insert("repo_id".to_owned(), json!(repo_id));
432            metadata.insert("repo_path".to_owned(), json!(status.repo_path));
433            metadata.insert("git_remote".to_owned(), json!(status.git_remote));
434            metadata.insert("branch".to_owned(), json!(status.branch));
435            metadata.insert("commit".to_owned(), json!(status.commit));
436            metadata.insert("repo_dirty".to_owned(), json!(status.dirty));
437            metadata.insert("dirty_files".to_owned(), json!(status.dirty_files));
438        }
439    }
440    event
441}
442
443async fn mcp_health(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
444    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
445        return response;
446    }
447    with_cors((StatusCode::OK, "Shuttle MCP server"))
448}
449
450async fn mcp_delete(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
451    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
452        return response;
453    }
454    with_cors((StatusCode::OK, "OK"))
455}
456
457async fn mcp_options() -> Response {
458    with_cors(StatusCode::NO_CONTENT)
459}
460
461async fn mcp_post(
462    headers: HeaderMap,
463    State(runtime): State<AppRuntime>,
464    Json(request): Json<crate::mcp::Request>,
465) -> Response {
466    if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
467        return response;
468    }
469    let response = crate::mcp::handle_request(
470        &crate::mcp::McpRuntime {
471            store: runtime.store,
472            cwd: runtime.cwd,
473            workspace_id: runtime.workspace_id,
474            agent: runtime.agent,
475            session_id: runtime.session_id,
476        },
477        request,
478    )
479    .await;
480    with_cors(Json(response))
481}
482
483async fn oauth_protected_resource(State(runtime): State<AppRuntime>) -> Response {
484    let Some(oauth) = runtime.oauth else {
485        return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
486    };
487    Json(oauth::protected_resource_metadata(&oauth.config)).into_response()
488}
489
490async fn oauth_authorization_server(State(runtime): State<AppRuntime>) -> Response {
491    let Some(oauth) = runtime.oauth else {
492        return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
493    };
494    Json(oauth::authorization_server_metadata(&oauth.config)).into_response()
495}
496
497async fn oauth_register(
498    State(runtime): State<AppRuntime>,
499    Json(request): Json<oauth::RegisterRequest>,
500) -> Response {
501    let Some(oauth) = runtime.oauth else {
502        return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
503    };
504    match oauth.store.register_client(request) {
505        Ok(client) => Json(json!({
506            "client_id": client.client_id,
507            "client_id_issued_at": chrono::Utc::now().timestamp(),
508            "redirect_uris": client.redirect_uris,
509            "client_name": client.client_name,
510            "token_endpoint_auth_method": "none",
511        }))
512        .into_response(),
513        Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
514    }
515}
516
517async fn oauth_authorize_page(
518    State(runtime): State<AppRuntime>,
519    Query(request): Query<oauth::AuthorizeRequest>,
520) -> Response {
521    let Some(oauth) = runtime.oauth else {
522        return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
523    };
524    if request.response_type != "code" {
525        return oauth_error(
526            StatusCode::BAD_REQUEST,
527            "unsupported_response_type",
528            "response_type must be code",
529        );
530    }
531    match oauth
532        .store
533        .client_allows_redirect(&request.client_id, &request.redirect_uri)
534    {
535        Ok(true) => {
536            Html(authorize_html(&request, oauth.config.admin_token.is_some())).into_response()
537        }
538        Ok(false) => oauth_error(
539            StatusCode::BAD_REQUEST,
540            "invalid_request",
541            "unknown client_id or redirect_uri",
542        ),
543        Err(_) => oauth_error(
544            StatusCode::INTERNAL_SERVER_ERROR,
545            "server_error",
546            "failed to validate OAuth client",
547        ),
548    }
549}
550
551async fn oauth_authorize_submit(
552    State(runtime): State<AppRuntime>,
553    Form(form): Form<oauth::AuthorizeForm>,
554) -> Response {
555    let Some(oauth) = runtime.oauth else {
556        return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
557    };
558    if let Some(expected) = oauth.config.admin_token.as_deref() {
559        if !constant_time_eq(form.admin_token.as_bytes(), expected.as_bytes()) {
560            return oauth_error(
561                StatusCode::UNAUTHORIZED,
562                "access_denied",
563                "invalid admin token",
564            );
565        }
566    }
567    let request = oauth::AuthorizeRequest::from(form);
568    if request.response_type != "code" {
569        return oauth_error(
570            StatusCode::BAD_REQUEST,
571            "unsupported_response_type",
572            "response_type must be code",
573        );
574    }
575    match oauth.store.create_code(request.clone()) {
576        Ok(code) => Redirect::to(&oauth::authorize_redirect(
577            &request.redirect_uri,
578            &code,
579            request.state.as_deref(),
580        ))
581        .into_response(),
582        Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
583    }
584}
585
586async fn oauth_token(
587    State(runtime): State<AppRuntime>,
588    Form(request): Form<oauth::TokenRequest>,
589) -> Response {
590    let Some(oauth) = runtime.oauth else {
591        return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
592    };
593    if request.grant_type != "authorization_code" {
594        return oauth_error(
595            StatusCode::BAD_REQUEST,
596            "unsupported_grant_type",
597            "grant_type must be authorization_code",
598        );
599    }
600    match oauth.store.exchange_code(request) {
601        Ok(token) => Json(token).into_response(),
602        Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_grant", &err.to_string()),
603    }
604}
605
606fn mcp_unauthorized_response(
607    oauth: Option<&OAuthRuntime>,
608    headers: &HeaderMap,
609) -> Option<Response> {
610    if let Some(oauth) = oauth {
611        let Some(token) = bearer_token(headers) else {
612            return Some(unauthorized_oauth(&oauth.config));
613        };
614        return match oauth.store.validate_access_token(token) {
615            Ok(true) => None,
616            Ok(false) => Some(unauthorized_oauth(&oauth.config)),
617            Err(_) => Some(oauth_error(
618                StatusCode::UNAUTHORIZED,
619                "invalid_token",
620                "failed to validate access token",
621            )),
622        };
623    }
624
625    let token = env::var("SHUTTLE_MCP_BEARER_TOKEN")
626        .ok()
627        .filter(|token| !token.is_empty())?;
628    let expected = format!("Bearer {token}");
629    if headers
630        .get("authorization")
631        .and_then(|header| header.to_str().ok())
632        .is_some_and(|actual| constant_time_eq(actual.as_bytes(), expected.as_bytes()))
633    {
634        None
635    } else {
636        Some(with_cors(StatusCode::UNAUTHORIZED))
637    }
638}
639
640fn bearer_token(headers: &HeaderMap) -> Option<&str> {
641    headers
642        .get(header::AUTHORIZATION)
643        .and_then(|header| header.to_str().ok())
644        .and_then(|value| {
645            let (scheme, token) = value.split_once(' ')?;
646            scheme.eq_ignore_ascii_case("Bearer").then_some(token)
647        })
648}
649
650fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
651    let mut diff = left.len() ^ right.len();
652    for index in 0..left.len().max(right.len()) {
653        let left = *left.get(index).unwrap_or(&0);
654        let right = *right.get(index).unwrap_or(&0);
655        diff |= (left ^ right) as usize;
656    }
657    diff == 0
658}
659
660fn with_cors(response: impl IntoResponse) -> Response {
661    let (mut parts, body) = response.into_response().into_parts();
662    parts
663        .headers
664        .insert("access-control-allow-origin", HeaderValue::from_static("*"));
665    parts.headers.insert(
666        "access-control-allow-methods",
667        HeaderValue::from_static("GET,POST,DELETE,OPTIONS"),
668    );
669    parts.headers.insert(
670        "access-control-allow-headers",
671        HeaderValue::from_static(
672            "accept,authorization,content-type,mcp-protocol-version,mcp-session-id",
673        ),
674    );
675    parts.headers.insert(
676        "access-control-expose-headers",
677        HeaderValue::from_static("mcp-session-id"),
678    );
679    Response::from_parts(parts, body)
680}
681
682fn unauthorized_oauth(config: &OAuthConfig) -> Response {
683    let mut response = with_cors(StatusCode::UNAUTHORIZED);
684    let header_value = format!(
685        r#"Bearer resource_metadata="{}/.well-known/oauth-protected-resource/mcp", scope="mcp""#,
686        quoted_header_value(&config.public_url)
687    );
688    if let Ok(value) = HeaderValue::from_str(&header_value) {
689        response
690            .headers_mut()
691            .insert(header::WWW_AUTHENTICATE, value);
692    }
693    response
694}
695
696fn oauth_error(status: StatusCode, code: &str, description: &str) -> Response {
697    (
698        status,
699        Json(json!({ "error": code, "error_description": description })),
700    )
701        .into_response()
702}
703
704fn authorize_html(request: &oauth::AuthorizeRequest, requires_admin_token: bool) -> String {
705    let admin = if requires_admin_token {
706        r#"<label>Admin token <input name="admin_token" type="password" autocomplete="current-password" required></label>"#
707    } else {
708        r#"<input name="admin_token" type="hidden" value="">"#
709    };
710    format!(
711        r#"<!doctype html>
712<html>
713<head>
714  <meta charset="utf-8">
715  <meta name="viewport" content="width=device-width, initial-scale=1">
716  <title>Authorize Shuttle</title>
717  <style>
718    body {{ font-family: system-ui, sans-serif; margin: 2rem; color: #1f2937; }}
719    form {{ display: grid; gap: 1rem; max-width: 32rem; }}
720    input, button {{ font: inherit; padding: .6rem; }}
721    label {{ display: grid; gap: .35rem; }}
722  </style>
723</head>
724<body>
725  <h1>Authorize Shuttle MCP</h1>
726  <p>{client_id} is requesting access to Shuttle MCP.</p>
727  <form method="post" action="/oauth/authorize">
728    {admin}
729    <input type="hidden" name="response_type" value="{response_type}">
730    <input type="hidden" name="client_id" value="{client_id}">
731    <input type="hidden" name="redirect_uri" value="{redirect_uri}">
732    <input type="hidden" name="state" value="{state}">
733    <input type="hidden" name="scope" value="{scope}">
734    <input type="hidden" name="code_challenge" value="{code_challenge}">
735    <input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
736    <button type="submit">Authorize</button>
737  </form>
738</body>
739</html>"#,
740        admin = admin,
741        response_type = html_escape(&request.response_type),
742        client_id = html_escape(&request.client_id),
743        redirect_uri = html_escape(&request.redirect_uri),
744        state = html_escape(request.state.as_deref().unwrap_or("")),
745        scope = html_escape(request.scope.as_deref().unwrap_or("mcp")),
746        code_challenge = html_escape(request.code_challenge.as_deref().unwrap_or("")),
747        code_challenge_method =
748            html_escape(request.code_challenge_method.as_deref().unwrap_or("S256")),
749    )
750}
751
752fn html_escape(value: &str) -> String {
753    value
754        .replace('&', "&amp;")
755        .replace('<', "&lt;")
756        .replace('>', "&gt;")
757        .replace('"', "&quot;")
758}
759
760fn quoted_header_value(value: &str) -> String {
761    value.replace('\\', "\\\\").replace('"', "\\\"")
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767    use axum::body::Body;
768    use axum::http::{Method, Request};
769    use base64::engine::general_purpose::URL_SAFE_NO_PAD;
770    use base64::Engine;
771    use http_body_util::BodyExt;
772    use sha2::{Digest, Sha256};
773    use tower::ServiceExt;
774
775    fn runtime(oauth: Option<OAuthRuntime>) -> AppRuntime {
776        let dir = tempfile::tempdir().unwrap().keep();
777        let db = dir.join("shuttle.db");
778        AppRuntime {
779            store: SqliteEventStore::open(&db).unwrap(),
780            cwd: dir,
781            workspace_id: "workspace".to_owned(),
782            agent: "codex".to_owned(),
783            session_id: "session".to_owned(),
784            oauth,
785        }
786    }
787
788    fn oauth_runtime() -> OAuthRuntime {
789        let dir = tempfile::tempdir().unwrap().keep();
790        OAuthRuntime {
791            config: OAuthConfig {
792                public_url: "https://shuttle.example.test".to_owned(),
793                admin_token: Some("admin-token".to_owned()),
794            },
795            store: OAuthStore::open(dir.join("oauth.db")).unwrap(),
796        }
797    }
798
799    fn issue_access_token(oauth: &OAuthRuntime) -> String {
800        let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
801        let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
802        let client = oauth
803            .store
804            .register_client(oauth::RegisterRequest {
805                redirect_uris: vec!["https://client.example.test/callback".to_owned()],
806                client_name: Some("client".to_owned()),
807            })
808            .unwrap();
809        let code = oauth
810            .store
811            .create_code(oauth::AuthorizeRequest {
812                response_type: "code".to_owned(),
813                client_id: client.client_id.clone(),
814                redirect_uri: "https://client.example.test/callback".to_owned(),
815                state: None,
816                scope: Some("mcp".to_owned()),
817                code_challenge: Some(challenge),
818                code_challenge_method: Some("S256".to_owned()),
819            })
820            .unwrap();
821        oauth
822            .store
823            .exchange_code(oauth::TokenRequest {
824                grant_type: "authorization_code".to_owned(),
825                client_id: client.client_id,
826                redirect_uri: "https://client.example.test/callback".to_owned(),
827                code: Some(code),
828                code_verifier: Some(verifier.to_owned()),
829            })
830            .unwrap()
831            .access_token
832    }
833
834    async fn request(
835        runtime: AppRuntime,
836        path: &str,
837        authorization: Option<&str>,
838    ) -> axum::response::Response {
839        let mut builder = Request::builder().method(Method::GET).uri(path);
840        if let Some(authorization) = authorization {
841            builder = builder.header(header::AUTHORIZATION, authorization);
842        }
843        router(runtime)
844            .oneshot(builder.body(Body::empty()).unwrap())
845            .await
846            .unwrap()
847    }
848
849    #[tokio::test]
850    async fn dashboard_routes_require_bearer_when_oauth_is_configured() {
851        let oauth = oauth_runtime();
852        let token = issue_access_token(&oauth);
853
854        let index = request(runtime(Some(oauth.clone())), "/", None).await;
855        let dashboard = request(runtime(Some(oauth.clone())), "/api/dashboard", None).await;
856        let authorized_index = request(
857            runtime(Some(oauth.clone())),
858            "/",
859            Some(&format!("Bearer {token}")),
860        )
861        .await;
862        let authorized_dashboard = request(
863            runtime(Some(oauth)),
864            "/api/dashboard",
865            Some(&format!("Bearer {token}")),
866        )
867        .await;
868
869        assert_eq!(index.status(), StatusCode::UNAUTHORIZED);
870        assert_eq!(dashboard.status(), StatusCode::UNAUTHORIZED);
871        assert_eq!(authorized_index.status(), StatusCode::OK);
872        assert_eq!(authorized_dashboard.status(), StatusCode::OK);
873    }
874
875    #[tokio::test]
876    async fn dashboard_routes_remain_local_open_without_auth_configuration() {
877        let index = request(runtime(None), "/", None).await;
878        let dashboard = request(runtime(None), "/api/dashboard", None).await;
879
880        assert_eq!(index.status(), StatusCode::OK);
881        assert_eq!(dashboard.status(), StatusCode::OK);
882    }
883
884    #[tokio::test]
885    async fn oauth_metadata_is_not_blocked_by_dashboard_auth() {
886        let response = request(
887            runtime(Some(oauth_runtime())),
888            "/.well-known/oauth-authorization-server",
889            None,
890        )
891        .await;
892
893        assert_eq!(response.status(), StatusCode::OK);
894    }
895
896    #[tokio::test]
897    async fn oauth_authorize_submit_redirects_callback_with_get() {
898        let oauth = oauth_runtime();
899        let client = oauth
900            .store
901            .register_client(oauth::RegisterRequest {
902                redirect_uris: vec!["https://claude.ai/api/mcp/auth_callback".to_owned()],
903                client_name: Some("Claude".to_owned()),
904            })
905            .unwrap();
906        let body = format!(
907            "admin_token=admin-token&response_type=code&client_id={}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&state=state-123&scope=mcp&code_challenge=challenge&code_challenge_method=S256",
908            client.client_id
909        );
910        let response = router(runtime(Some(oauth)))
911            .oneshot(
912                Request::builder()
913                    .method(Method::POST)
914                    .uri("/oauth/authorize")
915                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
916                    .body(Body::from(body))
917                    .unwrap(),
918            )
919            .await
920            .unwrap();
921
922        assert_eq!(response.status(), StatusCode::SEE_OTHER);
923        let location = response
924            .headers()
925            .get(header::LOCATION)
926            .and_then(|value| value.to_str().ok())
927            .unwrap();
928        assert!(location.starts_with("https://claude.ai/api/mcp/auth_callback?code=stl_"));
929        assert!(location.contains("&state=state-123"));
930        assert!(!location.contains("&iss="));
931        assert!(!location.contains("?iss="));
932    }
933
934    #[tokio::test]
935    async fn oauth_authorize_rejects_redirect_uri_with_trailing_slash() {
936        let oauth = oauth_runtime();
937        let client = oauth
938            .store
939            .register_client(oauth::RegisterRequest {
940                redirect_uris: vec!["https://claude.ai/api/mcp/auth_callback".to_owned()],
941                client_name: Some("Claude".to_owned()),
942            })
943            .unwrap();
944        let body = format!(
945            "admin_token=admin-token&response_type=code&client_id={}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback%2F&state=state-123&scope=mcp&code_challenge=challenge&code_challenge_method=S256",
946            client.client_id
947        );
948        let response = router(runtime(Some(oauth)))
949            .oneshot(
950                Request::builder()
951                    .method(Method::POST)
952                    .uri("/oauth/authorize")
953                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
954                    .body(Body::from(body))
955                    .unwrap(),
956            )
957            .await
958            .unwrap();
959
960        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
961    }
962
963    #[tokio::test]
964    async fn dashboard_html_renders_json_as_text_not_inner_html() {
965        let response = request(runtime(None), "/", None).await;
966        let body = response.into_body().collect().await.unwrap().to_bytes();
967        let html = String::from_utf8(body.to_vec()).unwrap();
968
969        assert!(html.contains("heading.textContent = name"));
970        assert!(html.contains("pre.textContent = JSON.stringify(value, null, 2)"));
971        assert!(!html.contains("section.innerHTML"));
972    }
973
974    #[tokio::test]
975    async fn backend_api_can_remember_and_recall() {
976        let runtime = runtime(None);
977        let app = router(runtime);
978        let remember = app
979            .clone()
980            .oneshot(
981                Request::builder()
982                    .method(Method::POST)
983                    .uri("/api/remember")
984                    .header(header::CONTENT_TYPE, "application/json")
985                    .body(Body::from(
986                        r#"{"kind":"fact","text":"SQLite backs Shuttle"}"#,
987                    ))
988                    .unwrap(),
989            )
990            .await
991            .unwrap();
992        assert_eq!(remember.status(), StatusCode::OK);
993
994        let recall = app
995            .oneshot(
996                Request::builder()
997                    .method(Method::POST)
998                    .uri("/api/recall")
999                    .header(header::CONTENT_TYPE, "application/json")
1000                    .body(Body::from(r#"{"query":"SQLite"}"#))
1001                    .unwrap(),
1002            )
1003            .await
1004            .unwrap();
1005        assert_eq!(recall.status(), StatusCode::OK);
1006        let body = recall.into_body().collect().await.unwrap().to_bytes();
1007        let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
1008        assert!(value
1009            .as_array()
1010            .unwrap()
1011            .iter()
1012            .any(|item| item["event"]["content"] == "SQLite backs Shuttle"));
1013    }
1014}