1use std::collections::{BTreeMap, BTreeSet};
2use std::env;
3use std::io::Write;
4use std::net::{IpAddr, SocketAddr};
5use std::path::{Path, PathBuf};
6use std::sync::{Arc, Mutex, MutexGuard};
7use std::time::Duration;
8
9use async_trait::async_trait;
10use axum::extract::{Form, Path as AxumPath, Query, State};
11use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
12use axum::response::{Html, IntoResponse, Redirect, Response};
13use axum::routing::{get, patch, post};
14use axum::{Json, Router};
15use serde::{Deserialize, Serialize};
16use serde_json::{json, Value};
17use tokio::process::Command;
18use toml_edit::{value, DocumentMut, Item, Table};
19
20use crate::core::{Result, ShuttleError};
21use crate::oauth::{self, OAuthConfig, OAuthStore};
22
23#[derive(Debug, Clone, Deserialize)]
24pub struct GatewayConfig {
25 #[serde(skip)]
26 config_path: Option<PathBuf>,
27 #[serde(default)]
28 pub server: ServerConfig,
29 #[serde(default)]
30 pub auth: AuthConfig,
31 #[serde(default)]
32 pub oauth: OAuthGatewayConfig,
33 #[serde(default)]
34 pub defaults: DefaultsConfig,
35 #[serde(default)]
36 pub listeners: Vec<ListenerConfig>,
37 pub projects: BTreeMap<String, ProjectConfig>,
38}
39
40#[derive(Debug, Clone, Deserialize)]
41pub struct ServerConfig {
42 #[serde(default = "default_addr")]
43 pub addr: SocketAddr,
44}
45
46impl Default for ServerConfig {
47 fn default() -> Self {
48 Self {
49 addr: default_addr(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Deserialize)]
55pub struct AuthConfig {
56 #[serde(default = "default_gateway_token_env")]
57 pub bearer_token_env: String,
58}
59
60impl Default for AuthConfig {
61 fn default() -> Self {
62 Self {
63 bearer_token_env: default_gateway_token_env(),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Deserialize)]
69pub struct OAuthGatewayConfig {
70 #[serde(default)]
71 pub public_url: String,
72 #[serde(default)]
73 pub db_path: Option<PathBuf>,
74 #[serde(default = "default_oauth_admin_token_env")]
75 pub admin_token_env: String,
76}
77
78impl Default for OAuthGatewayConfig {
79 fn default() -> Self {
80 Self {
81 public_url: String::new(),
82 db_path: None,
83 admin_token_env: default_oauth_admin_token_env(),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Default, Deserialize)]
89pub struct DefaultsConfig {
90 #[serde(default)]
91 pub project: String,
92}
93
94#[derive(Debug, Clone, Deserialize)]
95pub struct ProjectConfig {
96 #[serde(default)]
97 pub backend: ProjectBackendKind,
98 #[serde(default)]
99 pub repo: Option<PathBuf>,
100 #[serde(default)]
101 pub db: Option<PathBuf>,
102 #[serde(default)]
103 pub url: String,
104 #[serde(default)]
105 pub token_env: Option<String>,
106 #[serde(default)]
107 pub description: Option<String>,
108}
109
110#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, PartialEq, Eq)]
111#[serde(rename_all = "lowercase")]
112pub enum ProjectBackendKind {
113 #[default]
114 Local,
115 Http,
116}
117
118#[derive(Debug, Clone, Deserialize)]
119pub struct ListenerConfig {
120 pub name: String,
121 pub addr: SocketAddr,
122 pub auth: ListenerAuthKind,
123 #[serde(default)]
124 pub public_url: String,
125 #[serde(default)]
126 pub oauth_db_path: Option<PathBuf>,
127 #[serde(default = "default_oauth_admin_token_env")]
128 pub oauth_admin_token_env: String,
129 #[serde(default = "default_gateway_token_env")]
130 pub bearer_token_env: String,
131}
132
133#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
134#[serde(rename_all = "lowercase")]
135pub enum ListenerAuthKind {
136 OAuth,
137 Bearer,
138 None,
139}
140
141impl GatewayConfig {
142 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
143 let path = path.as_ref();
144 let abs_path = path
145 .canonicalize()
146 .or_else(|_| {
147 path.parent()
148 .unwrap_or_else(|| Path::new("."))
149 .canonicalize()
150 .map(|parent| parent.join(path.file_name().unwrap_or_default()))
151 })
152 .map_err(|err| ShuttleError::Store(err.to_string()))?;
153 let raw =
154 std::fs::read_to_string(path).map_err(|err| ShuttleError::Store(err.to_string()))?;
155 let mut cfg: GatewayConfig =
156 toml::from_str(&raw).map_err(|err| ShuttleError::Serialization(err.to_string()))?;
157 cfg.config_path = Some(abs_path.clone());
158
159 cfg.oauth.public_url = normalize_public_url(&cfg.oauth.public_url);
160 if cfg.projects.is_empty() {
161 return Err(ShuttleError::Store(
162 "at least one project is required".to_owned(),
163 ));
164 }
165 for (name, project) in &cfg.projects {
166 validate_project_config(name, project)?;
167 }
168 if !cfg.defaults.project.is_empty() && !cfg.projects.contains_key(&cfg.defaults.project) {
169 return Err(ShuttleError::Store(format!(
170 "default project {:?} is not configured",
171 cfg.defaults.project
172 )));
173 }
174 if !cfg.oauth.public_url.is_empty() {
175 match &cfg.oauth.db_path {
176 Some(path) if !path.is_absolute() => {
177 return Err(ShuttleError::Store(
178 "oauth db_path must be an absolute path when set".to_owned(),
179 ))
180 }
181 Some(_) => {}
182 None => {
183 cfg.oauth.db_path = Some(
184 abs_path
185 .parent()
186 .unwrap_or_else(|| Path::new("."))
187 .join("gateway-oauth.db"),
188 );
189 }
190 }
191 }
192 let config_dir = abs_path.parent().unwrap_or_else(|| Path::new("."));
193 for listener in &mut cfg.listeners {
194 listener.public_url = normalize_public_url(&listener.public_url);
195 if listener.name.trim().is_empty() {
196 return Err(ShuttleError::Store(
197 "listener name cannot be empty".to_owned(),
198 ));
199 }
200 match listener.auth {
201 ListenerAuthKind::OAuth => {
202 if listener.public_url.is_empty() {
203 return Err(ShuttleError::Store(format!(
204 "listener {:?} public_url is required for oauth auth",
205 listener.name
206 )));
207 }
208 match &listener.oauth_db_path {
209 Some(path) if !path.is_absolute() => {
210 return Err(ShuttleError::Store(format!(
211 "listener {:?} oauth_db_path must be an absolute path when set",
212 listener.name
213 )));
214 }
215 Some(_) => {}
216 None => {
217 listener.oauth_db_path = Some(
218 config_dir.join(format!("gateway-{}-oauth.db", listener.name)),
219 );
220 }
221 }
222 }
223 ListenerAuthKind::Bearer => {}
224 ListenerAuthKind::None => {
225 if !is_loopback_addr(listener.addr) {
226 return Err(ShuttleError::Store(format!(
227 "listener {:?} auth none is only allowed on loopback addresses",
228 listener.name
229 )));
230 }
231 }
232 }
233 }
234 Ok(cfg)
235 }
236}
237
238#[derive(Clone)]
239pub struct GatewayRuntime {
240 service: Arc<GatewayService>,
241 auth: GatewayAuth,
242}
243
244#[derive(Clone)]
245pub struct GatewayListener {
246 pub name: String,
247 pub addr: SocketAddr,
248 pub runtime: GatewayRuntime,
249}
250
251impl GatewayRuntime {
252 pub fn from_config(config: GatewayConfig, stl: PathBuf, timeout: Duration) -> Result<Self> {
253 let listeners = Self::listeners_from_config(config, stl, timeout)?;
254 if listeners.len() != 1 {
255 return Err(ShuttleError::Store(
256 "GatewayRuntime::from_config requires exactly one listener".to_owned(),
257 ));
258 }
259 Ok(listeners.into_iter().next().unwrap().runtime)
260 }
261
262 pub fn listeners_from_config(
263 mut config: GatewayConfig,
264 stl: PathBuf,
265 timeout: Duration,
266 ) -> Result<Vec<GatewayListener>> {
267 let listener_configs = if config.listeners.is_empty() {
268 vec![legacy_listener_config(&config)]
269 } else {
270 std::mem::take(&mut config.listeners)
271 };
272 let config_path = config.config_path.clone();
273 let registry = ProjectRegistry::new(config.defaults.project, config.projects)?;
274 let service = Arc::new(GatewayService::new_with_config_path(
275 registry,
276 Arc::new(SubprocessRunner {
277 binary: stl,
278 timeout,
279 }),
280 config_path,
281 ));
282 listener_configs
283 .into_iter()
284 .map(|listener| {
285 let auth = auth_from_listener(&listener)?;
286 Ok(GatewayListener {
287 name: listener.name,
288 addr: listener.addr,
289 runtime: GatewayRuntime {
290 service: service.clone(),
291 auth,
292 },
293 })
294 })
295 .collect()
296 }
297}
298
299#[derive(Clone)]
300enum GatewayAuth {
301 Bearer { token_env: String },
302 OAuth(Arc<OAuthRuntime>),
303 None,
304}
305
306#[derive(Clone)]
307struct OAuthRuntime {
308 config: OAuthConfig,
309 store: OAuthStore,
310}
311
312pub async fn serve(runtime: GatewayRuntime, addr: SocketAddr) -> Result<()> {
313 let listener = tokio::net::TcpListener::bind(addr)
314 .await
315 .map_err(|err| ShuttleError::Store(err.to_string()))?;
316 axum::serve(listener, router(runtime))
317 .await
318 .map_err(|err| ShuttleError::Store(err.to_string()))
319}
320
321pub async fn serve_listeners(listeners: Vec<GatewayListener>) -> Result<()> {
322 if listeners.is_empty() {
323 return Err(ShuttleError::Store(
324 "at least one listener is required".to_owned(),
325 ));
326 }
327 let mut tasks = Vec::new();
328 for listener in listeners {
329 let addr = listener.addr;
330 let runtime = listener.runtime;
331 let name = listener.name;
332 tasks.push(tokio::spawn(async move {
333 let tcp = tokio::net::TcpListener::bind(addr)
334 .await
335 .map_err(|err| ShuttleError::Store(format!("listener {name}: {err}")))?;
336 axum::serve(tcp, router(runtime))
337 .await
338 .map_err(|err| ShuttleError::Store(format!("listener {name}: {err}")))
339 }));
340 }
341 for task in tasks {
342 task.await
343 .map_err(|err| ShuttleError::Store(err.to_string()))??;
344 }
345 Ok(())
346}
347
348pub fn router(runtime: GatewayRuntime) -> Router {
349 Router::new()
350 .route("/api/projects", get(api_projects).post(api_add_project))
351 .route("/api/projects/current", get(api_current_project))
352 .route("/api/projects/use", post(api_use_project))
353 .route("/api/recall", post(api_recall))
354 .route("/api/remember", post(api_remember))
355 .route("/api/context", get(api_context))
356 .route("/api/tasks", get(api_tasks).post(api_create_task))
357 .route("/api/tasks/{id}", patch(api_update_task))
358 .route("/api/tasks/{id}/done", post(api_done_task))
359 .route(
360 "/mcp",
361 get(mcp_health)
362 .post(mcp_post)
363 .delete(mcp_delete)
364 .options(mcp_options),
365 )
366 .route(
367 "/.well-known/oauth-protected-resource",
368 get(oauth_protected_resource),
369 )
370 .route(
371 "/.well-known/oauth-protected-resource/mcp",
372 get(oauth_protected_resource),
373 )
374 .route(
375 "/.well-known/oauth-authorization-server",
376 get(oauth_authorization_server),
377 )
378 .route("/oauth/register", post(oauth_register))
379 .route(
380 "/oauth/authorize",
381 get(oauth_authorize_page).post(oauth_authorize_submit),
382 )
383 .route("/oauth/token", post(oauth_token))
384 .with_state(runtime)
385}
386
387#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
388pub struct Project {
389 pub name: String,
390 pub backend: ProjectBackendKind,
391 #[serde(skip_serializing_if = "Option::is_none")]
392 pub repo: Option<PathBuf>,
393 #[serde(skip_serializing_if = "Option::is_none")]
394 pub db: Option<PathBuf>,
395 #[serde(skip_serializing_if = "String::is_empty")]
396 pub url: String,
397 #[serde(skip)]
398 pub token_env: Option<String>,
399 #[serde(skip_serializing_if = "Option::is_none")]
400 pub description: Option<String>,
401}
402
403#[derive(Debug)]
404pub struct ProjectRegistry {
405 default_project: String,
406 projects: BTreeMap<String, Project>,
407}
408
409impl ProjectRegistry {
410 pub fn new(default_project: String, configs: BTreeMap<String, ProjectConfig>) -> Result<Self> {
411 for (name, cfg) in &configs {
412 validate_project_config(name, cfg)?;
413 }
414 let projects = configs
415 .into_iter()
416 .map(|(name, cfg)| {
417 let project = project_from_config(name.clone(), cfg);
418 (name, project)
419 })
420 .collect::<BTreeMap<_, _>>();
421 if !default_project.is_empty() && !projects.contains_key(&default_project) {
422 return Err(ShuttleError::Store(format!(
423 "default project {default_project:?} is not configured"
424 )));
425 }
426 Ok(Self {
427 default_project,
428 projects,
429 })
430 }
431
432 pub fn list(&self) -> Vec<Project> {
433 self.projects.values().cloned().collect()
434 }
435
436 pub fn names(&self) -> BTreeSet<String> {
437 self.projects.keys().cloned().collect()
438 }
439
440 pub fn insert_named(&mut self, name: String, config: ProjectConfig) -> Result<Project> {
441 validate_project_config(&name, &config)?;
442 if self.projects.contains_key(&name) {
443 return Err(ShuttleError::Store(format!(
444 "project {name:?} is already configured"
445 )));
446 }
447 let project = project_from_config(name.clone(), config);
448 self.projects.insert(name, project.clone());
449 Ok(project)
450 }
451
452 pub fn get(&self, name: &str) -> Option<Project> {
453 self.projects.get(name).cloned()
454 }
455
456 pub fn default(&self) -> Option<Project> {
457 (!self.default_project.is_empty())
458 .then(|| self.get(&self.default_project))
459 .flatten()
460 }
461
462 pub fn resolve(&self, project: &str, write: bool) -> Result<Project> {
463 if !project.is_empty() {
464 return self
465 .get(project)
466 .ok_or_else(|| ShuttleError::Store(format!("unknown project {project:?}")));
467 }
468 if write {
469 return Err(ShuttleError::Store(
470 "project is required for write operations".to_owned(),
471 ));
472 }
473 self.default()
474 .ok_or_else(|| ShuttleError::Store("project is required".to_owned()))
475 }
476}
477
478fn project_from_config(name: String, cfg: ProjectConfig) -> Project {
479 Project {
480 name,
481 backend: cfg.backend,
482 repo: cfg.repo,
483 db: cfg.db,
484 url: normalize_public_url(&cfg.url),
485 token_env: cfg.token_env,
486 description: cfg.description,
487 }
488}
489
490fn validate_project_config(name: &str, project: &ProjectConfig) -> Result<()> {
491 validate_project_name(name)?;
492 match project.backend {
493 ProjectBackendKind::Local => {
494 let Some(repo) = &project.repo else {
495 return Err(ShuttleError::Store(format!(
496 "project {name:?} repo is required for local backend"
497 )));
498 };
499 if !repo.is_absolute() {
500 return Err(ShuttleError::Store(format!(
501 "project {name:?} repo must be an absolute path"
502 )));
503 }
504 if let Some(db) = &project.db {
505 if !db.is_absolute() {
506 return Err(ShuttleError::Store(format!(
507 "project {name:?} db must be an absolute path when set"
508 )));
509 }
510 }
511 }
512 ProjectBackendKind::Http => {
513 if project.url.trim().is_empty() {
514 return Err(ShuttleError::Store(format!(
515 "project {name:?} url is required for http backend"
516 )));
517 }
518 }
519 }
520 Ok(())
521}
522
523fn validate_project_name(name: &str) -> Result<()> {
524 if name.trim().is_empty() {
525 return Err(ShuttleError::Store(
526 "project name cannot be empty".to_owned(),
527 ));
528 }
529 Ok(())
530}
531
532fn normalize_project_name(name: &str) -> Result<String> {
533 validate_project_name(name)?;
534 Ok(name.trim().to_owned())
535}
536
537fn unique_project_name(base_name: &str, existing_names: &BTreeSet<String>) -> String {
538 if !existing_names.contains(base_name) {
539 return base_name.to_owned();
540 }
541 for index in 2.. {
542 let candidate = format!("{base_name}-{index}");
543 if !existing_names.contains(&candidate) {
544 return candidate;
545 }
546 }
547 unreachable!("unbounded suffix search always returns")
548}
549
550fn persist_project_config(
551 config_path: &Path,
552 base_name: &str,
553 config: &ProjectConfig,
554 registry_names: &BTreeSet<String>,
555) -> Result<String> {
556 let raw =
557 std::fs::read_to_string(config_path).map_err(|err| ShuttleError::Store(err.to_string()))?;
558 let mut document = raw
559 .parse::<DocumentMut>()
560 .map_err(|err| ShuttleError::Serialization(err.to_string()))?;
561 let mut existing_names = registry_names.clone();
562 if let Some(projects) = document.get("projects").and_then(Item::as_table) {
563 existing_names.extend(projects.iter().map(|(name, _)| name.to_owned()));
564 }
565 let name = unique_project_name(base_name, &existing_names);
566
567 let projects = document
568 .entry("projects")
569 .or_insert_with(|| Item::Table(Table::new()))
570 .as_table_mut()
571 .ok_or_else(|| ShuttleError::Store("projects must be a TOML table".to_owned()))?;
572 projects[&name] = project_config_item(config);
573 write_config_atomically(config_path, document.to_string().as_bytes())?;
574 Ok(name)
575}
576
577fn project_config_item(config: &ProjectConfig) -> Item {
578 let mut table = Table::new();
579 table["backend"] = value(match config.backend {
580 ProjectBackendKind::Local => "local",
581 ProjectBackendKind::Http => "http",
582 });
583 if let Some(repo) = &config.repo {
584 table["repo"] = value(repo.display().to_string());
585 }
586 if let Some(db) = &config.db {
587 table["db"] = value(db.display().to_string());
588 }
589 if !config.url.trim().is_empty() {
590 table["url"] = value(normalize_public_url(&config.url));
591 }
592 if let Some(token_env) = &config.token_env {
593 if !token_env.trim().is_empty() {
594 table["token_env"] = value(token_env.trim());
595 }
596 }
597 if let Some(description) = &config.description {
598 if !description.trim().is_empty() {
599 table["description"] = value(description.trim());
600 }
601 }
602 Item::Table(table)
603}
604
605fn write_config_atomically(path: &Path, contents: &[u8]) -> Result<()> {
606 let dir = path.parent().unwrap_or_else(|| Path::new("."));
607 let mut temp =
608 tempfile::NamedTempFile::new_in(dir).map_err(|err| ShuttleError::Store(err.to_string()))?;
609 temp.write_all(contents)
610 .map_err(|err| ShuttleError::Store(err.to_string()))?;
611 temp.flush()
612 .map_err(|err| ShuttleError::Store(err.to_string()))?;
613 temp.persist(path)
614 .map_err(|err| ShuttleError::Store(err.to_string()))?;
615 Ok(())
616}
617
618#[derive(Debug, Serialize)]
619pub struct ServiceResponse {
620 pub project: String,
621 pub result: Value,
622 #[serde(skip_serializing_if = "Option::is_none")]
623 pub stored: Option<bool>,
624}
625
626pub struct GatewayService {
627 projects: Mutex<ProjectRegistry>,
628 runner: Arc<dyn Runner>,
629 config_path: Option<PathBuf>,
630 current: Mutex<String>,
631}
632
633impl GatewayService {
634 pub fn new(projects: ProjectRegistry, runner: Arc<dyn Runner>) -> Self {
635 Self::new_with_config_path(projects, runner, None)
636 }
637
638 pub fn new_with_config_path(
639 projects: ProjectRegistry,
640 runner: Arc<dyn Runner>,
641 config_path: Option<PathBuf>,
642 ) -> Self {
643 Self {
644 projects: Mutex::new(projects),
645 runner,
646 config_path,
647 current: Mutex::new(String::new()),
648 }
649 }
650
651 pub fn list_projects(&self) -> Result<Vec<Project>> {
652 Ok(self.projects()?.list())
653 }
654
655 pub fn add_project(
656 &self,
657 name: &str,
658 config: ProjectConfig,
659 make_current: bool,
660 ) -> Result<Project> {
661 let base_name = normalize_project_name(name)?;
664 validate_project_config(&base_name, &config)?;
665 let project = {
666 let mut projects = self.projects()?;
667 let name = if let Some(config_path) = &self.config_path {
668 persist_project_config(config_path, &base_name, &config, &projects.names())?
669 } else {
670 unique_project_name(&base_name, &projects.names())
671 };
672 projects.insert_named(name, config)?
673 };
674 if make_current {
675 *self.current()? = project.name.clone();
676 }
677 Ok(project)
678 }
679
680 pub fn use_project(&self, name: &str) -> Result<Project> {
681 let project = {
684 let projects = self.projects()?;
685 projects
686 .get(name)
687 .ok_or_else(|| ShuttleError::Store(format!("unknown project {name:?}")))?
688 };
689 *self.current()? = name.to_owned();
690 Ok(project)
691 }
692
693 pub fn current_project(&self) -> Result<Project> {
694 let current = self.current()?.clone();
695 if !current.is_empty() {
696 if let Some(project) = self.projects()?.get(¤t) {
697 return Ok(project);
698 }
699 }
700 self.projects()?
701 .default()
702 .ok_or_else(|| ShuttleError::Store("no current or default project".to_owned()))
703 }
704
705 pub async fn context(&self, project: &str) -> Result<ServiceResponse> {
706 self.run(project, false, &["context"]).await
707 }
708
709 pub async fn recall(&self, project: &str, query: &str) -> Result<ServiceResponse> {
710 require_non_empty(query, "query is required")?;
711 self.run(project, false, &["recall", query]).await
712 }
713
714 pub async fn remember(&self, project: &str, kind: &str, text: &str) -> Result<ServiceResponse> {
715 require_non_empty(text, "text is required")?;
716 let command = match kind {
717 "" | "memory" => "remember",
718 "decision" => "decide",
719 "observation" => "observe",
720 "pattern" => "pattern",
721 "fact" => "fact",
722 "bug" => "bug",
723 other => {
724 return Err(ShuttleError::Store(format!(
725 "unknown memory kind {other:?}"
726 )))
727 }
728 };
729 self.run(project, true, &[command, text]).await
730 }
731
732 pub async fn task_list(&self, project: &str) -> Result<ServiceResponse> {
733 self.run(project, false, &["task", "list"]).await
734 }
735
736 pub async fn task_create(
737 &self,
738 project: &str,
739 title: &str,
740 body: &str,
741 ) -> Result<ServiceResponse> {
742 require_non_empty(title, "title is required")?;
743 let content = if body.is_empty() {
744 title.to_owned()
745 } else {
746 format!("{title}\n\n{body}")
747 };
748 self.run(project, true, &["task", "create", &content]).await
749 }
750
751 pub async fn task_update(
752 &self,
753 project: &str,
754 id: &str,
755 text: &str,
756 ) -> Result<ServiceResponse> {
757 require_non_empty(id, "task id is required")?;
758 require_non_empty(text, "text is required")?;
759 self.run(project, true, &["task", "update", id, text]).await
760 }
761
762 pub async fn task_done(&self, project: &str, id: &str) -> Result<ServiceResponse> {
763 require_non_empty(id, "task id is required")?;
764 self.run(project, true, &["task", "done", id]).await
765 }
766
767 async fn run(&self, project: &str, write: bool, args: &[&str]) -> Result<ServiceResponse> {
768 let project = {
769 let projects = self.projects()?;
770 projects.resolve(project, write)?
771 };
772 let result = self.runner.run(&project, args).await.map_err(|err| {
773 ShuttleError::Store(format!("stl failed for project {}: {err}", project.name))
774 })?;
775 Ok(ServiceResponse {
776 project: project.name,
777 result,
778 stored: write.then_some(true),
779 })
780 }
781
782 fn projects(&self) -> Result<MutexGuard<'_, ProjectRegistry>> {
783 self.projects
784 .lock()
785 .map_err(|err| ShuttleError::Store(err.to_string()))
786 }
787
788 fn current(&self) -> Result<MutexGuard<'_, String>> {
789 self.current
790 .lock()
791 .map_err(|err| ShuttleError::Store(err.to_string()))
792 }
793}
794
795fn require_non_empty(value: &str, message: &str) -> Result<()> {
796 if value.trim().is_empty() {
797 return Err(ShuttleError::Store(message.to_owned()));
798 }
799 Ok(())
800}
801
802#[async_trait]
803pub trait Runner: Send + Sync {
804 async fn run(&self, project: &Project, args: &[&str]) -> std::result::Result<Value, String>;
805}
806
807pub struct SubprocessRunner {
808 binary: PathBuf,
809 timeout: Duration,
810}
811
812#[async_trait]
813impl Runner for SubprocessRunner {
814 async fn run(&self, project: &Project, args: &[&str]) -> std::result::Result<Value, String> {
815 if project.backend == ProjectBackendKind::Http {
816 return self.run_http(project, args).await;
817 }
818 let repo = project
819 .repo
820 .as_ref()
821 .ok_or_else(|| "repo is required for local backend".to_owned())?;
822 let mut command = Command::new(&self.binary);
823 command.arg("--json").args(args).current_dir(repo);
824 let output = tokio::time::timeout(self.timeout, command.output())
825 .await
826 .map_err(|_| format!("timed out after {}s", self.timeout.as_secs()))?
827 .map_err(|err| err.to_string())?;
828 if !output.status.success() {
829 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_owned();
830 return Err(if stderr.is_empty() {
831 format!("exit status {}", output.status)
832 } else {
833 stderr
834 });
835 }
836 serde_json::from_slice(&output.stdout).map_err(|err| err.to_string())
837 }
838}
839
840impl SubprocessRunner {
841 async fn run_http(
842 &self,
843 project: &Project,
844 args: &[&str],
845 ) -> std::result::Result<Value, String> {
846 let project = project.clone();
847 let args = args.iter().map(|arg| (*arg).to_owned()).collect::<Vec<_>>();
848 let timeout = self.timeout;
849 tokio::task::spawn_blocking(move || http_backend_call(&project, &args, timeout))
850 .await
851 .map_err(|err| err.to_string())?
852 }
853}
854
855fn http_backend_call(
856 project: &Project,
857 args: &[String],
858 timeout: Duration,
859) -> std::result::Result<Value, String> {
860 let base = project.url.trim_end_matches('/');
861 let agent = ureq::AgentBuilder::new().timeout(timeout).build();
862 let request = |method: &str, path: &str| {
863 let req = match method {
864 "GET" => agent.get(&format!("{base}{path}")),
865 "PATCH" => agent.request("PATCH", &format!("{base}{path}")),
866 _ => agent.post(&format!("{base}{path}")),
867 };
868 if let Some(token_env) = &project.token_env {
869 if let Ok(token) = env::var(token_env) {
870 if !token.is_empty() {
871 return req.set(header::AUTHORIZATION.as_str(), &format!("Bearer {token}"));
872 }
873 }
874 }
875 req
876 };
877 let response = match args {
878 [cmd] if cmd == "context" => request("GET", "/api/context").call(),
879 [cmd, query] if cmd == "recall" => {
880 request("POST", "/api/recall").send_json(json!({ "query": query }))
881 }
882 [cmd, text] if is_memory_command(cmd) => request("POST", "/api/remember")
883 .send_json(json!({ "kind": memory_kind_for_command(cmd), "text": text })),
884 [task, cmd] if task == "task" && cmd == "list" => request("GET", "/api/tasks").call(),
885 [task, cmd, content] if task == "task" && cmd == "create" => {
886 request("POST", "/api/tasks").send_json(json!({ "title": content, "body": "" }))
887 }
888 [task, cmd, id, text] if task == "task" && cmd == "update" => {
889 request("PATCH", &format!("/api/tasks/{id}")).send_json(json!({ "text": text }))
890 }
891 [task, cmd, id] if task == "task" && cmd == "done" => {
892 request("POST", &format!("/api/tasks/{id}/done")).send_json(json!({}))
893 }
894 _ => {
895 return Err(format!(
896 "unsupported http backend command: {}",
897 args.join(" ")
898 ))
899 }
900 };
901 let response = response.map_err(|err| match err {
902 ureq::Error::Status(status, response) => {
903 let body = response.into_string().unwrap_or_default();
904 if body.trim().is_empty() {
905 format!("http backend returned status {status}")
906 } else {
907 format!("http backend returned status {status}: {body}")
908 }
909 }
910 ureq::Error::Transport(err) => err.to_string(),
911 })?;
912 response.into_json::<Value>().map_err(|err| err.to_string())
913}
914
915fn is_memory_command(command: &str) -> bool {
916 matches!(
917 command,
918 "remember" | "decide" | "observe" | "pattern" | "fact" | "bug"
919 )
920}
921
922fn memory_kind_for_command(command: &str) -> &str {
923 match command {
924 "decide" => "decision",
925 "observe" => "observation",
926 "pattern" => "pattern",
927 "fact" => "fact",
928 "bug" => "bug",
929 _ => "memory",
930 }
931}
932
933async fn api_projects(State(runtime): State<GatewayRuntime>, headers: HeaderMap) -> Response {
934 if let Err(response) = authorize(&runtime, &headers, "/api/projects", false) {
935 return *response;
936 }
937 match runtime.service.list_projects() {
938 Ok(projects) => Json(json!({ "projects": projects })).into_response(),
939 Err(err) => error_response(err),
940 }
941}
942
943#[derive(Debug, Deserialize)]
944struct AddProjectRequest {
945 #[serde(default)]
946 name: String,
947 #[serde(default)]
948 backend: ProjectBackendKind,
949 #[serde(default)]
950 repo: Option<PathBuf>,
951 #[serde(default)]
952 db: Option<PathBuf>,
953 #[serde(default)]
954 url: String,
955 #[serde(default)]
956 token_env: Option<String>,
957 #[serde(default)]
958 description: Option<String>,
959 #[serde(default)]
960 make_current: bool,
961}
962
963impl AddProjectRequest {
964 fn into_parts(self) -> (String, ProjectConfig, bool) {
965 (
966 self.name,
967 ProjectConfig {
968 backend: self.backend,
969 repo: self.repo,
970 db: self.db,
971 url: self.url,
972 token_env: self.token_env,
973 description: self.description,
974 },
975 self.make_current,
976 )
977 }
978}
979
980async fn api_add_project(
981 State(runtime): State<GatewayRuntime>,
982 headers: HeaderMap,
983 Json(request): Json<AddProjectRequest>,
984) -> Response {
985 if let Err(response) = authorize(&runtime, &headers, "/api/projects", false) {
986 return *response;
987 }
988 let (name, config, make_current) = request.into_parts();
989 match runtime.service.add_project(&name, config, make_current) {
990 Ok(project) => (StatusCode::CREATED, Json(project)).into_response(),
991 Err(err) => error_response(err),
992 }
993}
994
995async fn api_current_project(
996 State(runtime): State<GatewayRuntime>,
997 headers: HeaderMap,
998) -> Response {
999 if let Err(response) = authorize(&runtime, &headers, "/api/projects/current", false) {
1000 return *response;
1001 }
1002 match runtime.service.current_project() {
1003 Ok(project) => Json(project).into_response(),
1004 Err(err) if err.to_string().contains("no current or default project") => (
1005 StatusCode::NOT_FOUND,
1006 Json(json!({"error": err.to_string()})),
1007 )
1008 .into_response(),
1009 Err(err) => error_response(err),
1010 }
1011}
1012
1013#[derive(Deserialize)]
1014struct ProjectRequest {
1015 #[serde(default)]
1016 project: String,
1017}
1018
1019async fn api_use_project(
1020 State(runtime): State<GatewayRuntime>,
1021 headers: HeaderMap,
1022 Json(request): Json<ProjectRequest>,
1023) -> Response {
1024 if let Err(response) = authorize(&runtime, &headers, "/api/projects/use", false) {
1025 return *response;
1026 }
1027 match runtime.service.use_project(&request.project) {
1028 Ok(project) => Json(project).into_response(),
1029 Err(err) => error_response(err),
1030 }
1031}
1032
1033#[derive(Deserialize)]
1034struct RecallRequest {
1035 #[serde(default)]
1036 project: String,
1037 #[serde(default)]
1038 query: String,
1039}
1040
1041async fn api_recall(
1042 State(runtime): State<GatewayRuntime>,
1043 headers: HeaderMap,
1044 Json(request): Json<RecallRequest>,
1045) -> Response {
1046 service_response(
1047 &runtime,
1048 &headers,
1049 "/api/recall",
1050 runtime
1051 .service
1052 .recall(&request.project, &request.query)
1053 .await,
1054 )
1055}
1056
1057#[derive(Deserialize)]
1058struct RememberRequest {
1059 #[serde(default)]
1060 project: String,
1061 #[serde(default)]
1062 kind: String,
1063 #[serde(default)]
1064 text: String,
1065}
1066
1067async fn api_remember(
1068 State(runtime): State<GatewayRuntime>,
1069 headers: HeaderMap,
1070 Json(request): Json<RememberRequest>,
1071) -> Response {
1072 service_response(
1073 &runtime,
1074 &headers,
1075 "/api/remember",
1076 runtime
1077 .service
1078 .remember(&request.project, &request.kind, &request.text)
1079 .await,
1080 )
1081}
1082
1083async fn api_context(
1084 State(runtime): State<GatewayRuntime>,
1085 headers: HeaderMap,
1086 Query(request): Query<ProjectRequest>,
1087) -> Response {
1088 service_response(
1089 &runtime,
1090 &headers,
1091 "/api/context",
1092 runtime.service.context(&request.project).await,
1093 )
1094}
1095
1096async fn api_tasks(
1097 State(runtime): State<GatewayRuntime>,
1098 headers: HeaderMap,
1099 Query(request): Query<ProjectRequest>,
1100) -> Response {
1101 service_response(
1102 &runtime,
1103 &headers,
1104 "/api/tasks",
1105 runtime.service.task_list(&request.project).await,
1106 )
1107}
1108
1109#[derive(Deserialize)]
1110struct CreateTaskRequest {
1111 #[serde(default)]
1112 project: String,
1113 #[serde(default)]
1114 title: String,
1115 #[serde(default)]
1116 body: String,
1117}
1118
1119async fn api_create_task(
1120 State(runtime): State<GatewayRuntime>,
1121 headers: HeaderMap,
1122 Json(request): Json<CreateTaskRequest>,
1123) -> Response {
1124 service_response(
1125 &runtime,
1126 &headers,
1127 "/api/tasks",
1128 runtime
1129 .service
1130 .task_create(&request.project, &request.title, &request.body)
1131 .await,
1132 )
1133}
1134
1135#[derive(Deserialize)]
1136struct UpdateTaskRequest {
1137 #[serde(default)]
1138 project: String,
1139 #[serde(default)]
1140 text: String,
1141}
1142
1143async fn api_update_task(
1144 State(runtime): State<GatewayRuntime>,
1145 headers: HeaderMap,
1146 AxumPath(id): AxumPath<String>,
1147 Json(request): Json<UpdateTaskRequest>,
1148) -> Response {
1149 service_response(
1150 &runtime,
1151 &headers,
1152 "/api/tasks",
1153 runtime
1154 .service
1155 .task_update(&request.project, &id, &request.text)
1156 .await,
1157 )
1158}
1159
1160async fn api_done_task(
1161 State(runtime): State<GatewayRuntime>,
1162 headers: HeaderMap,
1163 AxumPath(id): AxumPath<String>,
1164 Json(request): Json<ProjectRequest>,
1165) -> Response {
1166 service_response(
1167 &runtime,
1168 &headers,
1169 "/api/tasks",
1170 runtime.service.task_done(&request.project, &id).await,
1171 )
1172}
1173
1174fn service_response(
1175 runtime: &GatewayRuntime,
1176 headers: &HeaderMap,
1177 path: &str,
1178 response: Result<ServiceResponse>,
1179) -> Response {
1180 if let Err(response) = authorize(runtime, headers, path, false) {
1181 return *response;
1182 }
1183 match response {
1184 Ok(value) => Json(value).into_response(),
1185 Err(err) => error_response(err),
1186 }
1187}
1188
1189#[derive(Deserialize)]
1190struct RpcRequest {
1191 jsonrpc: Option<String>,
1192 id: Option<Value>,
1193 method: String,
1194 #[serde(default)]
1195 params: Value,
1196}
1197
1198async fn mcp_health(State(runtime): State<GatewayRuntime>, headers: HeaderMap) -> Response {
1199 authorize(&runtime, &headers, "/mcp", true)
1200 .map(|_| with_cors(Json(json!({ "status": "ok" }))))
1201 .unwrap_or_else(|response| *response)
1202}
1203
1204async fn mcp_delete(State(runtime): State<GatewayRuntime>, headers: HeaderMap) -> Response {
1205 authorize(&runtime, &headers, "/mcp", true)
1206 .map(|_| with_cors(StatusCode::OK))
1207 .unwrap_or_else(|response| *response)
1208}
1209
1210async fn mcp_options() -> Response {
1211 with_cors(StatusCode::NO_CONTENT)
1212}
1213
1214async fn mcp_post(
1215 State(runtime): State<GatewayRuntime>,
1216 headers: HeaderMap,
1217 Json(request): Json<RpcRequest>,
1218) -> Response {
1219 match authorize(&runtime, &headers, "/mcp", true) {
1220 Ok(()) if request.method == "notifications/initialized" => {
1221 with_cors(StatusCode::NO_CONTENT)
1222 }
1223 Ok(()) => with_cors(Json(handle_mcp(&runtime.service, request).await)),
1224 Err(response) => *response,
1225 }
1226}
1227
1228async fn handle_mcp(service: &GatewayService, request: RpcRequest) -> Value {
1229 let id = request.id.unwrap_or(Value::Null);
1230 if request.jsonrpc.as_deref() != Some("2.0") {
1231 return rpc_error(id, -32600, "invalid jsonrpc version");
1232 }
1233 match request.method.as_str() {
1234 "initialize" => rpc_ok(
1235 id,
1236 json!({
1237 "protocolVersion": "2025-11-25",
1238 "capabilities": { "tools": {} },
1239 "serverInfo": { "name": "shuttle-gateway", "version": env!("CARGO_PKG_VERSION") }
1240 }),
1241 ),
1242 "notifications/initialized" => json!({"jsonrpc": "2.0"}),
1243 "tools/list" => rpc_ok(id, json!({ "tools": gateway_tools() })),
1244 "tools/call" => match mcp_call_tool(service, request.params).await {
1245 Ok(value) => rpc_ok(
1246 id,
1247 json!({
1248 "content": [{ "type": "text", "text": value.to_string() }],
1249 "structuredContent": value,
1250 }),
1251 ),
1252 Err(err) => rpc_error(id, -32603, &err.to_string()),
1253 },
1254 _ => rpc_error(id, -32601, "method not found"),
1255 }
1256}
1257
1258async fn mcp_call_tool(service: &GatewayService, params: Value) -> Result<Value> {
1259 let name = params
1260 .get("name")
1261 .and_then(Value::as_str)
1262 .ok_or_else(|| ShuttleError::Store("missing tool name".to_owned()))?;
1263 let args = params
1264 .get("arguments")
1265 .cloned()
1266 .unwrap_or_else(|| json!({}));
1267 match name {
1268 "shuttle_projects" => Ok(json!({ "projects": service.list_projects()? })),
1269 "shuttle_project_add" => {
1270 let (name, config, make_current) = project_add_args(&args)?;
1271 serde_json::to_value(service.add_project(&name, config, make_current)?)
1272 .map_err(|err| ShuttleError::Serialization(err.to_string()))
1273 }
1274 "shuttle_current_project" => serde_json::to_value(service.current_project()?)
1275 .map_err(|err| ShuttleError::Serialization(err.to_string())),
1276 "shuttle_use_project" => {
1277 serde_json::to_value(service.use_project(str_arg(&args, "project")?)?)
1278 .map_err(|err| ShuttleError::Serialization(err.to_string()))
1279 }
1280 "shuttle_context" => service
1281 .context(optional_str_arg(&args, "project"))
1282 .await
1283 .and_then(to_value),
1284 "shuttle_recall" => service
1285 .recall(optional_str_arg(&args, "project"), str_arg(&args, "query")?)
1286 .await
1287 .and_then(to_value),
1288 "shuttle_remember" => service
1289 .remember(
1290 str_arg(&args, "project")?,
1291 optional_str_arg(&args, "kind"),
1292 str_arg(&args, "text")?,
1293 )
1294 .await
1295 .and_then(to_value),
1296 "shuttle_task_list" => service
1297 .task_list(optional_str_arg(&args, "project"))
1298 .await
1299 .and_then(to_value),
1300 "shuttle_task_create" => service
1301 .task_create(
1302 str_arg(&args, "project")?,
1303 str_arg(&args, "title")?,
1304 optional_str_arg(&args, "body"),
1305 )
1306 .await
1307 .and_then(to_value),
1308 "shuttle_task_update" => service
1309 .task_update(
1310 str_arg(&args, "project")?,
1311 str_arg(&args, "task_id")?,
1312 str_arg(&args, "text")?,
1313 )
1314 .await
1315 .and_then(to_value),
1316 "shuttle_task_done" => service
1317 .task_done(str_arg(&args, "project")?, str_arg(&args, "task_id")?)
1318 .await
1319 .and_then(to_value),
1320 other => Err(ShuttleError::Store(format!("unknown tool: {other}"))),
1321 }
1322}
1323
1324fn to_value(response: ServiceResponse) -> Result<Value> {
1325 serde_json::to_value(response).map_err(|err| ShuttleError::Serialization(err.to_string()))
1326}
1327
1328fn str_arg<'a>(args: &'a Value, key: &str) -> Result<&'a str> {
1329 args.get(key)
1330 .and_then(Value::as_str)
1331 .filter(|value| !value.is_empty())
1332 .ok_or_else(|| ShuttleError::Store(format!("{key} is required")))
1333}
1334
1335fn optional_str_arg<'a>(args: &'a Value, key: &str) -> &'a str {
1336 args.get(key).and_then(Value::as_str).unwrap_or("")
1337}
1338
1339fn project_add_args(args: &Value) -> Result<(String, ProjectConfig, bool)> {
1340 Ok((
1341 str_arg(args, "name")?.to_owned(),
1342 ProjectConfig {
1343 backend: project_backend_arg(args, "backend")?,
1344 repo: optional_path_arg(args, "repo"),
1345 db: optional_path_arg(args, "db"),
1346 url: optional_string_arg(args, "url").unwrap_or_default(),
1347 token_env: optional_string_arg(args, "token_env"),
1348 description: optional_string_arg(args, "description"),
1349 },
1350 optional_bool_arg(args, "make_current"),
1351 ))
1352}
1353
1354fn project_backend_arg(args: &Value, key: &str) -> Result<ProjectBackendKind> {
1355 match optional_str_arg(args, key) {
1356 "" => Ok(ProjectBackendKind::Local),
1357 "local" => Ok(ProjectBackendKind::Local),
1358 "http" => Ok(ProjectBackendKind::Http),
1359 other => Err(ShuttleError::Store(format!(
1360 "{key} must be one of: local, http; got {other:?}"
1361 ))),
1362 }
1363}
1364
1365fn optional_path_arg(args: &Value, key: &str) -> Option<PathBuf> {
1366 optional_string_arg(args, key).map(PathBuf::from)
1367}
1368
1369fn optional_string_arg(args: &Value, key: &str) -> Option<String> {
1370 args.get(key)
1371 .and_then(Value::as_str)
1372 .map(str::trim)
1373 .filter(|value| !value.is_empty())
1374 .map(ToOwned::to_owned)
1375}
1376
1377fn optional_bool_arg(args: &Value, key: &str) -> bool {
1378 args.get(key).and_then(Value::as_bool).unwrap_or(false)
1379}
1380
1381fn rpc_ok(id: Value, result: Value) -> Value {
1382 json!({ "jsonrpc": "2.0", "id": id, "result": result })
1383}
1384
1385fn rpc_error(id: Value, code: i64, message: &str) -> Value {
1386 json!({ "jsonrpc": "2.0", "id": id, "error": { "code": code, "message": message } })
1387}
1388
1389fn gateway_tools() -> Vec<Value> {
1390 vec![
1391 tool(
1392 "shuttle_projects",
1393 "List configured Shuttle projects",
1394 json!({}),
1395 vec![],
1396 projects_output_schema(),
1397 ),
1398 tool(
1399 "shuttle_project_add",
1400 "Add a Shuttle project to the running gateway",
1401 json!({
1402 "name": string_schema("Project name"),
1403 "backend": enum_schema("Project backend; defaults to local", &["local", "http"]),
1404 "repo": nullable_string_schema("Absolute local repository path for local backends"),
1405 "db": nullable_string_schema("Absolute local Shuttle database path"),
1406 "url": string_schema("HTTP project base URL for http backends"),
1407 "token_env": nullable_string_schema("Environment variable name containing the backend bearer token"),
1408 "description": nullable_string_schema("Project description"),
1409 "make_current": bool_schema("Set the added project as the current project"),
1410 }),
1411 vec!["name"],
1412 project_output_schema(),
1413 ),
1414 tool(
1415 "shuttle_current_project",
1416 "Read the current or default project",
1417 json!({}),
1418 vec![],
1419 project_output_schema(),
1420 ),
1421 tool(
1422 "shuttle_use_project",
1423 "Set the current project",
1424 json!({"project": string_schema("Configured project name")}),
1425 vec!["project"],
1426 project_output_schema(),
1427 ),
1428 tool(
1429 "shuttle_context",
1430 "Read Shuttle context for a project",
1431 json!({"project": string_schema("Configured project name; optional with default project")}),
1432 vec![],
1433 service_response_output_schema(),
1434 ),
1435 tool(
1436 "shuttle_recall",
1437 "Search Shuttle memories in a project",
1438 json!({"project": string_schema("Configured project name; optional with default project"), "query": string_schema("Recall query")}),
1439 vec!["query"],
1440 service_response_output_schema(),
1441 ),
1442 tool(
1443 "shuttle_remember",
1444 "Store a Shuttle memory in a project",
1445 json!({"project": string_schema("Configured project name"), "kind": enum_schema("Memory kind", &["memory", "decision", "observation", "pattern", "fact", "bug"]), "text": string_schema("Memory text")}),
1446 vec!["project", "text"],
1447 service_response_output_schema(),
1448 ),
1449 tool(
1450 "shuttle_task_list",
1451 "List Shuttle tasks in a project",
1452 json!({"project": string_schema("Configured project name; optional with default project")}),
1453 vec![],
1454 service_response_output_schema(),
1455 ),
1456 tool(
1457 "shuttle_task_create",
1458 "Create a Shuttle task in a project",
1459 json!({"project": string_schema("Configured project name"), "title": string_schema("Task title"), "body": string_schema("Optional task body")}),
1460 vec!["project", "title"],
1461 service_response_output_schema(),
1462 ),
1463 tool(
1464 "shuttle_task_update",
1465 "Update a Shuttle task in a project",
1466 json!({"project": string_schema("Configured project name"), "task_id": string_schema("Task UUID"), "text": string_schema("Update text")}),
1467 vec!["project", "task_id", "text"],
1468 service_response_output_schema(),
1469 ),
1470 tool(
1471 "shuttle_task_done",
1472 "Complete a Shuttle task in a project",
1473 json!({"project": string_schema("Configured project name"), "task_id": string_schema("Task UUID")}),
1474 vec!["project", "task_id"],
1475 service_response_output_schema(),
1476 ),
1477 ]
1478}
1479
1480fn tool(
1481 name: &str,
1482 description: &str,
1483 properties: Value,
1484 required: Vec<&str>,
1485 output_schema: Value,
1486) -> Value {
1487 json!({
1488 "name": name,
1489 "description": description,
1490 "inputSchema": {
1491 "type": "object",
1492 "properties": properties,
1493 "required": required,
1494 "additionalProperties": false,
1495 },
1496 "outputSchema": output_schema,
1497 })
1498}
1499
1500fn projects_output_schema() -> Value {
1501 object_schema(
1502 json!({ "projects": { "type": "array", "items": project_schema() } }),
1503 vec!["projects"],
1504 )
1505}
1506
1507fn project_output_schema() -> Value {
1508 project_schema()
1509}
1510
1511fn service_response_output_schema() -> Value {
1512 object_schema(
1513 json!({
1514 "project": string_schema("Configured project name"),
1515 "result": json_schema("Tool result from the selected project"),
1516 "stored": {
1517 "type": "boolean",
1518 "description": "Whether the operation stored data",
1519 },
1520 }),
1521 vec!["project", "result"],
1522 )
1523}
1524
1525fn project_schema() -> Value {
1526 object_schema(
1527 json!({
1528 "name": string_schema("Configured project name"),
1529 "backend": enum_schema("Project backend", &["local", "http"]),
1530 "repo": nullable_string_schema("Local repository path"),
1531 "db": nullable_string_schema("Local Shuttle database path"),
1532 "url": string_schema("HTTP project base URL"),
1533 "description": nullable_string_schema("Project description"),
1534 }),
1535 vec!["name", "backend", "url"],
1536 )
1537}
1538
1539fn object_schema(properties: Value, required: Vec<&str>) -> Value {
1540 json!({
1541 "type": "object",
1542 "properties": properties,
1543 "required": required,
1544 "additionalProperties": true,
1545 })
1546}
1547
1548fn string_schema(description: &str) -> Value {
1549 json!({ "type": "string", "description": description })
1550}
1551
1552fn bool_schema(description: &str) -> Value {
1553 json!({ "type": "boolean", "description": description })
1554}
1555
1556fn nullable_string_schema(description: &str) -> Value {
1557 json!({ "type": ["string", "null"], "description": description })
1558}
1559
1560fn json_schema(description: &str) -> Value {
1561 json!({
1562 "type": ["object", "array", "string", "number", "integer", "boolean", "null"],
1563 "description": description,
1564 })
1565}
1566
1567fn enum_schema(description: &str, values: &[&str]) -> Value {
1568 json!({ "type": "string", "description": description, "enum": values })
1569}
1570
1571fn authorize(
1572 runtime: &GatewayRuntime,
1573 headers: &HeaderMap,
1574 path: &str,
1575 cors: bool,
1576) -> std::result::Result<(), Box<Response>> {
1577 if is_oauth_public_route(path) {
1578 return Ok(());
1579 }
1580 match &runtime.auth {
1581 GatewayAuth::Bearer { token_env } => {
1582 let Some(token) = env::var(token_env).ok().filter(|token| !token.is_empty()) else {
1583 return Ok(());
1584 };
1585 let expected = format!("Bearer {token}");
1586 let ok = headers
1587 .get(header::AUTHORIZATION)
1588 .and_then(|header| header.to_str().ok())
1589 .is_some_and(|actual| constant_time_eq(actual.as_bytes(), expected.as_bytes()));
1590 if ok {
1591 Ok(())
1592 } else if cors {
1593 Err(Box::new(with_cors(StatusCode::UNAUTHORIZED)))
1594 } else {
1595 Err(Box::new(
1596 (
1597 StatusCode::UNAUTHORIZED,
1598 Json(json!({"error": "unauthorized"})),
1599 )
1600 .into_response(),
1601 ))
1602 }
1603 }
1604 GatewayAuth::OAuth(oauth) => {
1605 let Some(token) = bearer_token(headers) else {
1606 return Err(Box::new(unauthorized_oauth(&oauth.config)));
1607 };
1608 match oauth.store.validate_access_token(token) {
1609 Ok(true) => Ok(()),
1610 Ok(false) => Err(Box::new(unauthorized_oauth(&oauth.config))),
1611 Err(_) => Err(Box::new(oauth_error(
1612 StatusCode::UNAUTHORIZED,
1613 "invalid_token",
1614 "failed to validate access token",
1615 ))),
1616 }
1617 }
1618 GatewayAuth::None => Ok(()),
1619 }
1620}
1621
1622fn is_oauth_public_route(path: &str) -> bool {
1623 matches!(
1624 path,
1625 "/.well-known/oauth-protected-resource"
1626 | "/.well-known/oauth-protected-resource/mcp"
1627 | "/.well-known/oauth-authorization-server"
1628 | "/oauth/register"
1629 | "/oauth/token"
1630 | "/oauth/authorize"
1631 )
1632}
1633
1634async fn oauth_protected_resource(State(runtime): State<GatewayRuntime>) -> Response {
1635 let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1636 return (
1637 StatusCode::NOT_FOUND,
1638 Json(json!({"error": "oauth is not configured"})),
1639 )
1640 .into_response();
1641 };
1642 Json(oauth::protected_resource_metadata(&oauth.config)).into_response()
1643}
1644
1645async fn oauth_authorization_server(State(runtime): State<GatewayRuntime>) -> Response {
1646 let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1647 return (
1648 StatusCode::NOT_FOUND,
1649 Json(json!({"error": "oauth is not configured"})),
1650 )
1651 .into_response();
1652 };
1653 Json(oauth::authorization_server_metadata(&oauth.config)).into_response()
1654}
1655
1656async fn oauth_register(
1657 State(runtime): State<GatewayRuntime>,
1658 Json(request): Json<oauth::RegisterRequest>,
1659) -> Response {
1660 let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1661 return (
1662 StatusCode::NOT_FOUND,
1663 Json(json!({"error": "oauth is not configured"})),
1664 )
1665 .into_response();
1666 };
1667 match oauth.store.register_client(request) {
1668 Ok(client) => {
1669 let mut body = json!({
1670 "client_id": client.client_id,
1671 "redirect_uris": client.redirect_uris,
1672 "client_name": client.client_name,
1673 "token_endpoint_auth_method": "none",
1674 });
1675 if let Some(secret) = client.client_secret {
1676 body["client_secret"] = json!(secret);
1677 }
1678 (StatusCode::CREATED, Json(body)).into_response()
1679 }
1680 Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
1681 }
1682}
1683
1684async fn oauth_authorize_page(
1685 State(runtime): State<GatewayRuntime>,
1686 Query(request): Query<oauth::AuthorizeRequest>,
1687) -> Response {
1688 let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1689 return (
1690 StatusCode::NOT_FOUND,
1691 Json(json!({"error": "oauth is not configured"})),
1692 )
1693 .into_response();
1694 };
1695 if request.response_type != "code" {
1696 return oauth_error(
1697 StatusCode::BAD_REQUEST,
1698 "unsupported_response_type",
1699 "response_type must be code",
1700 );
1701 }
1702 match oauth
1703 .store
1704 .client_allows_redirect(&request.client_id, &request.redirect_uri)
1705 {
1706 Ok(true) => {
1707 Html(authorize_html(&request, oauth.config.admin_token.is_some())).into_response()
1708 }
1709 Ok(false) => oauth_error(
1710 StatusCode::BAD_REQUEST,
1711 "invalid_request",
1712 "unknown client_id or redirect_uri",
1713 ),
1714 Err(_) => oauth_error(
1715 StatusCode::INTERNAL_SERVER_ERROR,
1716 "server_error",
1717 "failed to validate OAuth client",
1718 ),
1719 }
1720}
1721
1722async fn oauth_authorize_submit(
1723 State(runtime): State<GatewayRuntime>,
1724 Form(form): Form<oauth::AuthorizeForm>,
1725) -> Response {
1726 let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1727 return (
1728 StatusCode::NOT_FOUND,
1729 Json(json!({"error": "oauth is not configured"})),
1730 )
1731 .into_response();
1732 };
1733 if let Some(expected) = oauth.config.admin_token.as_deref() {
1734 if !constant_time_eq(form.admin_token.as_bytes(), expected.as_bytes()) {
1735 return oauth_error(
1736 StatusCode::UNAUTHORIZED,
1737 "access_denied",
1738 "invalid admin token",
1739 );
1740 }
1741 }
1742 let request = oauth::AuthorizeRequest::from(form);
1743 if request.response_type != "code" {
1744 return oauth_error(
1745 StatusCode::BAD_REQUEST,
1746 "unsupported_response_type",
1747 "response_type must be code",
1748 );
1749 }
1750 match oauth.store.create_code(request.clone()) {
1751 Ok(code) => Redirect::to(&oauth::authorize_redirect(
1752 &request.redirect_uri,
1753 &code,
1754 request.state.as_deref(),
1755 ))
1756 .into_response(),
1757 Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
1758 }
1759}
1760
1761async fn oauth_token(
1762 State(runtime): State<GatewayRuntime>,
1763 Form(request): Form<oauth::TokenRequest>,
1764) -> Response {
1765 let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1766 return (
1767 StatusCode::NOT_FOUND,
1768 Json(json!({"error": "oauth is not configured"})),
1769 )
1770 .into_response();
1771 };
1772 if request.grant_type != "authorization_code" {
1773 return oauth_error(
1774 StatusCode::BAD_REQUEST,
1775 "unsupported_grant_type",
1776 "grant_type must be authorization_code",
1777 );
1778 }
1779 match oauth.store.exchange_code(request) {
1780 Ok(token) => Json(token).into_response(),
1781 Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_grant", &err.to_string()),
1782 }
1783}
1784
1785fn error_response(err: ShuttleError) -> Response {
1786 (
1787 StatusCode::BAD_REQUEST,
1788 Json(json!({"error": err.to_string()})),
1789 )
1790 .into_response()
1791}
1792
1793fn bearer_token(headers: &HeaderMap) -> Option<&str> {
1794 headers
1795 .get(header::AUTHORIZATION)
1796 .and_then(|header| header.to_str().ok())
1797 .and_then(|value| {
1798 let (scheme, token) = value.split_once(' ')?;
1799 scheme.eq_ignore_ascii_case("Bearer").then_some(token)
1800 })
1801}
1802
1803fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
1804 let mut diff = left.len() ^ right.len();
1805 for index in 0..left.len().max(right.len()) {
1806 let left = *left.get(index).unwrap_or(&0);
1807 let right = *right.get(index).unwrap_or(&0);
1808 diff |= (left ^ right) as usize;
1809 }
1810 diff == 0
1811}
1812
1813fn with_cors(response: impl IntoResponse) -> Response {
1814 let (mut parts, body) = response.into_response().into_parts();
1815 parts
1816 .headers
1817 .insert("access-control-allow-origin", HeaderValue::from_static("*"));
1818 parts.headers.insert(
1819 "access-control-allow-methods",
1820 HeaderValue::from_static("GET,POST,DELETE,OPTIONS"),
1821 );
1822 parts.headers.insert(
1823 "access-control-allow-headers",
1824 HeaderValue::from_static(
1825 "accept,authorization,content-type,mcp-protocol-version,mcp-session-id",
1826 ),
1827 );
1828 parts.headers.insert(
1829 "access-control-expose-headers",
1830 HeaderValue::from_static("mcp-session-id"),
1831 );
1832 Response::from_parts(parts, body)
1833}
1834
1835fn unauthorized_oauth(config: &OAuthConfig) -> Response {
1836 let mut response = with_cors(StatusCode::UNAUTHORIZED);
1837 let header_value = format!(
1838 r#"Bearer resource_metadata="{}/.well-known/oauth-protected-resource/mcp", scope="mcp""#,
1839 quoted_header_value(&config.public_url)
1840 );
1841 if let Ok(value) = HeaderValue::from_str(&header_value) {
1842 response
1843 .headers_mut()
1844 .insert(header::WWW_AUTHENTICATE, value);
1845 }
1846 response
1847}
1848
1849fn oauth_error(status: StatusCode, code: &str, description: &str) -> Response {
1850 (
1851 status,
1852 Json(json!({ "error": code, "error_description": description })),
1853 )
1854 .into_response()
1855}
1856
1857fn authorize_html(request: &oauth::AuthorizeRequest, requires_admin_token: bool) -> String {
1858 let admin = if requires_admin_token {
1859 r#"<label>Admin token <input name="admin_token" type="password" autocomplete="current-password" required></label>"#
1860 } else {
1861 r#"<input name="admin_token" type="hidden" value="">"#
1862 };
1863 format!(
1864 r#"<!doctype html>
1865<html>
1866<head><meta charset="utf-8"><title>Authorize Shuttle Gateway</title></head>
1867<body>
1868 <h1>Authorize Shuttle Gateway</h1>
1869 <p>{client_id} is requesting access to Shuttle MCP.</p>
1870 <form method="post" action="/oauth/authorize">
1871 {admin}
1872 <input type="hidden" name="response_type" value="{response_type}">
1873 <input type="hidden" name="client_id" value="{client_id}">
1874 <input type="hidden" name="redirect_uri" value="{redirect_uri}">
1875 <input type="hidden" name="state" value="{state}">
1876 <input type="hidden" name="scope" value="{scope}">
1877 <input type="hidden" name="code_challenge" value="{code_challenge}">
1878 <input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
1879 <button type="submit">Authorize</button>
1880 </form>
1881</body>
1882</html>"#,
1883 admin = admin,
1884 response_type = html_escape(&request.response_type),
1885 client_id = html_escape(&request.client_id),
1886 redirect_uri = html_escape(&request.redirect_uri),
1887 state = html_escape(request.state.as_deref().unwrap_or("")),
1888 scope = html_escape(request.scope.as_deref().unwrap_or("mcp")),
1889 code_challenge = html_escape(request.code_challenge.as_deref().unwrap_or("")),
1890 code_challenge_method =
1891 html_escape(request.code_challenge_method.as_deref().unwrap_or("S256")),
1892 )
1893}
1894
1895fn html_escape(value: &str) -> String {
1896 value
1897 .replace('&', "&")
1898 .replace('<', "<")
1899 .replace('>', ">")
1900 .replace('"', """)
1901 .replace('\'', "'")
1902}
1903
1904fn quoted_header_value(value: &str) -> String {
1905 value.replace('\\', "\\\\").replace('"', "\\\"")
1906}
1907
1908fn default_addr() -> SocketAddr {
1909 "127.0.0.1:8787".parse().expect("valid default address")
1910}
1911
1912fn default_gateway_token_env() -> String {
1913 "SHUTTLE_GATEWAY_TOKEN".to_owned()
1914}
1915
1916fn default_oauth_admin_token_env() -> String {
1917 "SHUTTLE_OAUTH_ADMIN_TOKEN".to_owned()
1918}
1919
1920fn legacy_listener_config(config: &GatewayConfig) -> ListenerConfig {
1921 if config.oauth.public_url.is_empty() {
1922 ListenerConfig {
1923 name: "default".to_owned(),
1924 addr: config.server.addr,
1925 auth: ListenerAuthKind::Bearer,
1926 public_url: String::new(),
1927 oauth_db_path: None,
1928 oauth_admin_token_env: config.oauth.admin_token_env.clone(),
1929 bearer_token_env: config.auth.bearer_token_env.clone(),
1930 }
1931 } else {
1932 ListenerConfig {
1933 name: "default".to_owned(),
1934 addr: config.server.addr,
1935 auth: ListenerAuthKind::OAuth,
1936 public_url: config.oauth.public_url.clone(),
1937 oauth_db_path: config.oauth.db_path.clone(),
1938 oauth_admin_token_env: config.oauth.admin_token_env.clone(),
1939 bearer_token_env: config.auth.bearer_token_env.clone(),
1940 }
1941 }
1942}
1943
1944fn auth_from_listener(listener: &ListenerConfig) -> Result<GatewayAuth> {
1945 match listener.auth {
1946 ListenerAuthKind::Bearer => Ok(GatewayAuth::Bearer {
1947 token_env: listener.bearer_token_env.clone(),
1948 }),
1949 ListenerAuthKind::None => Ok(GatewayAuth::None),
1950 ListenerAuthKind::OAuth => {
1951 let admin_token = env::var(&listener.oauth_admin_token_env).map_err(|_| {
1952 ShuttleError::Store(format!(
1953 "{} is required when oauth listener {:?} is configured",
1954 listener.oauth_admin_token_env, listener.name
1955 ))
1956 })?;
1957 if admin_token.is_empty() {
1958 return Err(ShuttleError::Store(format!(
1959 "{} is required when oauth listener {:?} is configured",
1960 listener.oauth_admin_token_env, listener.name
1961 )));
1962 }
1963 let db_path = listener.oauth_db_path.clone().ok_or_else(|| {
1964 ShuttleError::Store(format!(
1965 "oauth_db_path is required for listener {:?}",
1966 listener.name
1967 ))
1968 })?;
1969 Ok(GatewayAuth::OAuth(Arc::new(OAuthRuntime {
1970 config: OAuthConfig {
1971 public_url: listener.public_url.clone(),
1972 admin_token: Some(admin_token),
1973 },
1974 store: OAuthStore::open(db_path)?,
1975 })))
1976 }
1977 }
1978}
1979
1980fn is_loopback_addr(addr: SocketAddr) -> bool {
1981 match addr.ip() {
1982 IpAddr::V4(ip) => ip.is_loopback(),
1983 IpAddr::V6(ip) => ip.is_loopback(),
1984 }
1985}
1986
1987fn normalize_public_url(url: &str) -> String {
1988 url.trim().trim_end_matches('/').to_owned()
1989}
1990
1991#[cfg(test)]
1992mod tests {
1993 use super::*;
1994 use axum::body::Body;
1995 use axum::http::{Method, Request};
1996 use http_body_util::BodyExt;
1997 use std::sync::Mutex;
1998 use tower::ServiceExt;
1999
2000 #[derive(Default)]
2001 struct FakeRunner {
2002 calls: Mutex<Vec<(String, Vec<String>)>>,
2003 }
2004
2005 #[async_trait]
2006 impl Runner for FakeRunner {
2007 async fn run(
2008 &self,
2009 project: &Project,
2010 args: &[&str],
2011 ) -> std::result::Result<Value, String> {
2012 self.calls.lock().unwrap().push((
2013 project.name.clone(),
2014 args.iter().map(|arg| (*arg).to_owned()).collect(),
2015 ));
2016 Ok(json!({"ok": true}))
2017 }
2018 }
2019
2020 fn registry() -> ProjectRegistry {
2021 ProjectRegistry::new(
2022 "demo".to_owned(),
2023 BTreeMap::from([(
2024 "demo".to_owned(),
2025 ProjectConfig {
2026 backend: ProjectBackendKind::Local,
2027 repo: Some(PathBuf::from("/tmp/demo")),
2028 db: None,
2029 url: String::new(),
2030 token_env: None,
2031 description: None,
2032 },
2033 )]),
2034 )
2035 .unwrap()
2036 }
2037
2038 #[test]
2039 fn config_rejects_relative_repo_and_applies_defaults() {
2040 let dir = tempfile::tempdir().unwrap();
2041 let path = dir.path().join("projects.toml");
2042 std::fs::write(&path, "[projects.demo]\nrepo = \"relative\"\n").unwrap();
2043 assert!(GatewayConfig::load(&path).is_err());
2044
2045 std::fs::write(&path, "[projects.demo]\nrepo = \"/tmp/demo\"\n").unwrap();
2046 let cfg = GatewayConfig::load(&path).unwrap();
2047 assert_eq!(cfg.server.addr, default_addr());
2048 assert_eq!(cfg.auth.bearer_token_env, "SHUTTLE_GATEWAY_TOKEN");
2049 assert_eq!(cfg.oauth.admin_token_env, "SHUTTLE_OAUTH_ADMIN_TOKEN");
2050 }
2051
2052 #[test]
2053 fn config_normalizes_oauth_defaults() {
2054 let dir = tempfile::tempdir().unwrap();
2055 let path = dir.path().join("projects.toml");
2056 std::fs::write(
2057 &path,
2058 "[oauth]\npublic_url = \"https://shuttle.example.test/\"\n\n[projects.demo]\nrepo = \"/tmp/demo\"\n",
2059 )
2060 .unwrap();
2061 let cfg = GatewayConfig::load(&path).unwrap();
2062 assert_eq!(cfg.oauth.public_url, "https://shuttle.example.test");
2063 assert_eq!(
2064 cfg.oauth.db_path.unwrap().file_name().unwrap(),
2065 "gateway-oauth.db"
2066 );
2067 }
2068
2069 #[test]
2070 fn config_accepts_http_projects_without_repo() {
2071 let dir = tempfile::tempdir().unwrap();
2072 let path = dir.path().join("projects.toml");
2073 std::fs::write(
2074 &path,
2075 "[projects.demo]\nbackend = \"http\"\nurl = \"http://127.0.0.1:8787\"\ntoken_env = \"DEMO_TOKEN\"\n",
2076 )
2077 .unwrap();
2078 let cfg = GatewayConfig::load(&path).unwrap();
2079 let project = cfg.projects.get("demo").unwrap();
2080
2081 assert_eq!(project.backend, ProjectBackendKind::Http);
2082 assert!(project.repo.is_none());
2083 assert_eq!(project.token_env.as_deref(), Some("DEMO_TOKEN"));
2084 }
2085
2086 #[test]
2087 fn config_rejects_http_projects_without_url() {
2088 let dir = tempfile::tempdir().unwrap();
2089 let path = dir.path().join("projects.toml");
2090 std::fs::write(&path, "[projects.demo]\nbackend = \"http\"\n").unwrap();
2091
2092 assert!(GatewayConfig::load(&path)
2093 .unwrap_err()
2094 .to_string()
2095 .contains("url is required"));
2096 }
2097
2098 #[test]
2099 fn config_rejects_none_listener_on_non_loopback() {
2100 let dir = tempfile::tempdir().unwrap();
2101 let path = dir.path().join("projects.toml");
2102 std::fs::write(
2103 &path,
2104 "[[listeners]]\nname = \"open\"\naddr = \"0.0.0.0:8787\"\nauth = \"none\"\n\n[projects.demo]\nrepo = \"/tmp/demo\"\n",
2105 )
2106 .unwrap();
2107
2108 assert!(GatewayConfig::load(&path)
2109 .unwrap_err()
2110 .to_string()
2111 .contains("only allowed on loopback"));
2112 }
2113
2114 #[tokio::test]
2115 async fn remember_requires_explicit_project_and_maps_kind() {
2116 let runner = Arc::new(FakeRunner::default());
2117 let service = GatewayService::new(registry(), runner.clone());
2118 assert!(service.remember("", "decision", "ship it").await.is_err());
2119 let response = service
2120 .remember("demo", "decision", "ship it")
2121 .await
2122 .unwrap();
2123 assert_eq!(response.project, "demo");
2124 assert_eq!(response.stored, Some(true));
2125 let calls = runner.calls.lock().unwrap();
2126 assert_eq!(calls[0].1, vec!["decide", "ship it"]);
2127 }
2128
2129 #[tokio::test]
2130 async fn task_create_combines_title_and_body() {
2131 let runner = Arc::new(FakeRunner::default());
2132 let service = GatewayService::new(registry(), runner.clone());
2133 service.task_create("demo", "title", "body").await.unwrap();
2134 let calls = runner.calls.lock().unwrap();
2135 assert_eq!(calls[0].1, vec!["task", "create", "title\n\nbody"]);
2136 }
2137
2138 #[tokio::test]
2139 async fn service_add_project_routes_to_runtime_project() {
2140 let runner = Arc::new(FakeRunner::default());
2141 let service = GatewayService::new(registry(), runner.clone());
2142 let project = service
2143 .add_project(
2144 " extra ",
2145 ProjectConfig {
2146 backend: ProjectBackendKind::Http,
2147 repo: None,
2148 db: None,
2149 url: "http://127.0.0.1:9999/".to_owned(),
2150 token_env: Some("EXTRA_TOKEN".to_owned()),
2151 description: Some("extra project".to_owned()),
2152 },
2153 true,
2154 )
2155 .unwrap();
2156
2157 assert_eq!(project.name, "extra");
2158 assert_eq!(project.url, "http://127.0.0.1:9999");
2159 assert_eq!(service.current_project().unwrap().name, "extra");
2160 service.task_list("extra").await.unwrap();
2161
2162 let calls = runner.calls.lock().unwrap();
2163 assert_eq!(calls[0].0, "extra");
2164 assert_eq!(calls[0].1, vec!["task", "list"]);
2165 }
2166
2167 #[test]
2168 fn service_add_project_suffixes_duplicates_and_rejects_invalid_config() {
2169 let service = GatewayService::new(registry(), Arc::new(FakeRunner::default()));
2170 let duplicate = service
2171 .add_project(
2172 "demo",
2173 ProjectConfig {
2174 backend: ProjectBackendKind::Http,
2175 repo: None,
2176 db: None,
2177 url: "http://127.0.0.1:9999".to_owned(),
2178 token_env: None,
2179 description: None,
2180 },
2181 false,
2182 )
2183 .unwrap();
2184 assert_eq!(duplicate.name, "demo-2");
2185 assert!(service
2186 .add_project(
2187 "relative",
2188 ProjectConfig {
2189 backend: ProjectBackendKind::Local,
2190 repo: Some(PathBuf::from("relative")),
2191 db: None,
2192 url: String::new(),
2193 token_env: None,
2194 description: None,
2195 },
2196 false,
2197 )
2198 .unwrap_err()
2199 .to_string()
2200 .contains("absolute path"));
2201 }
2202
2203 #[test]
2204 fn service_add_project_persists_local_project_config() {
2205 let dir = tempfile::tempdir().unwrap();
2206 let config_path = write_gateway_config(dir.path(), false);
2207 let runtime = file_backed_runtime(&config_path, Arc::new(FakeRunner::default()));
2208 let repo = dir.path().join("repo");
2209 let db = dir.path().join("repo/.shuttle/shuttle.db");
2210
2211 let project = runtime
2212 .service
2213 .add_project(
2214 "local-extra",
2215 ProjectConfig {
2216 backend: ProjectBackendKind::Local,
2217 repo: Some(repo.clone()),
2218 db: Some(db.clone()),
2219 url: String::new(),
2220 token_env: None,
2221 description: Some("local extra".to_owned()),
2222 },
2223 false,
2224 )
2225 .unwrap();
2226
2227 assert_eq!(project.name, "local-extra");
2228 let reloaded = GatewayConfig::load(&config_path).unwrap();
2229 let persisted = reloaded.projects.get("local-extra").unwrap();
2230 assert_eq!(persisted.backend, ProjectBackendKind::Local);
2231 assert_eq!(persisted.repo.as_ref(), Some(&repo));
2232 assert_eq!(persisted.db.as_ref(), Some(&db));
2233 assert_eq!(persisted.description.as_deref(), Some("local extra"));
2234 }
2235
2236 #[tokio::test]
2237 async fn mcp_tools_list_includes_gateway_tools() {
2238 let runtime = test_runtime(registry(), Arc::new(FakeRunner::default()));
2239 env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2240 let response = router(runtime)
2241 .oneshot(
2242 Request::builder()
2243 .method(Method::POST)
2244 .uri("/mcp")
2245 .header(header::CONTENT_TYPE, "application/json")
2246 .body(Body::from(
2247 r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
2248 ))
2249 .unwrap(),
2250 )
2251 .await
2252 .unwrap();
2253 assert_eq!(response.status(), StatusCode::OK);
2254 let body = response.into_body().collect().await.unwrap().to_bytes();
2255 let value: Value = serde_json::from_slice(&body).unwrap();
2256 let tools = value["result"]["tools"].as_array().unwrap();
2257 assert!(tools.iter().any(|tool| tool["name"] == "shuttle_projects"));
2258 assert!(tools
2259 .iter()
2260 .any(|tool| tool["name"] == "shuttle_project_add"));
2261 assert!(tools.iter().any(|tool| tool["name"] == "shuttle_remember"));
2262 assert!(tools
2263 .iter()
2264 .all(|tool| tool["inputSchema"]["additionalProperties"] == false));
2265 assert!(tools
2266 .iter()
2267 .all(|tool| tool["outputSchema"]["type"] == "object"));
2268 }
2269
2270 #[tokio::test]
2271 async fn mcp_tool_call_returns_structured_content() {
2272 let runtime = test_runtime(registry(), Arc::new(FakeRunner::default()));
2273 env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2274 let response = router(runtime)
2275 .oneshot(
2276 Request::builder()
2277 .method(Method::POST)
2278 .uri("/mcp")
2279 .header(header::CONTENT_TYPE, "application/json")
2280 .body(Body::from(
2281 r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"shuttle_task_create","arguments":{"project":"demo","title":"ship it"}}}"#,
2282 ))
2283 .unwrap(),
2284 )
2285 .await
2286 .unwrap();
2287 assert_eq!(response.status(), StatusCode::OK);
2288 let body = response.into_body().collect().await.unwrap().to_bytes();
2289 let value: Value = serde_json::from_slice(&body).unwrap();
2290
2291 assert_eq!(
2292 value["result"]["structuredContent"],
2293 json!({"project": "demo", "result": {"ok": true}, "stored": true})
2294 );
2295 assert_eq!(
2296 value["result"]["content"][0]["text"],
2297 r#"{"project":"demo","result":{"ok":true},"stored":true}"#
2298 );
2299 }
2300
2301 #[tokio::test]
2302 async fn mcp_project_add_registers_project() {
2303 let dir = tempfile::tempdir().unwrap();
2304 let config_path = write_gateway_config(dir.path(), true);
2305 let runtime = file_backed_runtime(&config_path, Arc::new(FakeRunner::default()));
2306 env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2307 let app = router(runtime);
2308
2309 let add = app
2310 .clone()
2311 .oneshot(
2312 Request::builder()
2313 .method(Method::POST)
2314 .uri("/mcp")
2315 .header(header::CONTENT_TYPE, "application/json")
2316 .body(Body::from(
2317 r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"shuttle_project_add","arguments":{"name":"extra","backend":"http","url":"http://127.0.0.1:9999/","make_current":true}}}"#,
2318 ))
2319 .unwrap(),
2320 )
2321 .await
2322 .unwrap();
2323 assert_eq!(add.status(), StatusCode::OK);
2324 let body = add.into_body().collect().await.unwrap().to_bytes();
2325 let value: Value = serde_json::from_slice(&body).unwrap();
2326 assert_eq!(value["result"]["structuredContent"]["name"], "extra-2");
2327 assert_eq!(
2328 value["result"]["structuredContent"]["url"],
2329 "http://127.0.0.1:9999"
2330 );
2331 assert!(GatewayConfig::load(&config_path)
2332 .unwrap()
2333 .projects
2334 .contains_key("extra-2"));
2335
2336 let current = app
2337 .oneshot(
2338 Request::builder()
2339 .method(Method::POST)
2340 .uri("/mcp")
2341 .header(header::CONTENT_TYPE, "application/json")
2342 .body(Body::from(
2343 r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"shuttle_current_project","arguments":{}}}"#,
2344 ))
2345 .unwrap(),
2346 )
2347 .await
2348 .unwrap();
2349 let body = current.into_body().collect().await.unwrap().to_bytes();
2350 let value: Value = serde_json::from_slice(&body).unwrap();
2351 assert_eq!(value["result"]["structuredContent"]["name"], "extra-2");
2352 }
2353
2354 #[tokio::test]
2355 async fn mcp_initialized_notification_returns_no_content() {
2356 let runtime = test_runtime(registry(), Arc::new(FakeRunner::default()));
2357 env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2358 let response = router(runtime)
2359 .oneshot(
2360 Request::builder()
2361 .method(Method::POST)
2362 .uri("/mcp")
2363 .header(header::CONTENT_TYPE, "application/json")
2364 .body(Body::from(
2365 r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#,
2366 ))
2367 .unwrap(),
2368 )
2369 .await
2370 .unwrap();
2371
2372 assert_eq!(response.status(), StatusCode::NO_CONTENT);
2373 }
2374
2375 #[tokio::test]
2376 async fn http_remember_requires_project_and_recall_uses_default() {
2377 let runner = Arc::new(FakeRunner::default());
2378 let runtime = test_runtime(registry(), runner.clone());
2379 env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2380
2381 let remember = router(runtime.clone())
2382 .oneshot(
2383 Request::builder()
2384 .method(Method::POST)
2385 .uri("/api/remember")
2386 .header(header::CONTENT_TYPE, "application/json")
2387 .body(Body::from(r#"{"text":"note"}"#))
2388 .unwrap(),
2389 )
2390 .await
2391 .unwrap();
2392 assert_eq!(remember.status(), StatusCode::BAD_REQUEST);
2393
2394 let recall = router(runtime)
2395 .oneshot(
2396 Request::builder()
2397 .method(Method::POST)
2398 .uri("/api/recall")
2399 .header(header::CONTENT_TYPE, "application/json")
2400 .body(Body::from(r#"{"query":"sqlite"}"#))
2401 .unwrap(),
2402 )
2403 .await
2404 .unwrap();
2405
2406 assert_eq!(recall.status(), StatusCode::OK);
2407 let calls = runner.calls.lock().unwrap();
2408 assert_eq!(calls[0].1, vec!["recall", "sqlite"]);
2409 }
2410
2411 #[tokio::test]
2412 async fn http_project_add_updates_project_list_and_current() {
2413 let dir = tempfile::tempdir().unwrap();
2414 let config_path = write_gateway_config(dir.path(), true);
2415 let runtime = file_backed_runtime(&config_path, Arc::new(FakeRunner::default()));
2416 env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2417 let app = router(runtime);
2418
2419 let add = app
2420 .clone()
2421 .oneshot(
2422 Request::builder()
2423 .method(Method::POST)
2424 .uri("/api/projects")
2425 .header(header::CONTENT_TYPE, "application/json")
2426 .body(Body::from(
2427 r#"{"name":"extra","backend":"http","url":"http://127.0.0.1:9999/","description":"extra project","make_current":true}"#,
2428 ))
2429 .unwrap(),
2430 )
2431 .await
2432 .unwrap();
2433 assert_eq!(add.status(), StatusCode::CREATED);
2434 let body = add.into_body().collect().await.unwrap().to_bytes();
2435 let project: Value = serde_json::from_slice(&body).unwrap();
2436 assert_eq!(project["name"], "extra-2");
2437 assert_eq!(project["url"], "http://127.0.0.1:9999");
2438 assert!(GatewayConfig::load(&config_path)
2439 .unwrap()
2440 .projects
2441 .contains_key("extra-2"));
2442
2443 let current = app
2444 .clone()
2445 .oneshot(
2446 Request::builder()
2447 .method(Method::GET)
2448 .uri("/api/projects/current")
2449 .body(Body::empty())
2450 .unwrap(),
2451 )
2452 .await
2453 .unwrap();
2454 let body = current.into_body().collect().await.unwrap().to_bytes();
2455 let project: Value = serde_json::from_slice(&body).unwrap();
2456 assert_eq!(project["name"], "extra-2");
2457
2458 let list = app
2459 .oneshot(
2460 Request::builder()
2461 .method(Method::GET)
2462 .uri("/api/projects")
2463 .body(Body::empty())
2464 .unwrap(),
2465 )
2466 .await
2467 .unwrap();
2468 let body = list.into_body().collect().await.unwrap().to_bytes();
2469 let projects: Value = serde_json::from_slice(&body).unwrap();
2470 assert!(projects["projects"]
2471 .as_array()
2472 .unwrap()
2473 .iter()
2474 .any(|project| project["name"] == "extra-2"));
2475 }
2476
2477 #[tokio::test]
2478 async fn current_project_without_default_returns_not_found() {
2479 let runtime = test_runtime(registry_without_default(), Arc::new(FakeRunner::default()));
2480 env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2481 let response = router(runtime)
2482 .oneshot(
2483 Request::builder()
2484 .method(Method::GET)
2485 .uri("/api/projects/current")
2486 .body(Body::empty())
2487 .unwrap(),
2488 )
2489 .await
2490 .unwrap();
2491
2492 assert_eq!(response.status(), StatusCode::NOT_FOUND);
2493 }
2494
2495 #[tokio::test]
2496 async fn oauth_routes_issue_token_that_authorizes_protected_routes() {
2497 let oauth_dir = tempfile::tempdir().unwrap();
2498 let oauth = Arc::new(OAuthRuntime {
2499 config: OAuthConfig {
2500 public_url: "https://shuttle.example.test".to_owned(),
2501 admin_token: Some("admin-token".to_owned()),
2502 },
2503 store: OAuthStore::open(oauth_dir.path().join("oauth.db")).unwrap(),
2504 });
2505 let runtime = GatewayRuntime {
2506 service: Arc::new(GatewayService::new(
2507 registry(),
2508 Arc::new(FakeRunner::default()),
2509 )),
2510 auth: GatewayAuth::OAuth(oauth),
2511 };
2512 let app = router(runtime);
2513
2514 let register = app
2515 .clone()
2516 .oneshot(
2517 Request::builder()
2518 .method(Method::POST)
2519 .uri("/oauth/register")
2520 .header(header::CONTENT_TYPE, "application/json")
2521 .body(Body::from(
2522 r#"{"redirect_uris":["https://client.example.test/callback"],"client_name":"client"}"#,
2523 ))
2524 .unwrap(),
2525 )
2526 .await
2527 .unwrap();
2528 assert_eq!(register.status(), StatusCode::CREATED);
2529 let body = register.into_body().collect().await.unwrap().to_bytes();
2530 let registered: Value = serde_json::from_slice(&body).unwrap();
2531 assert!(registered.get("client_secret").is_none());
2532 let client_id = registered["client_id"].as_str().unwrap();
2533 let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
2534 let challenge = pkce_s256(verifier);
2535 let form = format!(
2536 "admin_token=admin-token&response_type=code&client_id={client_id}&redirect_uri=https%3A%2F%2Fclient.example.test%2Fcallback&state=state-123&scope=mcp&code_challenge={challenge}&code_challenge_method=S256"
2537 );
2538 let authorize = app
2539 .clone()
2540 .oneshot(
2541 Request::builder()
2542 .method(Method::POST)
2543 .uri("/oauth/authorize")
2544 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
2545 .body(Body::from(form))
2546 .unwrap(),
2547 )
2548 .await
2549 .unwrap();
2550 assert_eq!(authorize.status(), StatusCode::SEE_OTHER);
2551 let location = authorize
2552 .headers()
2553 .get(header::LOCATION)
2554 .and_then(|value| value.to_str().ok())
2555 .unwrap();
2556 assert!(location.contains("&state=state-123"));
2557 let code = location
2558 .split("code=")
2559 .nth(1)
2560 .unwrap()
2561 .split('&')
2562 .next()
2563 .unwrap();
2564 let token_form = format!(
2565 "grant_type=authorization_code&client_id={client_id}&redirect_uri=https%3A%2F%2Fclient.example.test%2Fcallback&code={code}&code_verifier={verifier}"
2566 );
2567 let token = app
2568 .clone()
2569 .oneshot(
2570 Request::builder()
2571 .method(Method::POST)
2572 .uri("/oauth/token")
2573 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
2574 .body(Body::from(token_form))
2575 .unwrap(),
2576 )
2577 .await
2578 .unwrap();
2579 assert_eq!(token.status(), StatusCode::OK);
2580 let body = token.into_body().collect().await.unwrap().to_bytes();
2581 let token: Value = serde_json::from_slice(&body).unwrap();
2582 let access_token = token["access_token"].as_str().unwrap();
2583
2584 let unauthorized = app
2585 .clone()
2586 .oneshot(
2587 Request::builder()
2588 .method(Method::GET)
2589 .uri("/api/projects")
2590 .body(Body::empty())
2591 .unwrap(),
2592 )
2593 .await
2594 .unwrap();
2595 assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
2596
2597 let authorized = app
2598 .oneshot(
2599 Request::builder()
2600 .method(Method::GET)
2601 .uri("/api/projects")
2602 .header(header::AUTHORIZATION, format!("Bearer {access_token}"))
2603 .body(Body::empty())
2604 .unwrap(),
2605 )
2606 .await
2607 .unwrap();
2608 assert_eq!(authorized.status(), StatusCode::OK);
2609 }
2610
2611 fn write_gateway_config(dir: &Path, include_extra: bool) -> PathBuf {
2612 let path = dir.join("projects.toml");
2613 let extra = if include_extra {
2614 "\n[projects.extra]\nbackend = \"http\"\nurl = \"http://127.0.0.1:8788\"\n"
2615 } else {
2616 ""
2617 };
2618 std::fs::write(
2619 &path,
2620 format!(
2621 "[defaults]\nproject = \"demo\"\n\n[projects.demo]\nbackend = \"local\"\nrepo = \"/tmp/demo\"\n{extra}"
2622 ),
2623 )
2624 .unwrap();
2625 path
2626 }
2627
2628 fn file_backed_runtime(path: &Path, runner: Arc<FakeRunner>) -> GatewayRuntime {
2629 let cfg = GatewayConfig::load(path).unwrap();
2630 let config_path = cfg.config_path.clone();
2631 let registry = ProjectRegistry::new(cfg.defaults.project, cfg.projects).unwrap();
2632 GatewayRuntime {
2633 service: Arc::new(GatewayService::new_with_config_path(
2634 registry,
2635 runner,
2636 config_path,
2637 )),
2638 auth: GatewayAuth::Bearer {
2639 token_env: "TEST_EMPTY_GATEWAY_TOKEN".to_owned(),
2640 },
2641 }
2642 }
2643
2644 fn test_runtime(registry: ProjectRegistry, runner: Arc<FakeRunner>) -> GatewayRuntime {
2645 GatewayRuntime {
2646 service: Arc::new(GatewayService::new(registry, runner)),
2647 auth: GatewayAuth::Bearer {
2648 token_env: "TEST_EMPTY_GATEWAY_TOKEN".to_owned(),
2649 },
2650 }
2651 }
2652
2653 fn registry_without_default() -> ProjectRegistry {
2654 ProjectRegistry::new(
2655 String::new(),
2656 BTreeMap::from([(
2657 "demo".to_owned(),
2658 ProjectConfig {
2659 backend: ProjectBackendKind::Local,
2660 repo: Some(PathBuf::from("/tmp/demo")),
2661 db: None,
2662 url: String::new(),
2663 token_env: None,
2664 description: None,
2665 },
2666 )]),
2667 )
2668 .unwrap()
2669 }
2670
2671 fn pkce_s256(verifier: &str) -> String {
2672 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
2673 use base64::Engine;
2674 use sha2::{Digest, Sha256};
2675
2676 URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()))
2677 }
2678}