1use std::env;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4
5use crate::core::{Event, EventStore, EventType, Result, ShuttleError};
6use crate::oauth::{self, OAuthConfig, OAuthStore};
7use crate::store::SqliteEventStore;
8use axum::extract::{Form, Path as AxumPath, Query, State};
9use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
10use axum::response::{Html, IntoResponse, Redirect, Response};
11use axum::routing::{get, patch, post};
12use axum::{Json, Router};
13use serde::{Deserialize, Serialize};
14use serde_json::json;
15use uuid::Uuid;
16
17#[derive(Clone)]
18pub struct AppRuntime {
19 pub store: SqliteEventStore,
20 pub cwd: PathBuf,
21 pub workspace_id: String,
22 pub agent: String,
23 pub session_id: String,
24 pub oauth: Option<OAuthRuntime>,
25}
26
27#[derive(Clone)]
28pub struct OAuthRuntime {
29 pub config: OAuthConfig,
30 pub store: OAuthStore,
31}
32
33#[derive(Debug, Serialize)]
34struct Dashboard {
35 inbox: Vec<Event>,
36 tasks: Vec<crate::task::TaskSummary>,
37 memories: Vec<Event>,
38 context: crate::context::Context,
39}
40
41pub async fn serve(runtime: AppRuntime, addr: SocketAddr) -> Result<()> {
42 let app = router(runtime);
43 let listener = tokio::net::TcpListener::bind(addr)
44 .await
45 .map_err(|err| ShuttleError::Store(err.to_string()))?;
46 axum::serve(listener, app)
47 .await
48 .map_err(|err| ShuttleError::Store(err.to_string()))
49}
50
51pub fn router(runtime: AppRuntime) -> Router {
52 Router::new()
53 .route("/", get(index))
54 .route("/api/dashboard", get(dashboard))
55 .route("/api/inbox", get(inbox))
56 .route("/api/tasks", get(tasks))
57 .route("/api/tasks", post(create_task))
58 .route("/api/tasks/{id}", patch(update_task))
59 .route("/api/tasks/{id}/done", post(done_task))
60 .route("/api/memories", get(memories))
61 .route("/api/context", get(context))
62 .route("/api/recall", post(recall))
63 .route("/api/remember", post(remember))
64 .route(
65 "/mcp",
66 get(mcp_health)
67 .post(mcp_post)
68 .delete(mcp_delete)
69 .options(mcp_options),
70 )
71 .route(
72 "/.well-known/oauth-protected-resource",
73 get(oauth_protected_resource),
74 )
75 .route(
76 "/.well-known/oauth-protected-resource/mcp",
77 get(oauth_protected_resource),
78 )
79 .route(
80 "/.well-known/oauth-authorization-server",
81 get(oauth_authorization_server),
82 )
83 .route("/oauth/register", post(oauth_register))
84 .route(
85 "/oauth/authorize",
86 get(oauth_authorize_page).post(oauth_authorize_submit),
87 )
88 .route("/oauth/token", post(oauth_token))
89 .with_state(runtime)
90}
91
92async fn index(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
93 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
94 return response;
95 }
96 Html(
97 r#"<!doctype html>
98<html>
99<head>
100 <meta charset="utf-8">
101 <meta name="viewport" content="width=device-width, initial-scale=1">
102 <title>Shuttle</title>
103 <style>
104 body { font-family: system-ui, sans-serif; margin: 2rem; color: #1f2937; }
105 main { display: grid; gap: 1rem; grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); }
106 section { border: 1px solid #d1d5db; border-radius: 8px; padding: 1rem; }
107 h1 { margin-top: 0; }
108 pre { white-space: pre-wrap; overflow-wrap: anywhere; }
109 </style>
110</head>
111<body>
112 <h1>Shuttle</h1>
113 <main id="dashboard"></main>
114 <script>
115 fetch('/api/dashboard').then(r => r.json()).then(data => {
116 const root = document.getElementById('dashboard');
117 for (const [name, value] of Object.entries(data)) {
118 const section = document.createElement('section');
119 const heading = document.createElement('h2');
120 heading.textContent = name;
121 const pre = document.createElement('pre');
122 pre.textContent = JSON.stringify(value, null, 2);
123 section.append(heading, pre);
124 root.append(section);
125 }
126 });
127 </script>
128</body>
129</html>"#,
130 )
131 .into_response()
132}
133
134async fn dashboard(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
135 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
136 return response;
137 }
138 Json(Dashboard {
139 inbox: crate::message::inbox(&runtime.store, &runtime.agent)
140 .await
141 .unwrap_or_default(),
142 tasks: crate::task::open_tasks(&runtime.store, &runtime.workspace_id, Some(20))
143 .await
144 .unwrap_or_default(),
145 memories: crate::memory::memories(&runtime.store)
146 .await
147 .unwrap_or_default(),
148 context: crate::context::assemble_context(
149 &runtime.store,
150 &runtime.cwd,
151 &runtime.workspace_id,
152 &runtime.agent,
153 )
154 .await
155 .unwrap_or_else(|_| crate::context::Context {
156 repo: runtime.cwd.display().to_string(),
157 branch: "unknown".to_owned(),
158 commit: "unknown".to_owned(),
159 git_remote: None,
160 dirty: false,
161 dirty_files: Vec::new(),
162 open_tasks: Vec::new(),
163 claimed_tasks: Vec::new(),
164 recent_decisions: Vec::new(),
165 related_memories: Vec::new(),
166 recent_messages: Vec::new(),
167 pending_handoffs: Vec::new(),
168 recent_completed_handoffs: Vec::new(),
169 inbox: Vec::new(),
170 }),
171 })
172 .into_response()
173}
174
175async fn inbox(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
176 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
177 return response;
178 }
179 Json(
180 crate::message::inbox(&runtime.store, &runtime.agent)
181 .await
182 .unwrap_or_default(),
183 )
184 .into_response()
185}
186
187async fn tasks(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
188 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
189 return response;
190 }
191 Json(
192 crate::task::open_tasks(&runtime.store, &runtime.workspace_id, Some(20))
193 .await
194 .unwrap_or_default(),
195 )
196 .into_response()
197}
198
199async fn memories(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
200 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
201 return response;
202 }
203 Json(
204 crate::memory::memories(&runtime.store)
205 .await
206 .unwrap_or_default(),
207 )
208 .into_response()
209}
210
211async fn context(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
212 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
213 return response;
214 }
215 Json(
216 crate::context::assemble_context(
217 &runtime.store,
218 &runtime.cwd,
219 &runtime.workspace_id,
220 &runtime.agent,
221 )
222 .await
223 .ok(),
224 )
225 .into_response()
226}
227
228#[derive(Deserialize)]
229struct RecallRequest {
230 #[serde(default)]
231 query: String,
232}
233
234async fn recall(
235 headers: HeaderMap,
236 State(runtime): State<AppRuntime>,
237 Json(request): Json<RecallRequest>,
238) -> Response {
239 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
240 return response;
241 }
242 if request.query.trim().is_empty() {
243 return api_error("query is required");
244 }
245 let status = crate::context::repo_status(&runtime.cwd).ok();
246 let repo_id = status.as_ref().map(crate::context::repo_id);
247 let branch = status.as_ref().map(|status| status.branch.as_str());
248 Json(
249 crate::memory::ranked_recall(
250 &runtime.store,
251 &request.query,
252 None,
253 Some(&runtime.workspace_id),
254 repo_id.as_deref(),
255 branch,
256 )
257 .await
258 .unwrap_or_default(),
259 )
260 .into_response()
261}
262
263#[derive(Deserialize)]
264struct RememberRequest {
265 #[serde(default)]
266 kind: String,
267 #[serde(default)]
268 text: String,
269}
270
271async fn remember(
272 headers: HeaderMap,
273 State(runtime): State<AppRuntime>,
274 Json(request): Json<RememberRequest>,
275) -> Response {
276 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
277 return response;
278 }
279 if request.text.trim().is_empty() {
280 return api_error("text is required");
281 }
282 let event_type = match request.kind.as_str() {
283 "" | "memory" => EventType::Memory,
284 "decision" => EventType::Decision,
285 "observation" => EventType::Observation,
286 "pattern" => EventType::Pattern,
287 "fact" => EventType::Fact,
288 "bug" => EventType::Bug,
289 other => return api_error(&format!("unknown memory kind {other:?}")),
290 };
291 let event = crate::memory::new_typed_memory(
292 event_type,
293 runtime.workspace_id.clone(),
294 runtime.agent.clone(),
295 runtime.session_id.clone(),
296 request.text,
297 );
298 match runtime
299 .store
300 .append(with_repo_metadata(event, &runtime))
301 .await
302 {
303 Ok(event) => Json(event).into_response(),
304 Err(err) => api_error(&err.to_string()),
305 }
306}
307
308#[derive(Deserialize)]
309struct CreateTaskRequest {
310 #[serde(default)]
311 title: String,
312 #[serde(default)]
313 body: String,
314}
315
316async fn create_task(
317 headers: HeaderMap,
318 State(runtime): State<AppRuntime>,
319 Json(request): Json<CreateTaskRequest>,
320) -> Response {
321 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
322 return response;
323 }
324 if request.title.trim().is_empty() {
325 return api_error("title is required");
326 }
327 let content = if request.body.is_empty() {
328 request.title
329 } else {
330 format!("{}\n\n{}", request.title, request.body)
331 };
332 let event = crate::task::new_task(
333 runtime.workspace_id.clone(),
334 runtime.agent.clone(),
335 runtime.session_id.clone(),
336 content,
337 );
338 match runtime
339 .store
340 .append(with_repo_metadata(event, &runtime))
341 .await
342 {
343 Ok(event) => Json(event).into_response(),
344 Err(err) => api_error(&err.to_string()),
345 }
346}
347
348#[derive(Deserialize)]
349struct UpdateTaskRequest {
350 #[serde(default)]
351 text: String,
352}
353
354async fn update_task(
355 headers: HeaderMap,
356 State(runtime): State<AppRuntime>,
357 AxumPath(id): AxumPath<Uuid>,
358 Json(request): Json<UpdateTaskRequest>,
359) -> Response {
360 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
361 return response;
362 }
363 if request.text.trim().is_empty() {
364 return api_error("text is required");
365 }
366 if let Err(err) =
367 crate::task::ensure_task_exists(&runtime.store, &runtime.workspace_id, id).await
368 {
369 return api_error(&err.to_string());
370 }
371 let event = crate::task::new_task_update(
372 runtime.workspace_id.clone(),
373 runtime.agent.clone(),
374 runtime.session_id.clone(),
375 id,
376 request.text,
377 );
378 match runtime
379 .store
380 .append(with_repo_metadata(event, &runtime))
381 .await
382 {
383 Ok(event) => Json(event).into_response(),
384 Err(err) => api_error(&err.to_string()),
385 }
386}
387
388async fn done_task(
389 headers: HeaderMap,
390 State(runtime): State<AppRuntime>,
391 AxumPath(id): AxumPath<Uuid>,
392) -> Response {
393 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
394 return response;
395 }
396 if let Err(err) =
397 crate::task::ensure_task_exists(&runtime.store, &runtime.workspace_id, id).await
398 {
399 return api_error(&err.to_string());
400 }
401 let event = crate::task::new_task_done(
402 runtime.workspace_id.clone(),
403 runtime.agent.clone(),
404 runtime.session_id.clone(),
405 id,
406 );
407 match runtime
408 .store
409 .append(with_repo_metadata(event, &runtime))
410 .await
411 {
412 Ok(event) => Json(event).into_response(),
413 Err(err) => api_error(&err.to_string()),
414 }
415}
416
417fn api_error(message: &str) -> Response {
418 (StatusCode::BAD_REQUEST, Json(json!({ "error": message }))).into_response()
419}
420
421fn with_repo_metadata(mut event: Event, runtime: &AppRuntime) -> Event {
422 if let Ok(status) = crate::context::repo_status(&runtime.cwd) {
423 let repo_id = crate::context::repo_id(&status);
424 event.repo_id = Some(repo_id.clone());
425 event.repo_path = Some(status.repo_path.clone());
426 event.git_remote = status.git_remote.clone();
427 event.branch = Some(status.branch.clone());
428 event.commit = Some(status.commit.clone());
429 event.repo_dirty = Some(status.dirty);
430 if let Some(metadata) = event.metadata_json.as_object_mut() {
431 metadata.insert("repo_id".to_owned(), json!(repo_id));
432 metadata.insert("repo_path".to_owned(), json!(status.repo_path));
433 metadata.insert("git_remote".to_owned(), json!(status.git_remote));
434 metadata.insert("branch".to_owned(), json!(status.branch));
435 metadata.insert("commit".to_owned(), json!(status.commit));
436 metadata.insert("repo_dirty".to_owned(), json!(status.dirty));
437 metadata.insert("dirty_files".to_owned(), json!(status.dirty_files));
438 }
439 }
440 event
441}
442
443async fn mcp_health(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
444 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
445 return response;
446 }
447 with_cors((StatusCode::OK, "Shuttle MCP server"))
448}
449
450async fn mcp_delete(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
451 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
452 return response;
453 }
454 with_cors((StatusCode::OK, "OK"))
455}
456
457async fn mcp_options() -> Response {
458 with_cors(StatusCode::NO_CONTENT)
459}
460
461async fn mcp_post(
462 headers: HeaderMap,
463 State(runtime): State<AppRuntime>,
464 Json(request): Json<crate::mcp::Request>,
465) -> Response {
466 if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
467 return response;
468 }
469 let response = crate::mcp::handle_request(
470 &crate::mcp::McpRuntime {
471 store: runtime.store,
472 cwd: runtime.cwd,
473 workspace_id: runtime.workspace_id,
474 agent: runtime.agent,
475 session_id: runtime.session_id,
476 },
477 request,
478 )
479 .await;
480 with_cors(Json(response))
481}
482
483async fn oauth_protected_resource(State(runtime): State<AppRuntime>) -> Response {
484 let Some(oauth) = runtime.oauth else {
485 return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
486 };
487 Json(oauth::protected_resource_metadata(&oauth.config)).into_response()
488}
489
490async fn oauth_authorization_server(State(runtime): State<AppRuntime>) -> Response {
491 let Some(oauth) = runtime.oauth else {
492 return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
493 };
494 Json(oauth::authorization_server_metadata(&oauth.config)).into_response()
495}
496
497async fn oauth_register(
498 State(runtime): State<AppRuntime>,
499 Json(request): Json<oauth::RegisterRequest>,
500) -> Response {
501 let Some(oauth) = runtime.oauth else {
502 return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
503 };
504 match oauth.store.register_client(request) {
505 Ok(client) => Json(json!({
506 "client_id": client.client_id,
507 "client_id_issued_at": chrono::Utc::now().timestamp(),
508 "redirect_uris": client.redirect_uris,
509 "client_name": client.client_name,
510 "token_endpoint_auth_method": "none",
511 }))
512 .into_response(),
513 Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
514 }
515}
516
517async fn oauth_authorize_page(
518 State(runtime): State<AppRuntime>,
519 Query(request): Query<oauth::AuthorizeRequest>,
520) -> Response {
521 let Some(oauth) = runtime.oauth else {
522 return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
523 };
524 if request.response_type != "code" {
525 return oauth_error(
526 StatusCode::BAD_REQUEST,
527 "unsupported_response_type",
528 "response_type must be code",
529 );
530 }
531 match oauth
532 .store
533 .client_allows_redirect(&request.client_id, &request.redirect_uri)
534 {
535 Ok(true) => {
536 Html(authorize_html(&request, oauth.config.admin_token.is_some())).into_response()
537 }
538 Ok(false) => oauth_error(
539 StatusCode::BAD_REQUEST,
540 "invalid_request",
541 "unknown client_id or redirect_uri",
542 ),
543 Err(_) => oauth_error(
544 StatusCode::INTERNAL_SERVER_ERROR,
545 "server_error",
546 "failed to validate OAuth client",
547 ),
548 }
549}
550
551async fn oauth_authorize_submit(
552 State(runtime): State<AppRuntime>,
553 Form(form): Form<oauth::AuthorizeForm>,
554) -> Response {
555 let Some(oauth) = runtime.oauth else {
556 return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
557 };
558 if let Some(expected) = oauth.config.admin_token.as_deref() {
559 if !constant_time_eq(form.admin_token.as_bytes(), expected.as_bytes()) {
560 return oauth_error(
561 StatusCode::UNAUTHORIZED,
562 "access_denied",
563 "invalid admin token",
564 );
565 }
566 }
567 let request = oauth::AuthorizeRequest::from(form);
568 if request.response_type != "code" {
569 return oauth_error(
570 StatusCode::BAD_REQUEST,
571 "unsupported_response_type",
572 "response_type must be code",
573 );
574 }
575 match oauth.store.create_code(request.clone()) {
576 Ok(code) => Redirect::to(&oauth::authorize_redirect(
577 &request.redirect_uri,
578 &code,
579 request.state.as_deref(),
580 ))
581 .into_response(),
582 Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
583 }
584}
585
586async fn oauth_token(
587 State(runtime): State<AppRuntime>,
588 Form(request): Form<oauth::TokenRequest>,
589) -> Response {
590 let Some(oauth) = runtime.oauth else {
591 return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
592 };
593 if request.grant_type != "authorization_code" {
594 return oauth_error(
595 StatusCode::BAD_REQUEST,
596 "unsupported_grant_type",
597 "grant_type must be authorization_code",
598 );
599 }
600 match oauth.store.exchange_code(request) {
601 Ok(token) => Json(token).into_response(),
602 Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_grant", &err.to_string()),
603 }
604}
605
606fn mcp_unauthorized_response(
607 oauth: Option<&OAuthRuntime>,
608 headers: &HeaderMap,
609) -> Option<Response> {
610 if let Some(oauth) = oauth {
611 let Some(token) = bearer_token(headers) else {
612 return Some(unauthorized_oauth(&oauth.config));
613 };
614 return match oauth.store.validate_access_token(token) {
615 Ok(true) => None,
616 Ok(false) => Some(unauthorized_oauth(&oauth.config)),
617 Err(_) => Some(oauth_error(
618 StatusCode::UNAUTHORIZED,
619 "invalid_token",
620 "failed to validate access token",
621 )),
622 };
623 }
624
625 let token = env::var("SHUTTLE_MCP_BEARER_TOKEN")
626 .ok()
627 .filter(|token| !token.is_empty())?;
628 let expected = format!("Bearer {token}");
629 if headers
630 .get("authorization")
631 .and_then(|header| header.to_str().ok())
632 .is_some_and(|actual| constant_time_eq(actual.as_bytes(), expected.as_bytes()))
633 {
634 None
635 } else {
636 Some(with_cors(StatusCode::UNAUTHORIZED))
637 }
638}
639
640fn bearer_token(headers: &HeaderMap) -> Option<&str> {
641 headers
642 .get(header::AUTHORIZATION)
643 .and_then(|header| header.to_str().ok())
644 .and_then(|value| {
645 let (scheme, token) = value.split_once(' ')?;
646 scheme.eq_ignore_ascii_case("Bearer").then_some(token)
647 })
648}
649
650fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
651 let mut diff = left.len() ^ right.len();
652 for index in 0..left.len().max(right.len()) {
653 let left = *left.get(index).unwrap_or(&0);
654 let right = *right.get(index).unwrap_or(&0);
655 diff |= (left ^ right) as usize;
656 }
657 diff == 0
658}
659
660fn with_cors(response: impl IntoResponse) -> Response {
661 let (mut parts, body) = response.into_response().into_parts();
662 parts
663 .headers
664 .insert("access-control-allow-origin", HeaderValue::from_static("*"));
665 parts.headers.insert(
666 "access-control-allow-methods",
667 HeaderValue::from_static("GET,POST,DELETE,OPTIONS"),
668 );
669 parts.headers.insert(
670 "access-control-allow-headers",
671 HeaderValue::from_static(
672 "accept,authorization,content-type,mcp-protocol-version,mcp-session-id",
673 ),
674 );
675 parts.headers.insert(
676 "access-control-expose-headers",
677 HeaderValue::from_static("mcp-session-id"),
678 );
679 Response::from_parts(parts, body)
680}
681
682fn unauthorized_oauth(config: &OAuthConfig) -> Response {
683 let mut response = with_cors(StatusCode::UNAUTHORIZED);
684 let header_value = format!(
685 r#"Bearer resource_metadata="{}/.well-known/oauth-protected-resource/mcp", scope="mcp""#,
686 quoted_header_value(&config.public_url)
687 );
688 if let Ok(value) = HeaderValue::from_str(&header_value) {
689 response
690 .headers_mut()
691 .insert(header::WWW_AUTHENTICATE, value);
692 }
693 response
694}
695
696fn oauth_error(status: StatusCode, code: &str, description: &str) -> Response {
697 (
698 status,
699 Json(json!({ "error": code, "error_description": description })),
700 )
701 .into_response()
702}
703
704fn authorize_html(request: &oauth::AuthorizeRequest, requires_admin_token: bool) -> String {
705 let admin = if requires_admin_token {
706 r#"<label>Admin token <input name="admin_token" type="password" autocomplete="current-password" required></label>"#
707 } else {
708 r#"<input name="admin_token" type="hidden" value="">"#
709 };
710 format!(
711 r#"<!doctype html>
712<html>
713<head>
714 <meta charset="utf-8">
715 <meta name="viewport" content="width=device-width, initial-scale=1">
716 <title>Authorize Shuttle</title>
717 <style>
718 body {{ font-family: system-ui, sans-serif; margin: 2rem; color: #1f2937; }}
719 form {{ display: grid; gap: 1rem; max-width: 32rem; }}
720 input, button {{ font: inherit; padding: .6rem; }}
721 label {{ display: grid; gap: .35rem; }}
722 </style>
723</head>
724<body>
725 <h1>Authorize Shuttle MCP</h1>
726 <p>{client_id} is requesting access to Shuttle MCP.</p>
727 <form method="post" action="/oauth/authorize">
728 {admin}
729 <input type="hidden" name="response_type" value="{response_type}">
730 <input type="hidden" name="client_id" value="{client_id}">
731 <input type="hidden" name="redirect_uri" value="{redirect_uri}">
732 <input type="hidden" name="state" value="{state}">
733 <input type="hidden" name="scope" value="{scope}">
734 <input type="hidden" name="code_challenge" value="{code_challenge}">
735 <input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
736 <button type="submit">Authorize</button>
737 </form>
738</body>
739</html>"#,
740 admin = admin,
741 response_type = html_escape(&request.response_type),
742 client_id = html_escape(&request.client_id),
743 redirect_uri = html_escape(&request.redirect_uri),
744 state = html_escape(request.state.as_deref().unwrap_or("")),
745 scope = html_escape(request.scope.as_deref().unwrap_or("mcp")),
746 code_challenge = html_escape(request.code_challenge.as_deref().unwrap_or("")),
747 code_challenge_method =
748 html_escape(request.code_challenge_method.as_deref().unwrap_or("S256")),
749 )
750}
751
752fn html_escape(value: &str) -> String {
753 value
754 .replace('&', "&")
755 .replace('<', "<")
756 .replace('>', ">")
757 .replace('"', """)
758}
759
760fn quoted_header_value(value: &str) -> String {
761 value.replace('\\', "\\\\").replace('"', "\\\"")
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767 use axum::body::Body;
768 use axum::http::{Method, Request};
769 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
770 use base64::Engine;
771 use http_body_util::BodyExt;
772 use sha2::{Digest, Sha256};
773 use tower::ServiceExt;
774
775 fn runtime(oauth: Option<OAuthRuntime>) -> AppRuntime {
776 let dir = tempfile::tempdir().unwrap().keep();
777 let db = dir.join("shuttle.db");
778 AppRuntime {
779 store: SqliteEventStore::open(&db).unwrap(),
780 cwd: dir,
781 workspace_id: "workspace".to_owned(),
782 agent: "codex".to_owned(),
783 session_id: "session".to_owned(),
784 oauth,
785 }
786 }
787
788 fn oauth_runtime() -> OAuthRuntime {
789 let dir = tempfile::tempdir().unwrap().keep();
790 OAuthRuntime {
791 config: OAuthConfig {
792 public_url: "https://shuttle.example.test".to_owned(),
793 admin_token: Some("admin-token".to_owned()),
794 },
795 store: OAuthStore::open(dir.join("oauth.db")).unwrap(),
796 }
797 }
798
799 fn issue_access_token(oauth: &OAuthRuntime) -> String {
800 let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
801 let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
802 let client = oauth
803 .store
804 .register_client(oauth::RegisterRequest {
805 redirect_uris: vec!["https://client.example.test/callback".to_owned()],
806 client_name: Some("client".to_owned()),
807 })
808 .unwrap();
809 let code = oauth
810 .store
811 .create_code(oauth::AuthorizeRequest {
812 response_type: "code".to_owned(),
813 client_id: client.client_id.clone(),
814 redirect_uri: "https://client.example.test/callback".to_owned(),
815 state: None,
816 scope: Some("mcp".to_owned()),
817 code_challenge: Some(challenge),
818 code_challenge_method: Some("S256".to_owned()),
819 })
820 .unwrap();
821 oauth
822 .store
823 .exchange_code(oauth::TokenRequest {
824 grant_type: "authorization_code".to_owned(),
825 client_id: client.client_id,
826 redirect_uri: "https://client.example.test/callback".to_owned(),
827 code: Some(code),
828 code_verifier: Some(verifier.to_owned()),
829 })
830 .unwrap()
831 .access_token
832 }
833
834 async fn request(
835 runtime: AppRuntime,
836 path: &str,
837 authorization: Option<&str>,
838 ) -> axum::response::Response {
839 let mut builder = Request::builder().method(Method::GET).uri(path);
840 if let Some(authorization) = authorization {
841 builder = builder.header(header::AUTHORIZATION, authorization);
842 }
843 router(runtime)
844 .oneshot(builder.body(Body::empty()).unwrap())
845 .await
846 .unwrap()
847 }
848
849 #[tokio::test]
850 async fn dashboard_routes_require_bearer_when_oauth_is_configured() {
851 let oauth = oauth_runtime();
852 let token = issue_access_token(&oauth);
853
854 let index = request(runtime(Some(oauth.clone())), "/", None).await;
855 let dashboard = request(runtime(Some(oauth.clone())), "/api/dashboard", None).await;
856 let authorized_index = request(
857 runtime(Some(oauth.clone())),
858 "/",
859 Some(&format!("Bearer {token}")),
860 )
861 .await;
862 let authorized_dashboard = request(
863 runtime(Some(oauth)),
864 "/api/dashboard",
865 Some(&format!("Bearer {token}")),
866 )
867 .await;
868
869 assert_eq!(index.status(), StatusCode::UNAUTHORIZED);
870 assert_eq!(dashboard.status(), StatusCode::UNAUTHORIZED);
871 assert_eq!(authorized_index.status(), StatusCode::OK);
872 assert_eq!(authorized_dashboard.status(), StatusCode::OK);
873 }
874
875 #[tokio::test]
876 async fn dashboard_routes_remain_local_open_without_auth_configuration() {
877 let index = request(runtime(None), "/", None).await;
878 let dashboard = request(runtime(None), "/api/dashboard", None).await;
879
880 assert_eq!(index.status(), StatusCode::OK);
881 assert_eq!(dashboard.status(), StatusCode::OK);
882 }
883
884 #[tokio::test]
885 async fn oauth_metadata_is_not_blocked_by_dashboard_auth() {
886 let response = request(
887 runtime(Some(oauth_runtime())),
888 "/.well-known/oauth-authorization-server",
889 None,
890 )
891 .await;
892
893 assert_eq!(response.status(), StatusCode::OK);
894 }
895
896 #[tokio::test]
897 async fn oauth_authorize_submit_redirects_callback_with_get() {
898 let oauth = oauth_runtime();
899 let client = oauth
900 .store
901 .register_client(oauth::RegisterRequest {
902 redirect_uris: vec!["https://claude.ai/api/mcp/auth_callback".to_owned()],
903 client_name: Some("Claude".to_owned()),
904 })
905 .unwrap();
906 let body = format!(
907 "admin_token=admin-token&response_type=code&client_id={}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&state=state-123&scope=mcp&code_challenge=challenge&code_challenge_method=S256",
908 client.client_id
909 );
910 let response = router(runtime(Some(oauth)))
911 .oneshot(
912 Request::builder()
913 .method(Method::POST)
914 .uri("/oauth/authorize")
915 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
916 .body(Body::from(body))
917 .unwrap(),
918 )
919 .await
920 .unwrap();
921
922 assert_eq!(response.status(), StatusCode::SEE_OTHER);
923 let location = response
924 .headers()
925 .get(header::LOCATION)
926 .and_then(|value| value.to_str().ok())
927 .unwrap();
928 assert!(location.starts_with("https://claude.ai/api/mcp/auth_callback?code=stl_"));
929 assert!(location.contains("&state=state-123"));
930 assert!(!location.contains("&iss="));
931 assert!(!location.contains("?iss="));
932 }
933
934 #[tokio::test]
935 async fn oauth_authorize_rejects_redirect_uri_with_trailing_slash() {
936 let oauth = oauth_runtime();
937 let client = oauth
938 .store
939 .register_client(oauth::RegisterRequest {
940 redirect_uris: vec!["https://claude.ai/api/mcp/auth_callback".to_owned()],
941 client_name: Some("Claude".to_owned()),
942 })
943 .unwrap();
944 let body = format!(
945 "admin_token=admin-token&response_type=code&client_id={}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback%2F&state=state-123&scope=mcp&code_challenge=challenge&code_challenge_method=S256",
946 client.client_id
947 );
948 let response = router(runtime(Some(oauth)))
949 .oneshot(
950 Request::builder()
951 .method(Method::POST)
952 .uri("/oauth/authorize")
953 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
954 .body(Body::from(body))
955 .unwrap(),
956 )
957 .await
958 .unwrap();
959
960 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
961 }
962
963 #[tokio::test]
964 async fn dashboard_html_renders_json_as_text_not_inner_html() {
965 let response = request(runtime(None), "/", None).await;
966 let body = response.into_body().collect().await.unwrap().to_bytes();
967 let html = String::from_utf8(body.to_vec()).unwrap();
968
969 assert!(html.contains("heading.textContent = name"));
970 assert!(html.contains("pre.textContent = JSON.stringify(value, null, 2)"));
971 assert!(!html.contains("section.innerHTML"));
972 }
973
974 #[tokio::test]
975 async fn backend_api_can_remember_and_recall() {
976 let runtime = runtime(None);
977 let app = router(runtime);
978 let remember = app
979 .clone()
980 .oneshot(
981 Request::builder()
982 .method(Method::POST)
983 .uri("/api/remember")
984 .header(header::CONTENT_TYPE, "application/json")
985 .body(Body::from(
986 r#"{"kind":"fact","text":"SQLite backs Shuttle"}"#,
987 ))
988 .unwrap(),
989 )
990 .await
991 .unwrap();
992 assert_eq!(remember.status(), StatusCode::OK);
993
994 let recall = app
995 .oneshot(
996 Request::builder()
997 .method(Method::POST)
998 .uri("/api/recall")
999 .header(header::CONTENT_TYPE, "application/json")
1000 .body(Body::from(r#"{"query":"SQLite"}"#))
1001 .unwrap(),
1002 )
1003 .await
1004 .unwrap();
1005 assert_eq!(recall.status(), StatusCode::OK);
1006 let body = recall.into_body().collect().await.unwrap().to_bytes();
1007 let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
1008 assert!(value
1009 .as_array()
1010 .unwrap()
1011 .iter()
1012 .any(|item| item["event"]["content"] == "SQLite backs Shuttle"));
1013 }
1014}