Skip to main content

shuttle_rs/
gateway.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::env;
3use std::io::Write;
4use std::net::{IpAddr, SocketAddr};
5use std::path::{Path, PathBuf};
6use std::sync::{Arc, Mutex, MutexGuard};
7use std::time::Duration;
8
9use async_trait::async_trait;
10use axum::extract::{Form, Path as AxumPath, Query, State};
11use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
12use axum::response::{Html, IntoResponse, Redirect, Response};
13use axum::routing::{get, patch, post};
14use axum::{Json, Router};
15use serde::{Deserialize, Serialize};
16use serde_json::{json, Value};
17use tokio::process::Command;
18use toml_edit::{value, DocumentMut, Item, Table};
19
20use crate::core::{Result, ShuttleError};
21use crate::oauth::{self, OAuthConfig, OAuthStore};
22
23#[derive(Debug, Clone, Deserialize)]
24pub struct GatewayConfig {
25    #[serde(skip)]
26    config_path: Option<PathBuf>,
27    #[serde(default)]
28    pub server: ServerConfig,
29    #[serde(default)]
30    pub auth: AuthConfig,
31    #[serde(default)]
32    pub oauth: OAuthGatewayConfig,
33    #[serde(default)]
34    pub defaults: DefaultsConfig,
35    #[serde(default)]
36    pub listeners: Vec<ListenerConfig>,
37    pub projects: BTreeMap<String, ProjectConfig>,
38}
39
40#[derive(Debug, Clone, Deserialize)]
41pub struct ServerConfig {
42    #[serde(default = "default_addr")]
43    pub addr: SocketAddr,
44}
45
46impl Default for ServerConfig {
47    fn default() -> Self {
48        Self {
49            addr: default_addr(),
50        }
51    }
52}
53
54#[derive(Debug, Clone, Deserialize)]
55pub struct AuthConfig {
56    #[serde(default = "default_gateway_token_env")]
57    pub bearer_token_env: String,
58}
59
60impl Default for AuthConfig {
61    fn default() -> Self {
62        Self {
63            bearer_token_env: default_gateway_token_env(),
64        }
65    }
66}
67
68#[derive(Debug, Clone, Deserialize)]
69pub struct OAuthGatewayConfig {
70    #[serde(default)]
71    pub public_url: String,
72    #[serde(default)]
73    pub db_path: Option<PathBuf>,
74    #[serde(default = "default_oauth_admin_token_env")]
75    pub admin_token_env: String,
76}
77
78impl Default for OAuthGatewayConfig {
79    fn default() -> Self {
80        Self {
81            public_url: String::new(),
82            db_path: None,
83            admin_token_env: default_oauth_admin_token_env(),
84        }
85    }
86}
87
88#[derive(Debug, Clone, Default, Deserialize)]
89pub struct DefaultsConfig {
90    #[serde(default)]
91    pub project: String,
92}
93
94#[derive(Debug, Clone, Deserialize)]
95pub struct ProjectConfig {
96    #[serde(default)]
97    pub backend: ProjectBackendKind,
98    #[serde(default)]
99    pub repo: Option<PathBuf>,
100    #[serde(default)]
101    pub db: Option<PathBuf>,
102    #[serde(default)]
103    pub url: String,
104    #[serde(default)]
105    pub token_env: Option<String>,
106    #[serde(default)]
107    pub description: Option<String>,
108}
109
110#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, PartialEq, Eq)]
111#[serde(rename_all = "lowercase")]
112pub enum ProjectBackendKind {
113    #[default]
114    Local,
115    Http,
116}
117
118#[derive(Debug, Clone, Deserialize)]
119pub struct ListenerConfig {
120    pub name: String,
121    pub addr: SocketAddr,
122    pub auth: ListenerAuthKind,
123    #[serde(default)]
124    pub public_url: String,
125    #[serde(default)]
126    pub oauth_db_path: Option<PathBuf>,
127    #[serde(default = "default_oauth_admin_token_env")]
128    pub oauth_admin_token_env: String,
129    #[serde(default = "default_gateway_token_env")]
130    pub bearer_token_env: String,
131}
132
133#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
134#[serde(rename_all = "lowercase")]
135pub enum ListenerAuthKind {
136    OAuth,
137    Bearer,
138    None,
139}
140
141impl GatewayConfig {
142    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
143        let path = path.as_ref();
144        let abs_path = path
145            .canonicalize()
146            .or_else(|_| {
147                path.parent()
148                    .unwrap_or_else(|| Path::new("."))
149                    .canonicalize()
150                    .map(|parent| parent.join(path.file_name().unwrap_or_default()))
151            })
152            .map_err(|err| ShuttleError::Store(err.to_string()))?;
153        let raw =
154            std::fs::read_to_string(path).map_err(|err| ShuttleError::Store(err.to_string()))?;
155        let mut cfg: GatewayConfig =
156            toml::from_str(&raw).map_err(|err| ShuttleError::Serialization(err.to_string()))?;
157        cfg.config_path = Some(abs_path.clone());
158
159        cfg.oauth.public_url = normalize_public_url(&cfg.oauth.public_url);
160        if cfg.projects.is_empty() {
161            return Err(ShuttleError::Store(
162                "at least one project is required".to_owned(),
163            ));
164        }
165        for (name, project) in &cfg.projects {
166            validate_project_config(name, project)?;
167        }
168        if !cfg.defaults.project.is_empty() && !cfg.projects.contains_key(&cfg.defaults.project) {
169            return Err(ShuttleError::Store(format!(
170                "default project {:?} is not configured",
171                cfg.defaults.project
172            )));
173        }
174        if !cfg.oauth.public_url.is_empty() {
175            match &cfg.oauth.db_path {
176                Some(path) if !path.is_absolute() => {
177                    return Err(ShuttleError::Store(
178                        "oauth db_path must be an absolute path when set".to_owned(),
179                    ))
180                }
181                Some(_) => {}
182                None => {
183                    cfg.oauth.db_path = Some(
184                        abs_path
185                            .parent()
186                            .unwrap_or_else(|| Path::new("."))
187                            .join("gateway-oauth.db"),
188                    );
189                }
190            }
191        }
192        let config_dir = abs_path.parent().unwrap_or_else(|| Path::new("."));
193        for listener in &mut cfg.listeners {
194            listener.public_url = normalize_public_url(&listener.public_url);
195            if listener.name.trim().is_empty() {
196                return Err(ShuttleError::Store(
197                    "listener name cannot be empty".to_owned(),
198                ));
199            }
200            match listener.auth {
201                ListenerAuthKind::OAuth => {
202                    if listener.public_url.is_empty() {
203                        return Err(ShuttleError::Store(format!(
204                            "listener {:?} public_url is required for oauth auth",
205                            listener.name
206                        )));
207                    }
208                    match &listener.oauth_db_path {
209                        Some(path) if !path.is_absolute() => {
210                            return Err(ShuttleError::Store(format!(
211                                "listener {:?} oauth_db_path must be an absolute path when set",
212                                listener.name
213                            )));
214                        }
215                        Some(_) => {}
216                        None => {
217                            listener.oauth_db_path = Some(
218                                config_dir.join(format!("gateway-{}-oauth.db", listener.name)),
219                            );
220                        }
221                    }
222                }
223                ListenerAuthKind::Bearer => {}
224                ListenerAuthKind::None => {
225                    if !is_loopback_addr(listener.addr) {
226                        return Err(ShuttleError::Store(format!(
227                            "listener {:?} auth none is only allowed on loopback addresses",
228                            listener.name
229                        )));
230                    }
231                }
232            }
233        }
234        Ok(cfg)
235    }
236}
237
238#[derive(Clone)]
239pub struct GatewayRuntime {
240    service: Arc<GatewayService>,
241    auth: GatewayAuth,
242}
243
244#[derive(Clone)]
245pub struct GatewayListener {
246    pub name: String,
247    pub addr: SocketAddr,
248    pub runtime: GatewayRuntime,
249}
250
251impl GatewayRuntime {
252    pub fn from_config(config: GatewayConfig, stl: PathBuf, timeout: Duration) -> Result<Self> {
253        let listeners = Self::listeners_from_config(config, stl, timeout)?;
254        if listeners.len() != 1 {
255            return Err(ShuttleError::Store(
256                "GatewayRuntime::from_config requires exactly one listener".to_owned(),
257            ));
258        }
259        Ok(listeners.into_iter().next().unwrap().runtime)
260    }
261
262    pub fn listeners_from_config(
263        mut config: GatewayConfig,
264        stl: PathBuf,
265        timeout: Duration,
266    ) -> Result<Vec<GatewayListener>> {
267        let listener_configs = if config.listeners.is_empty() {
268            vec![legacy_listener_config(&config)]
269        } else {
270            std::mem::take(&mut config.listeners)
271        };
272        let config_path = config.config_path.clone();
273        let registry = ProjectRegistry::new(config.defaults.project, config.projects)?;
274        let service = Arc::new(GatewayService::new_with_config_path(
275            registry,
276            Arc::new(SubprocessRunner {
277                binary: stl,
278                timeout,
279            }),
280            config_path,
281        ));
282        listener_configs
283            .into_iter()
284            .map(|listener| {
285                let auth = auth_from_listener(&listener)?;
286                Ok(GatewayListener {
287                    name: listener.name,
288                    addr: listener.addr,
289                    runtime: GatewayRuntime {
290                        service: service.clone(),
291                        auth,
292                    },
293                })
294            })
295            .collect()
296    }
297}
298
299#[derive(Clone)]
300enum GatewayAuth {
301    Bearer { token_env: String },
302    OAuth(Arc<OAuthRuntime>),
303    None,
304}
305
306#[derive(Clone)]
307struct OAuthRuntime {
308    config: OAuthConfig,
309    store: OAuthStore,
310}
311
312pub async fn serve(runtime: GatewayRuntime, addr: SocketAddr) -> Result<()> {
313    let listener = tokio::net::TcpListener::bind(addr)
314        .await
315        .map_err(|err| ShuttleError::Store(err.to_string()))?;
316    axum::serve(listener, router(runtime))
317        .await
318        .map_err(|err| ShuttleError::Store(err.to_string()))
319}
320
321pub async fn serve_listeners(listeners: Vec<GatewayListener>) -> Result<()> {
322    if listeners.is_empty() {
323        return Err(ShuttleError::Store(
324            "at least one listener is required".to_owned(),
325        ));
326    }
327    let mut tasks = Vec::new();
328    for listener in listeners {
329        let addr = listener.addr;
330        let runtime = listener.runtime;
331        let name = listener.name;
332        tasks.push(tokio::spawn(async move {
333            let tcp = tokio::net::TcpListener::bind(addr)
334                .await
335                .map_err(|err| ShuttleError::Store(format!("listener {name}: {err}")))?;
336            axum::serve(tcp, router(runtime))
337                .await
338                .map_err(|err| ShuttleError::Store(format!("listener {name}: {err}")))
339        }));
340    }
341    for task in tasks {
342        task.await
343            .map_err(|err| ShuttleError::Store(err.to_string()))??;
344    }
345    Ok(())
346}
347
348pub fn router(runtime: GatewayRuntime) -> Router {
349    Router::new()
350        .route("/api/projects", get(api_projects).post(api_add_project))
351        .route("/api/projects/current", get(api_current_project))
352        .route("/api/projects/use", post(api_use_project))
353        .route("/api/recall", post(api_recall))
354        .route("/api/remember", post(api_remember))
355        .route("/api/context", get(api_context))
356        .route("/api/tasks", get(api_tasks).post(api_create_task))
357        .route("/api/tasks/{id}", patch(api_update_task))
358        .route("/api/tasks/{id}/done", post(api_done_task))
359        .route(
360            "/mcp",
361            get(mcp_health)
362                .post(mcp_post)
363                .delete(mcp_delete)
364                .options(mcp_options),
365        )
366        .route(
367            "/.well-known/oauth-protected-resource",
368            get(oauth_protected_resource),
369        )
370        .route(
371            "/.well-known/oauth-protected-resource/mcp",
372            get(oauth_protected_resource),
373        )
374        .route(
375            "/.well-known/oauth-authorization-server",
376            get(oauth_authorization_server),
377        )
378        .route("/oauth/register", post(oauth_register))
379        .route(
380            "/oauth/authorize",
381            get(oauth_authorize_page).post(oauth_authorize_submit),
382        )
383        .route("/oauth/token", post(oauth_token))
384        .with_state(runtime)
385}
386
387#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
388pub struct Project {
389    pub name: String,
390    pub backend: ProjectBackendKind,
391    #[serde(skip_serializing_if = "Option::is_none")]
392    pub repo: Option<PathBuf>,
393    #[serde(skip_serializing_if = "Option::is_none")]
394    pub db: Option<PathBuf>,
395    #[serde(skip_serializing_if = "String::is_empty")]
396    pub url: String,
397    #[serde(skip)]
398    pub token_env: Option<String>,
399    #[serde(skip_serializing_if = "Option::is_none")]
400    pub description: Option<String>,
401}
402
403#[derive(Debug)]
404pub struct ProjectRegistry {
405    default_project: String,
406    projects: BTreeMap<String, Project>,
407}
408
409impl ProjectRegistry {
410    pub fn new(default_project: String, configs: BTreeMap<String, ProjectConfig>) -> Result<Self> {
411        for (name, cfg) in &configs {
412            validate_project_config(name, cfg)?;
413        }
414        let projects = configs
415            .into_iter()
416            .map(|(name, cfg)| {
417                let project = project_from_config(name.clone(), cfg);
418                (name, project)
419            })
420            .collect::<BTreeMap<_, _>>();
421        if !default_project.is_empty() && !projects.contains_key(&default_project) {
422            return Err(ShuttleError::Store(format!(
423                "default project {default_project:?} is not configured"
424            )));
425        }
426        Ok(Self {
427            default_project,
428            projects,
429        })
430    }
431
432    pub fn list(&self) -> Vec<Project> {
433        self.projects.values().cloned().collect()
434    }
435
436    pub fn names(&self) -> BTreeSet<String> {
437        self.projects.keys().cloned().collect()
438    }
439
440    pub fn insert_named(&mut self, name: String, config: ProjectConfig) -> Result<Project> {
441        validate_project_config(&name, &config)?;
442        if self.projects.contains_key(&name) {
443            return Err(ShuttleError::Store(format!(
444                "project {name:?} is already configured"
445            )));
446        }
447        let project = project_from_config(name.clone(), config);
448        self.projects.insert(name, project.clone());
449        Ok(project)
450    }
451
452    pub fn get(&self, name: &str) -> Option<Project> {
453        self.projects.get(name).cloned()
454    }
455
456    pub fn default(&self) -> Option<Project> {
457        (!self.default_project.is_empty())
458            .then(|| self.get(&self.default_project))
459            .flatten()
460    }
461
462    pub fn resolve(&self, project: &str, write: bool) -> Result<Project> {
463        if !project.is_empty() {
464            return self
465                .get(project)
466                .ok_or_else(|| ShuttleError::Store(format!("unknown project {project:?}")));
467        }
468        if write {
469            return Err(ShuttleError::Store(
470                "project is required for write operations".to_owned(),
471            ));
472        }
473        self.default()
474            .ok_or_else(|| ShuttleError::Store("project is required".to_owned()))
475    }
476}
477
478fn project_from_config(name: String, cfg: ProjectConfig) -> Project {
479    Project {
480        name,
481        backend: cfg.backend,
482        repo: cfg.repo,
483        db: cfg.db,
484        url: normalize_public_url(&cfg.url),
485        token_env: cfg.token_env,
486        description: cfg.description,
487    }
488}
489
490fn validate_project_config(name: &str, project: &ProjectConfig) -> Result<()> {
491    validate_project_name(name)?;
492    match project.backend {
493        ProjectBackendKind::Local => {
494            let Some(repo) = &project.repo else {
495                return Err(ShuttleError::Store(format!(
496                    "project {name:?} repo is required for local backend"
497                )));
498            };
499            if !repo.is_absolute() {
500                return Err(ShuttleError::Store(format!(
501                    "project {name:?} repo must be an absolute path"
502                )));
503            }
504            if let Some(db) = &project.db {
505                if !db.is_absolute() {
506                    return Err(ShuttleError::Store(format!(
507                        "project {name:?} db must be an absolute path when set"
508                    )));
509                }
510            }
511        }
512        ProjectBackendKind::Http => {
513            if project.url.trim().is_empty() {
514                return Err(ShuttleError::Store(format!(
515                    "project {name:?} url is required for http backend"
516                )));
517            }
518        }
519    }
520    Ok(())
521}
522
523fn validate_project_name(name: &str) -> Result<()> {
524    if name.trim().is_empty() {
525        return Err(ShuttleError::Store(
526            "project name cannot be empty".to_owned(),
527        ));
528    }
529    Ok(())
530}
531
532fn normalize_project_name(name: &str) -> Result<String> {
533    validate_project_name(name)?;
534    Ok(name.trim().to_owned())
535}
536
537fn unique_project_name(base_name: &str, existing_names: &BTreeSet<String>) -> String {
538    if !existing_names.contains(base_name) {
539        return base_name.to_owned();
540    }
541    for index in 2.. {
542        let candidate = format!("{base_name}-{index}");
543        if !existing_names.contains(&candidate) {
544            return candidate;
545        }
546    }
547    unreachable!("unbounded suffix search always returns")
548}
549
550fn persist_project_config(
551    config_path: &Path,
552    base_name: &str,
553    config: &ProjectConfig,
554    registry_names: &BTreeSet<String>,
555) -> Result<String> {
556    let raw =
557        std::fs::read_to_string(config_path).map_err(|err| ShuttleError::Store(err.to_string()))?;
558    let mut document = raw
559        .parse::<DocumentMut>()
560        .map_err(|err| ShuttleError::Serialization(err.to_string()))?;
561    let mut existing_names = registry_names.clone();
562    if let Some(projects) = document.get("projects").and_then(Item::as_table) {
563        existing_names.extend(projects.iter().map(|(name, _)| name.to_owned()));
564    }
565    let name = unique_project_name(base_name, &existing_names);
566
567    let projects = document
568        .entry("projects")
569        .or_insert_with(|| Item::Table(Table::new()))
570        .as_table_mut()
571        .ok_or_else(|| ShuttleError::Store("projects must be a TOML table".to_owned()))?;
572    projects[&name] = project_config_item(config);
573    write_config_atomically(config_path, document.to_string().as_bytes())?;
574    Ok(name)
575}
576
577fn project_config_item(config: &ProjectConfig) -> Item {
578    let mut table = Table::new();
579    table["backend"] = value(match config.backend {
580        ProjectBackendKind::Local => "local",
581        ProjectBackendKind::Http => "http",
582    });
583    if let Some(repo) = &config.repo {
584        table["repo"] = value(repo.display().to_string());
585    }
586    if let Some(db) = &config.db {
587        table["db"] = value(db.display().to_string());
588    }
589    if !config.url.trim().is_empty() {
590        table["url"] = value(normalize_public_url(&config.url));
591    }
592    if let Some(token_env) = &config.token_env {
593        if !token_env.trim().is_empty() {
594            table["token_env"] = value(token_env.trim());
595        }
596    }
597    if let Some(description) = &config.description {
598        if !description.trim().is_empty() {
599            table["description"] = value(description.trim());
600        }
601    }
602    Item::Table(table)
603}
604
605fn write_config_atomically(path: &Path, contents: &[u8]) -> Result<()> {
606    let dir = path.parent().unwrap_or_else(|| Path::new("."));
607    let mut temp =
608        tempfile::NamedTempFile::new_in(dir).map_err(|err| ShuttleError::Store(err.to_string()))?;
609    temp.write_all(contents)
610        .map_err(|err| ShuttleError::Store(err.to_string()))?;
611    temp.flush()
612        .map_err(|err| ShuttleError::Store(err.to_string()))?;
613    temp.persist(path)
614        .map_err(|err| ShuttleError::Store(err.to_string()))?;
615    Ok(())
616}
617
618#[derive(Debug, Serialize)]
619pub struct ServiceResponse {
620    pub project: String,
621    pub result: Value,
622    #[serde(skip_serializing_if = "Option::is_none")]
623    pub stored: Option<bool>,
624}
625
626pub struct GatewayService {
627    projects: Mutex<ProjectRegistry>,
628    runner: Arc<dyn Runner>,
629    config_path: Option<PathBuf>,
630    current: Mutex<String>,
631}
632
633impl GatewayService {
634    pub fn new(projects: ProjectRegistry, runner: Arc<dyn Runner>) -> Self {
635        Self::new_with_config_path(projects, runner, None)
636    }
637
638    pub fn new_with_config_path(
639        projects: ProjectRegistry,
640        runner: Arc<dyn Runner>,
641        config_path: Option<PathBuf>,
642    ) -> Self {
643        Self {
644            projects: Mutex::new(projects),
645            runner,
646            config_path,
647            current: Mutex::new(String::new()),
648        }
649    }
650
651    pub fn list_projects(&self) -> Result<Vec<Project>> {
652        Ok(self.projects()?.list())
653    }
654
655    pub fn add_project(
656        &self,
657        name: &str,
658        config: ProjectConfig,
659        make_current: bool,
660    ) -> Result<Project> {
661        // Keep registry and current-project locks separate so gateway commands never hold
662        // mutable registry access while later running backend work.
663        let base_name = normalize_project_name(name)?;
664        validate_project_config(&base_name, &config)?;
665        let project = {
666            let mut projects = self.projects()?;
667            let name = if let Some(config_path) = &self.config_path {
668                persist_project_config(config_path, &base_name, &config, &projects.names())?
669            } else {
670                unique_project_name(&base_name, &projects.names())
671            };
672            projects.insert_named(name, config)?
673        };
674        if make_current {
675            *self.current()? = project.name.clone();
676        }
677        Ok(project)
678    }
679
680    pub fn use_project(&self, name: &str) -> Result<Project> {
681        // Validate against the registry before updating current; current_project() falls
682        // back to the default if future runtime removal leaves this value stale.
683        let project = {
684            let projects = self.projects()?;
685            projects
686                .get(name)
687                .ok_or_else(|| ShuttleError::Store(format!("unknown project {name:?}")))?
688        };
689        *self.current()? = name.to_owned();
690        Ok(project)
691    }
692
693    pub fn current_project(&self) -> Result<Project> {
694        let current = self.current()?.clone();
695        if !current.is_empty() {
696            if let Some(project) = self.projects()?.get(&current) {
697                return Ok(project);
698            }
699        }
700        self.projects()?
701            .default()
702            .ok_or_else(|| ShuttleError::Store("no current or default project".to_owned()))
703    }
704
705    pub async fn context(&self, project: &str) -> Result<ServiceResponse> {
706        self.run(project, false, &["context"]).await
707    }
708
709    pub async fn recall(&self, project: &str, query: &str) -> Result<ServiceResponse> {
710        require_non_empty(query, "query is required")?;
711        self.run(project, false, &["recall", query]).await
712    }
713
714    pub async fn remember(&self, project: &str, kind: &str, text: &str) -> Result<ServiceResponse> {
715        require_non_empty(text, "text is required")?;
716        let command = match kind {
717            "" | "memory" => "remember",
718            "decision" => "decide",
719            "observation" => "observe",
720            "pattern" => "pattern",
721            "fact" => "fact",
722            "bug" => "bug",
723            other => {
724                return Err(ShuttleError::Store(format!(
725                    "unknown memory kind {other:?}"
726                )))
727            }
728        };
729        self.run(project, true, &[command, text]).await
730    }
731
732    pub async fn task_list(&self, project: &str) -> Result<ServiceResponse> {
733        self.run(project, false, &["task", "list"]).await
734    }
735
736    pub async fn task_create(
737        &self,
738        project: &str,
739        title: &str,
740        body: &str,
741    ) -> Result<ServiceResponse> {
742        require_non_empty(title, "title is required")?;
743        let content = if body.is_empty() {
744            title.to_owned()
745        } else {
746            format!("{title}\n\n{body}")
747        };
748        self.run(project, true, &["task", "create", &content]).await
749    }
750
751    pub async fn task_update(
752        &self,
753        project: &str,
754        id: &str,
755        text: &str,
756    ) -> Result<ServiceResponse> {
757        require_non_empty(id, "task id is required")?;
758        require_non_empty(text, "text is required")?;
759        self.run(project, true, &["task", "update", id, text]).await
760    }
761
762    pub async fn task_done(&self, project: &str, id: &str) -> Result<ServiceResponse> {
763        require_non_empty(id, "task id is required")?;
764        self.run(project, true, &["task", "done", id]).await
765    }
766
767    async fn run(&self, project: &str, write: bool, args: &[&str]) -> Result<ServiceResponse> {
768        let project = {
769            let projects = self.projects()?;
770            projects.resolve(project, write)?
771        };
772        let result = self.runner.run(&project, args).await.map_err(|err| {
773            ShuttleError::Store(format!("stl failed for project {}: {err}", project.name))
774        })?;
775        Ok(ServiceResponse {
776            project: project.name,
777            result,
778            stored: write.then_some(true),
779        })
780    }
781
782    fn projects(&self) -> Result<MutexGuard<'_, ProjectRegistry>> {
783        self.projects
784            .lock()
785            .map_err(|err| ShuttleError::Store(err.to_string()))
786    }
787
788    fn current(&self) -> Result<MutexGuard<'_, String>> {
789        self.current
790            .lock()
791            .map_err(|err| ShuttleError::Store(err.to_string()))
792    }
793}
794
795fn require_non_empty(value: &str, message: &str) -> Result<()> {
796    if value.trim().is_empty() {
797        return Err(ShuttleError::Store(message.to_owned()));
798    }
799    Ok(())
800}
801
802#[async_trait]
803pub trait Runner: Send + Sync {
804    async fn run(&self, project: &Project, args: &[&str]) -> std::result::Result<Value, String>;
805}
806
807pub struct SubprocessRunner {
808    binary: PathBuf,
809    timeout: Duration,
810}
811
812#[async_trait]
813impl Runner for SubprocessRunner {
814    async fn run(&self, project: &Project, args: &[&str]) -> std::result::Result<Value, String> {
815        if project.backend == ProjectBackendKind::Http {
816            return self.run_http(project, args).await;
817        }
818        let repo = project
819            .repo
820            .as_ref()
821            .ok_or_else(|| "repo is required for local backend".to_owned())?;
822        let mut command = Command::new(&self.binary);
823        command.arg("--json").args(args).current_dir(repo);
824        let output = tokio::time::timeout(self.timeout, command.output())
825            .await
826            .map_err(|_| format!("timed out after {}s", self.timeout.as_secs()))?
827            .map_err(|err| err.to_string())?;
828        if !output.status.success() {
829            let stderr = String::from_utf8_lossy(&output.stderr).trim().to_owned();
830            return Err(if stderr.is_empty() {
831                format!("exit status {}", output.status)
832            } else {
833                stderr
834            });
835        }
836        serde_json::from_slice(&output.stdout).map_err(|err| err.to_string())
837    }
838}
839
840impl SubprocessRunner {
841    async fn run_http(
842        &self,
843        project: &Project,
844        args: &[&str],
845    ) -> std::result::Result<Value, String> {
846        let project = project.clone();
847        let args = args.iter().map(|arg| (*arg).to_owned()).collect::<Vec<_>>();
848        let timeout = self.timeout;
849        tokio::task::spawn_blocking(move || http_backend_call(&project, &args, timeout))
850            .await
851            .map_err(|err| err.to_string())?
852    }
853}
854
855fn http_backend_call(
856    project: &Project,
857    args: &[String],
858    timeout: Duration,
859) -> std::result::Result<Value, String> {
860    let base = project.url.trim_end_matches('/');
861    let agent = ureq::AgentBuilder::new().timeout(timeout).build();
862    let request = |method: &str, path: &str| {
863        let req = match method {
864            "GET" => agent.get(&format!("{base}{path}")),
865            "PATCH" => agent.request("PATCH", &format!("{base}{path}")),
866            _ => agent.post(&format!("{base}{path}")),
867        };
868        if let Some(token_env) = &project.token_env {
869            if let Ok(token) = env::var(token_env) {
870                if !token.is_empty() {
871                    return req.set(header::AUTHORIZATION.as_str(), &format!("Bearer {token}"));
872                }
873            }
874        }
875        req
876    };
877    let response = match args {
878        [cmd] if cmd == "context" => request("GET", "/api/context").call(),
879        [cmd, query] if cmd == "recall" => {
880            request("POST", "/api/recall").send_json(json!({ "query": query }))
881        }
882        [cmd, text] if is_memory_command(cmd) => request("POST", "/api/remember")
883            .send_json(json!({ "kind": memory_kind_for_command(cmd), "text": text })),
884        [task, cmd] if task == "task" && cmd == "list" => request("GET", "/api/tasks").call(),
885        [task, cmd, content] if task == "task" && cmd == "create" => {
886            request("POST", "/api/tasks").send_json(json!({ "title": content, "body": "" }))
887        }
888        [task, cmd, id, text] if task == "task" && cmd == "update" => {
889            request("PATCH", &format!("/api/tasks/{id}")).send_json(json!({ "text": text }))
890        }
891        [task, cmd, id] if task == "task" && cmd == "done" => {
892            request("POST", &format!("/api/tasks/{id}/done")).send_json(json!({}))
893        }
894        _ => {
895            return Err(format!(
896                "unsupported http backend command: {}",
897                args.join(" ")
898            ))
899        }
900    };
901    let response = response.map_err(|err| match err {
902        ureq::Error::Status(status, response) => {
903            let body = response.into_string().unwrap_or_default();
904            if body.trim().is_empty() {
905                format!("http backend returned status {status}")
906            } else {
907                format!("http backend returned status {status}: {body}")
908            }
909        }
910        ureq::Error::Transport(err) => err.to_string(),
911    })?;
912    response.into_json::<Value>().map_err(|err| err.to_string())
913}
914
915fn is_memory_command(command: &str) -> bool {
916    matches!(
917        command,
918        "remember" | "decide" | "observe" | "pattern" | "fact" | "bug"
919    )
920}
921
922fn memory_kind_for_command(command: &str) -> &str {
923    match command {
924        "decide" => "decision",
925        "observe" => "observation",
926        "pattern" => "pattern",
927        "fact" => "fact",
928        "bug" => "bug",
929        _ => "memory",
930    }
931}
932
933async fn api_projects(State(runtime): State<GatewayRuntime>, headers: HeaderMap) -> Response {
934    if let Err(response) = authorize(&runtime, &headers, "/api/projects", false) {
935        return *response;
936    }
937    match runtime.service.list_projects() {
938        Ok(projects) => Json(json!({ "projects": projects })).into_response(),
939        Err(err) => error_response(err),
940    }
941}
942
943#[derive(Debug, Deserialize)]
944struct AddProjectRequest {
945    #[serde(default)]
946    name: String,
947    #[serde(default)]
948    backend: ProjectBackendKind,
949    #[serde(default)]
950    repo: Option<PathBuf>,
951    #[serde(default)]
952    db: Option<PathBuf>,
953    #[serde(default)]
954    url: String,
955    #[serde(default)]
956    token_env: Option<String>,
957    #[serde(default)]
958    description: Option<String>,
959    #[serde(default)]
960    make_current: bool,
961}
962
963impl AddProjectRequest {
964    fn into_parts(self) -> (String, ProjectConfig, bool) {
965        (
966            self.name,
967            ProjectConfig {
968                backend: self.backend,
969                repo: self.repo,
970                db: self.db,
971                url: self.url,
972                token_env: self.token_env,
973                description: self.description,
974            },
975            self.make_current,
976        )
977    }
978}
979
980async fn api_add_project(
981    State(runtime): State<GatewayRuntime>,
982    headers: HeaderMap,
983    Json(request): Json<AddProjectRequest>,
984) -> Response {
985    if let Err(response) = authorize(&runtime, &headers, "/api/projects", false) {
986        return *response;
987    }
988    let (name, config, make_current) = request.into_parts();
989    match runtime.service.add_project(&name, config, make_current) {
990        Ok(project) => (StatusCode::CREATED, Json(project)).into_response(),
991        Err(err) => error_response(err),
992    }
993}
994
995async fn api_current_project(
996    State(runtime): State<GatewayRuntime>,
997    headers: HeaderMap,
998) -> Response {
999    if let Err(response) = authorize(&runtime, &headers, "/api/projects/current", false) {
1000        return *response;
1001    }
1002    match runtime.service.current_project() {
1003        Ok(project) => Json(project).into_response(),
1004        Err(err) if err.to_string().contains("no current or default project") => (
1005            StatusCode::NOT_FOUND,
1006            Json(json!({"error": err.to_string()})),
1007        )
1008            .into_response(),
1009        Err(err) => error_response(err),
1010    }
1011}
1012
1013#[derive(Deserialize)]
1014struct ProjectRequest {
1015    #[serde(default)]
1016    project: String,
1017}
1018
1019async fn api_use_project(
1020    State(runtime): State<GatewayRuntime>,
1021    headers: HeaderMap,
1022    Json(request): Json<ProjectRequest>,
1023) -> Response {
1024    if let Err(response) = authorize(&runtime, &headers, "/api/projects/use", false) {
1025        return *response;
1026    }
1027    match runtime.service.use_project(&request.project) {
1028        Ok(project) => Json(project).into_response(),
1029        Err(err) => error_response(err),
1030    }
1031}
1032
1033#[derive(Deserialize)]
1034struct RecallRequest {
1035    #[serde(default)]
1036    project: String,
1037    #[serde(default)]
1038    query: String,
1039}
1040
1041async fn api_recall(
1042    State(runtime): State<GatewayRuntime>,
1043    headers: HeaderMap,
1044    Json(request): Json<RecallRequest>,
1045) -> Response {
1046    service_response(
1047        &runtime,
1048        &headers,
1049        "/api/recall",
1050        runtime
1051            .service
1052            .recall(&request.project, &request.query)
1053            .await,
1054    )
1055}
1056
1057#[derive(Deserialize)]
1058struct RememberRequest {
1059    #[serde(default)]
1060    project: String,
1061    #[serde(default)]
1062    kind: String,
1063    #[serde(default)]
1064    text: String,
1065}
1066
1067async fn api_remember(
1068    State(runtime): State<GatewayRuntime>,
1069    headers: HeaderMap,
1070    Json(request): Json<RememberRequest>,
1071) -> Response {
1072    service_response(
1073        &runtime,
1074        &headers,
1075        "/api/remember",
1076        runtime
1077            .service
1078            .remember(&request.project, &request.kind, &request.text)
1079            .await,
1080    )
1081}
1082
1083async fn api_context(
1084    State(runtime): State<GatewayRuntime>,
1085    headers: HeaderMap,
1086    Query(request): Query<ProjectRequest>,
1087) -> Response {
1088    service_response(
1089        &runtime,
1090        &headers,
1091        "/api/context",
1092        runtime.service.context(&request.project).await,
1093    )
1094}
1095
1096async fn api_tasks(
1097    State(runtime): State<GatewayRuntime>,
1098    headers: HeaderMap,
1099    Query(request): Query<ProjectRequest>,
1100) -> Response {
1101    service_response(
1102        &runtime,
1103        &headers,
1104        "/api/tasks",
1105        runtime.service.task_list(&request.project).await,
1106    )
1107}
1108
1109#[derive(Deserialize)]
1110struct CreateTaskRequest {
1111    #[serde(default)]
1112    project: String,
1113    #[serde(default)]
1114    title: String,
1115    #[serde(default)]
1116    body: String,
1117}
1118
1119async fn api_create_task(
1120    State(runtime): State<GatewayRuntime>,
1121    headers: HeaderMap,
1122    Json(request): Json<CreateTaskRequest>,
1123) -> Response {
1124    service_response(
1125        &runtime,
1126        &headers,
1127        "/api/tasks",
1128        runtime
1129            .service
1130            .task_create(&request.project, &request.title, &request.body)
1131            .await,
1132    )
1133}
1134
1135#[derive(Deserialize)]
1136struct UpdateTaskRequest {
1137    #[serde(default)]
1138    project: String,
1139    #[serde(default)]
1140    text: String,
1141}
1142
1143async fn api_update_task(
1144    State(runtime): State<GatewayRuntime>,
1145    headers: HeaderMap,
1146    AxumPath(id): AxumPath<String>,
1147    Json(request): Json<UpdateTaskRequest>,
1148) -> Response {
1149    service_response(
1150        &runtime,
1151        &headers,
1152        "/api/tasks",
1153        runtime
1154            .service
1155            .task_update(&request.project, &id, &request.text)
1156            .await,
1157    )
1158}
1159
1160async fn api_done_task(
1161    State(runtime): State<GatewayRuntime>,
1162    headers: HeaderMap,
1163    AxumPath(id): AxumPath<String>,
1164    Json(request): Json<ProjectRequest>,
1165) -> Response {
1166    service_response(
1167        &runtime,
1168        &headers,
1169        "/api/tasks",
1170        runtime.service.task_done(&request.project, &id).await,
1171    )
1172}
1173
1174fn service_response(
1175    runtime: &GatewayRuntime,
1176    headers: &HeaderMap,
1177    path: &str,
1178    response: Result<ServiceResponse>,
1179) -> Response {
1180    if let Err(response) = authorize(runtime, headers, path, false) {
1181        return *response;
1182    }
1183    match response {
1184        Ok(value) => Json(value).into_response(),
1185        Err(err) => error_response(err),
1186    }
1187}
1188
1189#[derive(Deserialize)]
1190struct RpcRequest {
1191    jsonrpc: Option<String>,
1192    id: Option<Value>,
1193    method: String,
1194    #[serde(default)]
1195    params: Value,
1196}
1197
1198async fn mcp_health(State(runtime): State<GatewayRuntime>, headers: HeaderMap) -> Response {
1199    authorize(&runtime, &headers, "/mcp", true)
1200        .map(|_| with_cors(Json(json!({ "status": "ok" }))))
1201        .unwrap_or_else(|response| *response)
1202}
1203
1204async fn mcp_delete(State(runtime): State<GatewayRuntime>, headers: HeaderMap) -> Response {
1205    authorize(&runtime, &headers, "/mcp", true)
1206        .map(|_| with_cors(StatusCode::OK))
1207        .unwrap_or_else(|response| *response)
1208}
1209
1210async fn mcp_options() -> Response {
1211    with_cors(StatusCode::NO_CONTENT)
1212}
1213
1214async fn mcp_post(
1215    State(runtime): State<GatewayRuntime>,
1216    headers: HeaderMap,
1217    Json(request): Json<RpcRequest>,
1218) -> Response {
1219    match authorize(&runtime, &headers, "/mcp", true) {
1220        Ok(()) if request.method == "notifications/initialized" => {
1221            with_cors(StatusCode::NO_CONTENT)
1222        }
1223        Ok(()) => with_cors(Json(handle_mcp(&runtime.service, request).await)),
1224        Err(response) => *response,
1225    }
1226}
1227
1228async fn handle_mcp(service: &GatewayService, request: RpcRequest) -> Value {
1229    let id = request.id.unwrap_or(Value::Null);
1230    if request.jsonrpc.as_deref() != Some("2.0") {
1231        return rpc_error(id, -32600, "invalid jsonrpc version");
1232    }
1233    match request.method.as_str() {
1234        "initialize" => rpc_ok(
1235            id,
1236            json!({
1237                "protocolVersion": "2025-11-25",
1238                "capabilities": { "tools": {} },
1239                "serverInfo": { "name": "shuttle-gateway", "version": env!("CARGO_PKG_VERSION") }
1240            }),
1241        ),
1242        "notifications/initialized" => json!({"jsonrpc": "2.0"}),
1243        "tools/list" => rpc_ok(id, json!({ "tools": gateway_tools() })),
1244        "tools/call" => match mcp_call_tool(service, request.params).await {
1245            Ok(value) => rpc_ok(
1246                id,
1247                json!({
1248                    "content": [{ "type": "text", "text": value.to_string() }],
1249                    "structuredContent": value,
1250                }),
1251            ),
1252            Err(err) => rpc_error(id, -32603, &err.to_string()),
1253        },
1254        _ => rpc_error(id, -32601, "method not found"),
1255    }
1256}
1257
1258async fn mcp_call_tool(service: &GatewayService, params: Value) -> Result<Value> {
1259    let name = params
1260        .get("name")
1261        .and_then(Value::as_str)
1262        .ok_or_else(|| ShuttleError::Store("missing tool name".to_owned()))?;
1263    let args = params
1264        .get("arguments")
1265        .cloned()
1266        .unwrap_or_else(|| json!({}));
1267    match name {
1268        "shuttle_projects" => Ok(json!({ "projects": service.list_projects()? })),
1269        "shuttle_project_add" => {
1270            let (name, config, make_current) = project_add_args(&args)?;
1271            serde_json::to_value(service.add_project(&name, config, make_current)?)
1272                .map_err(|err| ShuttleError::Serialization(err.to_string()))
1273        }
1274        "shuttle_current_project" => serde_json::to_value(service.current_project()?)
1275            .map_err(|err| ShuttleError::Serialization(err.to_string())),
1276        "shuttle_use_project" => {
1277            serde_json::to_value(service.use_project(str_arg(&args, "project")?)?)
1278                .map_err(|err| ShuttleError::Serialization(err.to_string()))
1279        }
1280        "shuttle_context" => service
1281            .context(optional_str_arg(&args, "project"))
1282            .await
1283            .and_then(to_value),
1284        "shuttle_recall" => service
1285            .recall(optional_str_arg(&args, "project"), str_arg(&args, "query")?)
1286            .await
1287            .and_then(to_value),
1288        "shuttle_remember" => service
1289            .remember(
1290                str_arg(&args, "project")?,
1291                optional_str_arg(&args, "kind"),
1292                str_arg(&args, "text")?,
1293            )
1294            .await
1295            .and_then(to_value),
1296        "shuttle_task_list" => service
1297            .task_list(optional_str_arg(&args, "project"))
1298            .await
1299            .and_then(to_value),
1300        "shuttle_task_create" => service
1301            .task_create(
1302                str_arg(&args, "project")?,
1303                str_arg(&args, "title")?,
1304                optional_str_arg(&args, "body"),
1305            )
1306            .await
1307            .and_then(to_value),
1308        "shuttle_task_update" => service
1309            .task_update(
1310                str_arg(&args, "project")?,
1311                str_arg(&args, "task_id")?,
1312                str_arg(&args, "text")?,
1313            )
1314            .await
1315            .and_then(to_value),
1316        "shuttle_task_done" => service
1317            .task_done(str_arg(&args, "project")?, str_arg(&args, "task_id")?)
1318            .await
1319            .and_then(to_value),
1320        other => Err(ShuttleError::Store(format!("unknown tool: {other}"))),
1321    }
1322}
1323
1324fn to_value(response: ServiceResponse) -> Result<Value> {
1325    serde_json::to_value(response).map_err(|err| ShuttleError::Serialization(err.to_string()))
1326}
1327
1328fn str_arg<'a>(args: &'a Value, key: &str) -> Result<&'a str> {
1329    args.get(key)
1330        .and_then(Value::as_str)
1331        .filter(|value| !value.is_empty())
1332        .ok_or_else(|| ShuttleError::Store(format!("{key} is required")))
1333}
1334
1335fn optional_str_arg<'a>(args: &'a Value, key: &str) -> &'a str {
1336    args.get(key).and_then(Value::as_str).unwrap_or("")
1337}
1338
1339fn project_add_args(args: &Value) -> Result<(String, ProjectConfig, bool)> {
1340    Ok((
1341        str_arg(args, "name")?.to_owned(),
1342        ProjectConfig {
1343            backend: project_backend_arg(args, "backend")?,
1344            repo: optional_path_arg(args, "repo"),
1345            db: optional_path_arg(args, "db"),
1346            url: optional_string_arg(args, "url").unwrap_or_default(),
1347            token_env: optional_string_arg(args, "token_env"),
1348            description: optional_string_arg(args, "description"),
1349        },
1350        optional_bool_arg(args, "make_current"),
1351    ))
1352}
1353
1354fn project_backend_arg(args: &Value, key: &str) -> Result<ProjectBackendKind> {
1355    match optional_str_arg(args, key) {
1356        "" => Ok(ProjectBackendKind::Local),
1357        "local" => Ok(ProjectBackendKind::Local),
1358        "http" => Ok(ProjectBackendKind::Http),
1359        other => Err(ShuttleError::Store(format!(
1360            "{key} must be one of: local, http; got {other:?}"
1361        ))),
1362    }
1363}
1364
1365fn optional_path_arg(args: &Value, key: &str) -> Option<PathBuf> {
1366    optional_string_arg(args, key).map(PathBuf::from)
1367}
1368
1369fn optional_string_arg(args: &Value, key: &str) -> Option<String> {
1370    args.get(key)
1371        .and_then(Value::as_str)
1372        .map(str::trim)
1373        .filter(|value| !value.is_empty())
1374        .map(ToOwned::to_owned)
1375}
1376
1377fn optional_bool_arg(args: &Value, key: &str) -> bool {
1378    args.get(key).and_then(Value::as_bool).unwrap_or(false)
1379}
1380
1381fn rpc_ok(id: Value, result: Value) -> Value {
1382    json!({ "jsonrpc": "2.0", "id": id, "result": result })
1383}
1384
1385fn rpc_error(id: Value, code: i64, message: &str) -> Value {
1386    json!({ "jsonrpc": "2.0", "id": id, "error": { "code": code, "message": message } })
1387}
1388
1389fn gateway_tools() -> Vec<Value> {
1390    vec![
1391        tool(
1392            "shuttle_projects",
1393            "List configured Shuttle projects",
1394            json!({}),
1395            vec![],
1396            projects_output_schema(),
1397        ),
1398        tool(
1399            "shuttle_project_add",
1400            "Add a Shuttle project to the running gateway",
1401            json!({
1402                "name": string_schema("Project name"),
1403                "backend": enum_schema("Project backend; defaults to local", &["local", "http"]),
1404                "repo": nullable_string_schema("Absolute local repository path for local backends"),
1405                "db": nullable_string_schema("Absolute local Shuttle database path"),
1406                "url": string_schema("HTTP project base URL for http backends"),
1407                "token_env": nullable_string_schema("Environment variable name containing the backend bearer token"),
1408                "description": nullable_string_schema("Project description"),
1409                "make_current": bool_schema("Set the added project as the current project"),
1410            }),
1411            vec!["name"],
1412            project_output_schema(),
1413        ),
1414        tool(
1415            "shuttle_current_project",
1416            "Read the current or default project",
1417            json!({}),
1418            vec![],
1419            project_output_schema(),
1420        ),
1421        tool(
1422            "shuttle_use_project",
1423            "Set the current project",
1424            json!({"project": string_schema("Configured project name")}),
1425            vec!["project"],
1426            project_output_schema(),
1427        ),
1428        tool(
1429            "shuttle_context",
1430            "Read Shuttle context for a project",
1431            json!({"project": string_schema("Configured project name; optional with default project")}),
1432            vec![],
1433            service_response_output_schema(),
1434        ),
1435        tool(
1436            "shuttle_recall",
1437            "Search Shuttle memories in a project",
1438            json!({"project": string_schema("Configured project name; optional with default project"), "query": string_schema("Recall query")}),
1439            vec!["query"],
1440            service_response_output_schema(),
1441        ),
1442        tool(
1443            "shuttle_remember",
1444            "Store a Shuttle memory in a project",
1445            json!({"project": string_schema("Configured project name"), "kind": enum_schema("Memory kind", &["memory", "decision", "observation", "pattern", "fact", "bug"]), "text": string_schema("Memory text")}),
1446            vec!["project", "text"],
1447            service_response_output_schema(),
1448        ),
1449        tool(
1450            "shuttle_task_list",
1451            "List Shuttle tasks in a project",
1452            json!({"project": string_schema("Configured project name; optional with default project")}),
1453            vec![],
1454            service_response_output_schema(),
1455        ),
1456        tool(
1457            "shuttle_task_create",
1458            "Create a Shuttle task in a project",
1459            json!({"project": string_schema("Configured project name"), "title": string_schema("Task title"), "body": string_schema("Optional task body")}),
1460            vec!["project", "title"],
1461            service_response_output_schema(),
1462        ),
1463        tool(
1464            "shuttle_task_update",
1465            "Update a Shuttle task in a project",
1466            json!({"project": string_schema("Configured project name"), "task_id": string_schema("Task UUID"), "text": string_schema("Update text")}),
1467            vec!["project", "task_id", "text"],
1468            service_response_output_schema(),
1469        ),
1470        tool(
1471            "shuttle_task_done",
1472            "Complete a Shuttle task in a project",
1473            json!({"project": string_schema("Configured project name"), "task_id": string_schema("Task UUID")}),
1474            vec!["project", "task_id"],
1475            service_response_output_schema(),
1476        ),
1477    ]
1478}
1479
1480fn tool(
1481    name: &str,
1482    description: &str,
1483    properties: Value,
1484    required: Vec<&str>,
1485    output_schema: Value,
1486) -> Value {
1487    json!({
1488        "name": name,
1489        "description": description,
1490        "inputSchema": {
1491            "type": "object",
1492            "properties": properties,
1493            "required": required,
1494            "additionalProperties": false,
1495        },
1496        "outputSchema": output_schema,
1497    })
1498}
1499
1500fn projects_output_schema() -> Value {
1501    object_schema(
1502        json!({ "projects": { "type": "array", "items": project_schema() } }),
1503        vec!["projects"],
1504    )
1505}
1506
1507fn project_output_schema() -> Value {
1508    project_schema()
1509}
1510
1511fn service_response_output_schema() -> Value {
1512    object_schema(
1513        json!({
1514            "project": string_schema("Configured project name"),
1515            "result": json_schema("Tool result from the selected project"),
1516            "stored": {
1517                "type": "boolean",
1518                "description": "Whether the operation stored data",
1519            },
1520        }),
1521        vec!["project", "result"],
1522    )
1523}
1524
1525fn project_schema() -> Value {
1526    object_schema(
1527        json!({
1528            "name": string_schema("Configured project name"),
1529            "backend": enum_schema("Project backend", &["local", "http"]),
1530            "repo": nullable_string_schema("Local repository path"),
1531            "db": nullable_string_schema("Local Shuttle database path"),
1532            "url": string_schema("HTTP project base URL"),
1533            "description": nullable_string_schema("Project description"),
1534        }),
1535        vec!["name", "backend", "url"],
1536    )
1537}
1538
1539fn object_schema(properties: Value, required: Vec<&str>) -> Value {
1540    json!({
1541        "type": "object",
1542        "properties": properties,
1543        "required": required,
1544        "additionalProperties": true,
1545    })
1546}
1547
1548fn string_schema(description: &str) -> Value {
1549    json!({ "type": "string", "description": description })
1550}
1551
1552fn bool_schema(description: &str) -> Value {
1553    json!({ "type": "boolean", "description": description })
1554}
1555
1556fn nullable_string_schema(description: &str) -> Value {
1557    json!({ "type": ["string", "null"], "description": description })
1558}
1559
1560fn json_schema(description: &str) -> Value {
1561    json!({
1562        "type": ["object", "array", "string", "number", "integer", "boolean", "null"],
1563        "description": description,
1564    })
1565}
1566
1567fn enum_schema(description: &str, values: &[&str]) -> Value {
1568    json!({ "type": "string", "description": description, "enum": values })
1569}
1570
1571fn authorize(
1572    runtime: &GatewayRuntime,
1573    headers: &HeaderMap,
1574    path: &str,
1575    cors: bool,
1576) -> std::result::Result<(), Box<Response>> {
1577    if is_oauth_public_route(path) {
1578        return Ok(());
1579    }
1580    match &runtime.auth {
1581        GatewayAuth::Bearer { token_env } => {
1582            let Some(token) = env::var(token_env).ok().filter(|token| !token.is_empty()) else {
1583                return Ok(());
1584            };
1585            let expected = format!("Bearer {token}");
1586            let ok = headers
1587                .get(header::AUTHORIZATION)
1588                .and_then(|header| header.to_str().ok())
1589                .is_some_and(|actual| constant_time_eq(actual.as_bytes(), expected.as_bytes()));
1590            if ok {
1591                Ok(())
1592            } else if cors {
1593                Err(Box::new(with_cors(StatusCode::UNAUTHORIZED)))
1594            } else {
1595                Err(Box::new(
1596                    (
1597                        StatusCode::UNAUTHORIZED,
1598                        Json(json!({"error": "unauthorized"})),
1599                    )
1600                        .into_response(),
1601                ))
1602            }
1603        }
1604        GatewayAuth::OAuth(oauth) => {
1605            let Some(token) = bearer_token(headers) else {
1606                return Err(Box::new(unauthorized_oauth(&oauth.config)));
1607            };
1608            match oauth.store.validate_access_token(token) {
1609                Ok(true) => Ok(()),
1610                Ok(false) => Err(Box::new(unauthorized_oauth(&oauth.config))),
1611                Err(_) => Err(Box::new(oauth_error(
1612                    StatusCode::UNAUTHORIZED,
1613                    "invalid_token",
1614                    "failed to validate access token",
1615                ))),
1616            }
1617        }
1618        GatewayAuth::None => Ok(()),
1619    }
1620}
1621
1622fn is_oauth_public_route(path: &str) -> bool {
1623    matches!(
1624        path,
1625        "/.well-known/oauth-protected-resource"
1626            | "/.well-known/oauth-protected-resource/mcp"
1627            | "/.well-known/oauth-authorization-server"
1628            | "/oauth/register"
1629            | "/oauth/token"
1630            | "/oauth/authorize"
1631    )
1632}
1633
1634async fn oauth_protected_resource(State(runtime): State<GatewayRuntime>) -> Response {
1635    let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1636        return (
1637            StatusCode::NOT_FOUND,
1638            Json(json!({"error": "oauth is not configured"})),
1639        )
1640            .into_response();
1641    };
1642    Json(oauth::protected_resource_metadata(&oauth.config)).into_response()
1643}
1644
1645async fn oauth_authorization_server(State(runtime): State<GatewayRuntime>) -> Response {
1646    let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1647        return (
1648            StatusCode::NOT_FOUND,
1649            Json(json!({"error": "oauth is not configured"})),
1650        )
1651            .into_response();
1652    };
1653    Json(oauth::authorization_server_metadata(&oauth.config)).into_response()
1654}
1655
1656async fn oauth_register(
1657    State(runtime): State<GatewayRuntime>,
1658    Json(request): Json<oauth::RegisterRequest>,
1659) -> Response {
1660    let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1661        return (
1662            StatusCode::NOT_FOUND,
1663            Json(json!({"error": "oauth is not configured"})),
1664        )
1665            .into_response();
1666    };
1667    match oauth.store.register_client(request) {
1668        Ok(client) => {
1669            let mut body = json!({
1670                "client_id": client.client_id,
1671                "redirect_uris": client.redirect_uris,
1672                "client_name": client.client_name,
1673                "token_endpoint_auth_method": "none",
1674            });
1675            if let Some(secret) = client.client_secret {
1676                body["client_secret"] = json!(secret);
1677            }
1678            (StatusCode::CREATED, Json(body)).into_response()
1679        }
1680        Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
1681    }
1682}
1683
1684async fn oauth_authorize_page(
1685    State(runtime): State<GatewayRuntime>,
1686    Query(request): Query<oauth::AuthorizeRequest>,
1687) -> Response {
1688    let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1689        return (
1690            StatusCode::NOT_FOUND,
1691            Json(json!({"error": "oauth is not configured"})),
1692        )
1693            .into_response();
1694    };
1695    if request.response_type != "code" {
1696        return oauth_error(
1697            StatusCode::BAD_REQUEST,
1698            "unsupported_response_type",
1699            "response_type must be code",
1700        );
1701    }
1702    match oauth
1703        .store
1704        .client_allows_redirect(&request.client_id, &request.redirect_uri)
1705    {
1706        Ok(true) => {
1707            Html(authorize_html(&request, oauth.config.admin_token.is_some())).into_response()
1708        }
1709        Ok(false) => oauth_error(
1710            StatusCode::BAD_REQUEST,
1711            "invalid_request",
1712            "unknown client_id or redirect_uri",
1713        ),
1714        Err(_) => oauth_error(
1715            StatusCode::INTERNAL_SERVER_ERROR,
1716            "server_error",
1717            "failed to validate OAuth client",
1718        ),
1719    }
1720}
1721
1722async fn oauth_authorize_submit(
1723    State(runtime): State<GatewayRuntime>,
1724    Form(form): Form<oauth::AuthorizeForm>,
1725) -> Response {
1726    let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1727        return (
1728            StatusCode::NOT_FOUND,
1729            Json(json!({"error": "oauth is not configured"})),
1730        )
1731            .into_response();
1732    };
1733    if let Some(expected) = oauth.config.admin_token.as_deref() {
1734        if !constant_time_eq(form.admin_token.as_bytes(), expected.as_bytes()) {
1735            return oauth_error(
1736                StatusCode::UNAUTHORIZED,
1737                "access_denied",
1738                "invalid admin token",
1739            );
1740        }
1741    }
1742    let request = oauth::AuthorizeRequest::from(form);
1743    if request.response_type != "code" {
1744        return oauth_error(
1745            StatusCode::BAD_REQUEST,
1746            "unsupported_response_type",
1747            "response_type must be code",
1748        );
1749    }
1750    match oauth.store.create_code(request.clone()) {
1751        Ok(code) => Redirect::to(&oauth::authorize_redirect(
1752            &request.redirect_uri,
1753            &code,
1754            request.state.as_deref(),
1755        ))
1756        .into_response(),
1757        Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
1758    }
1759}
1760
1761async fn oauth_token(
1762    State(runtime): State<GatewayRuntime>,
1763    Form(request): Form<oauth::TokenRequest>,
1764) -> Response {
1765    let GatewayAuth::OAuth(oauth) = &runtime.auth else {
1766        return (
1767            StatusCode::NOT_FOUND,
1768            Json(json!({"error": "oauth is not configured"})),
1769        )
1770            .into_response();
1771    };
1772    if request.grant_type != "authorization_code" {
1773        return oauth_error(
1774            StatusCode::BAD_REQUEST,
1775            "unsupported_grant_type",
1776            "grant_type must be authorization_code",
1777        );
1778    }
1779    match oauth.store.exchange_code(request) {
1780        Ok(token) => Json(token).into_response(),
1781        Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_grant", &err.to_string()),
1782    }
1783}
1784
1785fn error_response(err: ShuttleError) -> Response {
1786    (
1787        StatusCode::BAD_REQUEST,
1788        Json(json!({"error": err.to_string()})),
1789    )
1790        .into_response()
1791}
1792
1793fn bearer_token(headers: &HeaderMap) -> Option<&str> {
1794    headers
1795        .get(header::AUTHORIZATION)
1796        .and_then(|header| header.to_str().ok())
1797        .and_then(|value| {
1798            let (scheme, token) = value.split_once(' ')?;
1799            scheme.eq_ignore_ascii_case("Bearer").then_some(token)
1800        })
1801}
1802
1803fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
1804    let mut diff = left.len() ^ right.len();
1805    for index in 0..left.len().max(right.len()) {
1806        let left = *left.get(index).unwrap_or(&0);
1807        let right = *right.get(index).unwrap_or(&0);
1808        diff |= (left ^ right) as usize;
1809    }
1810    diff == 0
1811}
1812
1813fn with_cors(response: impl IntoResponse) -> Response {
1814    let (mut parts, body) = response.into_response().into_parts();
1815    parts
1816        .headers
1817        .insert("access-control-allow-origin", HeaderValue::from_static("*"));
1818    parts.headers.insert(
1819        "access-control-allow-methods",
1820        HeaderValue::from_static("GET,POST,DELETE,OPTIONS"),
1821    );
1822    parts.headers.insert(
1823        "access-control-allow-headers",
1824        HeaderValue::from_static(
1825            "accept,authorization,content-type,mcp-protocol-version,mcp-session-id",
1826        ),
1827    );
1828    parts.headers.insert(
1829        "access-control-expose-headers",
1830        HeaderValue::from_static("mcp-session-id"),
1831    );
1832    Response::from_parts(parts, body)
1833}
1834
1835fn unauthorized_oauth(config: &OAuthConfig) -> Response {
1836    let mut response = with_cors(StatusCode::UNAUTHORIZED);
1837    let header_value = format!(
1838        r#"Bearer resource_metadata="{}/.well-known/oauth-protected-resource/mcp", scope="mcp""#,
1839        quoted_header_value(&config.public_url)
1840    );
1841    if let Ok(value) = HeaderValue::from_str(&header_value) {
1842        response
1843            .headers_mut()
1844            .insert(header::WWW_AUTHENTICATE, value);
1845    }
1846    response
1847}
1848
1849fn oauth_error(status: StatusCode, code: &str, description: &str) -> Response {
1850    (
1851        status,
1852        Json(json!({ "error": code, "error_description": description })),
1853    )
1854        .into_response()
1855}
1856
1857fn authorize_html(request: &oauth::AuthorizeRequest, requires_admin_token: bool) -> String {
1858    let admin = if requires_admin_token {
1859        r#"<label>Admin token <input name="admin_token" type="password" autocomplete="current-password" required></label>"#
1860    } else {
1861        r#"<input name="admin_token" type="hidden" value="">"#
1862    };
1863    format!(
1864        r#"<!doctype html>
1865<html>
1866<head><meta charset="utf-8"><title>Authorize Shuttle Gateway</title></head>
1867<body>
1868  <h1>Authorize Shuttle Gateway</h1>
1869  <p>{client_id} is requesting access to Shuttle MCP.</p>
1870  <form method="post" action="/oauth/authorize">
1871    {admin}
1872    <input type="hidden" name="response_type" value="{response_type}">
1873    <input type="hidden" name="client_id" value="{client_id}">
1874    <input type="hidden" name="redirect_uri" value="{redirect_uri}">
1875    <input type="hidden" name="state" value="{state}">
1876    <input type="hidden" name="scope" value="{scope}">
1877    <input type="hidden" name="code_challenge" value="{code_challenge}">
1878    <input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
1879    <button type="submit">Authorize</button>
1880  </form>
1881</body>
1882</html>"#,
1883        admin = admin,
1884        response_type = html_escape(&request.response_type),
1885        client_id = html_escape(&request.client_id),
1886        redirect_uri = html_escape(&request.redirect_uri),
1887        state = html_escape(request.state.as_deref().unwrap_or("")),
1888        scope = html_escape(request.scope.as_deref().unwrap_or("mcp")),
1889        code_challenge = html_escape(request.code_challenge.as_deref().unwrap_or("")),
1890        code_challenge_method =
1891            html_escape(request.code_challenge_method.as_deref().unwrap_or("S256")),
1892    )
1893}
1894
1895fn html_escape(value: &str) -> String {
1896    value
1897        .replace('&', "&amp;")
1898        .replace('<', "&lt;")
1899        .replace('>', "&gt;")
1900        .replace('"', "&quot;")
1901        .replace('\'', "&#39;")
1902}
1903
1904fn quoted_header_value(value: &str) -> String {
1905    value.replace('\\', "\\\\").replace('"', "\\\"")
1906}
1907
1908fn default_addr() -> SocketAddr {
1909    "127.0.0.1:8787".parse().expect("valid default address")
1910}
1911
1912fn default_gateway_token_env() -> String {
1913    "SHUTTLE_GATEWAY_TOKEN".to_owned()
1914}
1915
1916fn default_oauth_admin_token_env() -> String {
1917    "SHUTTLE_OAUTH_ADMIN_TOKEN".to_owned()
1918}
1919
1920fn legacy_listener_config(config: &GatewayConfig) -> ListenerConfig {
1921    if config.oauth.public_url.is_empty() {
1922        ListenerConfig {
1923            name: "default".to_owned(),
1924            addr: config.server.addr,
1925            auth: ListenerAuthKind::Bearer,
1926            public_url: String::new(),
1927            oauth_db_path: None,
1928            oauth_admin_token_env: config.oauth.admin_token_env.clone(),
1929            bearer_token_env: config.auth.bearer_token_env.clone(),
1930        }
1931    } else {
1932        ListenerConfig {
1933            name: "default".to_owned(),
1934            addr: config.server.addr,
1935            auth: ListenerAuthKind::OAuth,
1936            public_url: config.oauth.public_url.clone(),
1937            oauth_db_path: config.oauth.db_path.clone(),
1938            oauth_admin_token_env: config.oauth.admin_token_env.clone(),
1939            bearer_token_env: config.auth.bearer_token_env.clone(),
1940        }
1941    }
1942}
1943
1944fn auth_from_listener(listener: &ListenerConfig) -> Result<GatewayAuth> {
1945    match listener.auth {
1946        ListenerAuthKind::Bearer => Ok(GatewayAuth::Bearer {
1947            token_env: listener.bearer_token_env.clone(),
1948        }),
1949        ListenerAuthKind::None => Ok(GatewayAuth::None),
1950        ListenerAuthKind::OAuth => {
1951            let admin_token = env::var(&listener.oauth_admin_token_env).map_err(|_| {
1952                ShuttleError::Store(format!(
1953                    "{} is required when oauth listener {:?} is configured",
1954                    listener.oauth_admin_token_env, listener.name
1955                ))
1956            })?;
1957            if admin_token.is_empty() {
1958                return Err(ShuttleError::Store(format!(
1959                    "{} is required when oauth listener {:?} is configured",
1960                    listener.oauth_admin_token_env, listener.name
1961                )));
1962            }
1963            let db_path = listener.oauth_db_path.clone().ok_or_else(|| {
1964                ShuttleError::Store(format!(
1965                    "oauth_db_path is required for listener {:?}",
1966                    listener.name
1967                ))
1968            })?;
1969            Ok(GatewayAuth::OAuth(Arc::new(OAuthRuntime {
1970                config: OAuthConfig {
1971                    public_url: listener.public_url.clone(),
1972                    admin_token: Some(admin_token),
1973                },
1974                store: OAuthStore::open(db_path)?,
1975            })))
1976        }
1977    }
1978}
1979
1980fn is_loopback_addr(addr: SocketAddr) -> bool {
1981    match addr.ip() {
1982        IpAddr::V4(ip) => ip.is_loopback(),
1983        IpAddr::V6(ip) => ip.is_loopback(),
1984    }
1985}
1986
1987fn normalize_public_url(url: &str) -> String {
1988    url.trim().trim_end_matches('/').to_owned()
1989}
1990
1991#[cfg(test)]
1992mod tests {
1993    use super::*;
1994    use axum::body::Body;
1995    use axum::http::{Method, Request};
1996    use http_body_util::BodyExt;
1997    use std::sync::Mutex;
1998    use tower::ServiceExt;
1999
2000    #[derive(Default)]
2001    struct FakeRunner {
2002        calls: Mutex<Vec<(String, Vec<String>)>>,
2003    }
2004
2005    #[async_trait]
2006    impl Runner for FakeRunner {
2007        async fn run(
2008            &self,
2009            project: &Project,
2010            args: &[&str],
2011        ) -> std::result::Result<Value, String> {
2012            self.calls.lock().unwrap().push((
2013                project.name.clone(),
2014                args.iter().map(|arg| (*arg).to_owned()).collect(),
2015            ));
2016            Ok(json!({"ok": true}))
2017        }
2018    }
2019
2020    fn registry() -> ProjectRegistry {
2021        ProjectRegistry::new(
2022            "demo".to_owned(),
2023            BTreeMap::from([(
2024                "demo".to_owned(),
2025                ProjectConfig {
2026                    backend: ProjectBackendKind::Local,
2027                    repo: Some(PathBuf::from("/tmp/demo")),
2028                    db: None,
2029                    url: String::new(),
2030                    token_env: None,
2031                    description: None,
2032                },
2033            )]),
2034        )
2035        .unwrap()
2036    }
2037
2038    #[test]
2039    fn config_rejects_relative_repo_and_applies_defaults() {
2040        let dir = tempfile::tempdir().unwrap();
2041        let path = dir.path().join("projects.toml");
2042        std::fs::write(&path, "[projects.demo]\nrepo = \"relative\"\n").unwrap();
2043        assert!(GatewayConfig::load(&path).is_err());
2044
2045        std::fs::write(&path, "[projects.demo]\nrepo = \"/tmp/demo\"\n").unwrap();
2046        let cfg = GatewayConfig::load(&path).unwrap();
2047        assert_eq!(cfg.server.addr, default_addr());
2048        assert_eq!(cfg.auth.bearer_token_env, "SHUTTLE_GATEWAY_TOKEN");
2049        assert_eq!(cfg.oauth.admin_token_env, "SHUTTLE_OAUTH_ADMIN_TOKEN");
2050    }
2051
2052    #[test]
2053    fn config_normalizes_oauth_defaults() {
2054        let dir = tempfile::tempdir().unwrap();
2055        let path = dir.path().join("projects.toml");
2056        std::fs::write(
2057            &path,
2058            "[oauth]\npublic_url = \"https://shuttle.example.test/\"\n\n[projects.demo]\nrepo = \"/tmp/demo\"\n",
2059        )
2060        .unwrap();
2061        let cfg = GatewayConfig::load(&path).unwrap();
2062        assert_eq!(cfg.oauth.public_url, "https://shuttle.example.test");
2063        assert_eq!(
2064            cfg.oauth.db_path.unwrap().file_name().unwrap(),
2065            "gateway-oauth.db"
2066        );
2067    }
2068
2069    #[test]
2070    fn config_accepts_http_projects_without_repo() {
2071        let dir = tempfile::tempdir().unwrap();
2072        let path = dir.path().join("projects.toml");
2073        std::fs::write(
2074            &path,
2075            "[projects.demo]\nbackend = \"http\"\nurl = \"http://127.0.0.1:8787\"\ntoken_env = \"DEMO_TOKEN\"\n",
2076        )
2077        .unwrap();
2078        let cfg = GatewayConfig::load(&path).unwrap();
2079        let project = cfg.projects.get("demo").unwrap();
2080
2081        assert_eq!(project.backend, ProjectBackendKind::Http);
2082        assert!(project.repo.is_none());
2083        assert_eq!(project.token_env.as_deref(), Some("DEMO_TOKEN"));
2084    }
2085
2086    #[test]
2087    fn config_rejects_http_projects_without_url() {
2088        let dir = tempfile::tempdir().unwrap();
2089        let path = dir.path().join("projects.toml");
2090        std::fs::write(&path, "[projects.demo]\nbackend = \"http\"\n").unwrap();
2091
2092        assert!(GatewayConfig::load(&path)
2093            .unwrap_err()
2094            .to_string()
2095            .contains("url is required"));
2096    }
2097
2098    #[test]
2099    fn config_rejects_none_listener_on_non_loopback() {
2100        let dir = tempfile::tempdir().unwrap();
2101        let path = dir.path().join("projects.toml");
2102        std::fs::write(
2103            &path,
2104            "[[listeners]]\nname = \"open\"\naddr = \"0.0.0.0:8787\"\nauth = \"none\"\n\n[projects.demo]\nrepo = \"/tmp/demo\"\n",
2105        )
2106        .unwrap();
2107
2108        assert!(GatewayConfig::load(&path)
2109            .unwrap_err()
2110            .to_string()
2111            .contains("only allowed on loopback"));
2112    }
2113
2114    #[tokio::test]
2115    async fn remember_requires_explicit_project_and_maps_kind() {
2116        let runner = Arc::new(FakeRunner::default());
2117        let service = GatewayService::new(registry(), runner.clone());
2118        assert!(service.remember("", "decision", "ship it").await.is_err());
2119        let response = service
2120            .remember("demo", "decision", "ship it")
2121            .await
2122            .unwrap();
2123        assert_eq!(response.project, "demo");
2124        assert_eq!(response.stored, Some(true));
2125        let calls = runner.calls.lock().unwrap();
2126        assert_eq!(calls[0].1, vec!["decide", "ship it"]);
2127    }
2128
2129    #[tokio::test]
2130    async fn task_create_combines_title_and_body() {
2131        let runner = Arc::new(FakeRunner::default());
2132        let service = GatewayService::new(registry(), runner.clone());
2133        service.task_create("demo", "title", "body").await.unwrap();
2134        let calls = runner.calls.lock().unwrap();
2135        assert_eq!(calls[0].1, vec!["task", "create", "title\n\nbody"]);
2136    }
2137
2138    #[tokio::test]
2139    async fn service_add_project_routes_to_runtime_project() {
2140        let runner = Arc::new(FakeRunner::default());
2141        let service = GatewayService::new(registry(), runner.clone());
2142        let project = service
2143            .add_project(
2144                " extra ",
2145                ProjectConfig {
2146                    backend: ProjectBackendKind::Http,
2147                    repo: None,
2148                    db: None,
2149                    url: "http://127.0.0.1:9999/".to_owned(),
2150                    token_env: Some("EXTRA_TOKEN".to_owned()),
2151                    description: Some("extra project".to_owned()),
2152                },
2153                true,
2154            )
2155            .unwrap();
2156
2157        assert_eq!(project.name, "extra");
2158        assert_eq!(project.url, "http://127.0.0.1:9999");
2159        assert_eq!(service.current_project().unwrap().name, "extra");
2160        service.task_list("extra").await.unwrap();
2161
2162        let calls = runner.calls.lock().unwrap();
2163        assert_eq!(calls[0].0, "extra");
2164        assert_eq!(calls[0].1, vec!["task", "list"]);
2165    }
2166
2167    #[test]
2168    fn service_add_project_suffixes_duplicates_and_rejects_invalid_config() {
2169        let service = GatewayService::new(registry(), Arc::new(FakeRunner::default()));
2170        let duplicate = service
2171            .add_project(
2172                "demo",
2173                ProjectConfig {
2174                    backend: ProjectBackendKind::Http,
2175                    repo: None,
2176                    db: None,
2177                    url: "http://127.0.0.1:9999".to_owned(),
2178                    token_env: None,
2179                    description: None,
2180                },
2181                false,
2182            )
2183            .unwrap();
2184        assert_eq!(duplicate.name, "demo-2");
2185        assert!(service
2186            .add_project(
2187                "relative",
2188                ProjectConfig {
2189                    backend: ProjectBackendKind::Local,
2190                    repo: Some(PathBuf::from("relative")),
2191                    db: None,
2192                    url: String::new(),
2193                    token_env: None,
2194                    description: None,
2195                },
2196                false,
2197            )
2198            .unwrap_err()
2199            .to_string()
2200            .contains("absolute path"));
2201    }
2202
2203    #[test]
2204    fn service_add_project_persists_local_project_config() {
2205        let dir = tempfile::tempdir().unwrap();
2206        let config_path = write_gateway_config(dir.path(), false);
2207        let runtime = file_backed_runtime(&config_path, Arc::new(FakeRunner::default()));
2208        let repo = dir.path().join("repo");
2209        let db = dir.path().join("repo/.shuttle/shuttle.db");
2210
2211        let project = runtime
2212            .service
2213            .add_project(
2214                "local-extra",
2215                ProjectConfig {
2216                    backend: ProjectBackendKind::Local,
2217                    repo: Some(repo.clone()),
2218                    db: Some(db.clone()),
2219                    url: String::new(),
2220                    token_env: None,
2221                    description: Some("local extra".to_owned()),
2222                },
2223                false,
2224            )
2225            .unwrap();
2226
2227        assert_eq!(project.name, "local-extra");
2228        let reloaded = GatewayConfig::load(&config_path).unwrap();
2229        let persisted = reloaded.projects.get("local-extra").unwrap();
2230        assert_eq!(persisted.backend, ProjectBackendKind::Local);
2231        assert_eq!(persisted.repo.as_ref(), Some(&repo));
2232        assert_eq!(persisted.db.as_ref(), Some(&db));
2233        assert_eq!(persisted.description.as_deref(), Some("local extra"));
2234    }
2235
2236    #[tokio::test]
2237    async fn mcp_tools_list_includes_gateway_tools() {
2238        let runtime = test_runtime(registry(), Arc::new(FakeRunner::default()));
2239        env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2240        let response = router(runtime)
2241            .oneshot(
2242                Request::builder()
2243                    .method(Method::POST)
2244                    .uri("/mcp")
2245                    .header(header::CONTENT_TYPE, "application/json")
2246                    .body(Body::from(
2247                        r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
2248                    ))
2249                    .unwrap(),
2250            )
2251            .await
2252            .unwrap();
2253        assert_eq!(response.status(), StatusCode::OK);
2254        let body = response.into_body().collect().await.unwrap().to_bytes();
2255        let value: Value = serde_json::from_slice(&body).unwrap();
2256        let tools = value["result"]["tools"].as_array().unwrap();
2257        assert!(tools.iter().any(|tool| tool["name"] == "shuttle_projects"));
2258        assert!(tools
2259            .iter()
2260            .any(|tool| tool["name"] == "shuttle_project_add"));
2261        assert!(tools.iter().any(|tool| tool["name"] == "shuttle_remember"));
2262        assert!(tools
2263            .iter()
2264            .all(|tool| tool["inputSchema"]["additionalProperties"] == false));
2265        assert!(tools
2266            .iter()
2267            .all(|tool| tool["outputSchema"]["type"] == "object"));
2268    }
2269
2270    #[tokio::test]
2271    async fn mcp_tool_call_returns_structured_content() {
2272        let runtime = test_runtime(registry(), Arc::new(FakeRunner::default()));
2273        env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2274        let response = router(runtime)
2275            .oneshot(
2276                Request::builder()
2277                    .method(Method::POST)
2278                    .uri("/mcp")
2279                    .header(header::CONTENT_TYPE, "application/json")
2280                    .body(Body::from(
2281                        r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"shuttle_task_create","arguments":{"project":"demo","title":"ship it"}}}"#,
2282                    ))
2283                    .unwrap(),
2284            )
2285            .await
2286            .unwrap();
2287        assert_eq!(response.status(), StatusCode::OK);
2288        let body = response.into_body().collect().await.unwrap().to_bytes();
2289        let value: Value = serde_json::from_slice(&body).unwrap();
2290
2291        assert_eq!(
2292            value["result"]["structuredContent"],
2293            json!({"project": "demo", "result": {"ok": true}, "stored": true})
2294        );
2295        assert_eq!(
2296            value["result"]["content"][0]["text"],
2297            r#"{"project":"demo","result":{"ok":true},"stored":true}"#
2298        );
2299    }
2300
2301    #[tokio::test]
2302    async fn mcp_project_add_registers_project() {
2303        let dir = tempfile::tempdir().unwrap();
2304        let config_path = write_gateway_config(dir.path(), true);
2305        let runtime = file_backed_runtime(&config_path, Arc::new(FakeRunner::default()));
2306        env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2307        let app = router(runtime);
2308
2309        let add = app
2310            .clone()
2311            .oneshot(
2312                Request::builder()
2313                    .method(Method::POST)
2314                    .uri("/mcp")
2315                    .header(header::CONTENT_TYPE, "application/json")
2316                    .body(Body::from(
2317                        r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"shuttle_project_add","arguments":{"name":"extra","backend":"http","url":"http://127.0.0.1:9999/","make_current":true}}}"#,
2318                    ))
2319                    .unwrap(),
2320            )
2321            .await
2322            .unwrap();
2323        assert_eq!(add.status(), StatusCode::OK);
2324        let body = add.into_body().collect().await.unwrap().to_bytes();
2325        let value: Value = serde_json::from_slice(&body).unwrap();
2326        assert_eq!(value["result"]["structuredContent"]["name"], "extra-2");
2327        assert_eq!(
2328            value["result"]["structuredContent"]["url"],
2329            "http://127.0.0.1:9999"
2330        );
2331        assert!(GatewayConfig::load(&config_path)
2332            .unwrap()
2333            .projects
2334            .contains_key("extra-2"));
2335
2336        let current = app
2337            .oneshot(
2338                Request::builder()
2339                    .method(Method::POST)
2340                    .uri("/mcp")
2341                    .header(header::CONTENT_TYPE, "application/json")
2342                    .body(Body::from(
2343                        r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"shuttle_current_project","arguments":{}}}"#,
2344                    ))
2345                    .unwrap(),
2346            )
2347            .await
2348            .unwrap();
2349        let body = current.into_body().collect().await.unwrap().to_bytes();
2350        let value: Value = serde_json::from_slice(&body).unwrap();
2351        assert_eq!(value["result"]["structuredContent"]["name"], "extra-2");
2352    }
2353
2354    #[tokio::test]
2355    async fn mcp_initialized_notification_returns_no_content() {
2356        let runtime = test_runtime(registry(), Arc::new(FakeRunner::default()));
2357        env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2358        let response = router(runtime)
2359            .oneshot(
2360                Request::builder()
2361                    .method(Method::POST)
2362                    .uri("/mcp")
2363                    .header(header::CONTENT_TYPE, "application/json")
2364                    .body(Body::from(
2365                        r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#,
2366                    ))
2367                    .unwrap(),
2368            )
2369            .await
2370            .unwrap();
2371
2372        assert_eq!(response.status(), StatusCode::NO_CONTENT);
2373    }
2374
2375    #[tokio::test]
2376    async fn http_remember_requires_project_and_recall_uses_default() {
2377        let runner = Arc::new(FakeRunner::default());
2378        let runtime = test_runtime(registry(), runner.clone());
2379        env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2380
2381        let remember = router(runtime.clone())
2382            .oneshot(
2383                Request::builder()
2384                    .method(Method::POST)
2385                    .uri("/api/remember")
2386                    .header(header::CONTENT_TYPE, "application/json")
2387                    .body(Body::from(r#"{"text":"note"}"#))
2388                    .unwrap(),
2389            )
2390            .await
2391            .unwrap();
2392        assert_eq!(remember.status(), StatusCode::BAD_REQUEST);
2393
2394        let recall = router(runtime)
2395            .oneshot(
2396                Request::builder()
2397                    .method(Method::POST)
2398                    .uri("/api/recall")
2399                    .header(header::CONTENT_TYPE, "application/json")
2400                    .body(Body::from(r#"{"query":"sqlite"}"#))
2401                    .unwrap(),
2402            )
2403            .await
2404            .unwrap();
2405
2406        assert_eq!(recall.status(), StatusCode::OK);
2407        let calls = runner.calls.lock().unwrap();
2408        assert_eq!(calls[0].1, vec!["recall", "sqlite"]);
2409    }
2410
2411    #[tokio::test]
2412    async fn http_project_add_updates_project_list_and_current() {
2413        let dir = tempfile::tempdir().unwrap();
2414        let config_path = write_gateway_config(dir.path(), true);
2415        let runtime = file_backed_runtime(&config_path, Arc::new(FakeRunner::default()));
2416        env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2417        let app = router(runtime);
2418
2419        let add = app
2420            .clone()
2421            .oneshot(
2422                Request::builder()
2423                    .method(Method::POST)
2424                    .uri("/api/projects")
2425                    .header(header::CONTENT_TYPE, "application/json")
2426                    .body(Body::from(
2427                        r#"{"name":"extra","backend":"http","url":"http://127.0.0.1:9999/","description":"extra project","make_current":true}"#,
2428                    ))
2429                    .unwrap(),
2430            )
2431            .await
2432            .unwrap();
2433        assert_eq!(add.status(), StatusCode::CREATED);
2434        let body = add.into_body().collect().await.unwrap().to_bytes();
2435        let project: Value = serde_json::from_slice(&body).unwrap();
2436        assert_eq!(project["name"], "extra-2");
2437        assert_eq!(project["url"], "http://127.0.0.1:9999");
2438        assert!(GatewayConfig::load(&config_path)
2439            .unwrap()
2440            .projects
2441            .contains_key("extra-2"));
2442
2443        let current = app
2444            .clone()
2445            .oneshot(
2446                Request::builder()
2447                    .method(Method::GET)
2448                    .uri("/api/projects/current")
2449                    .body(Body::empty())
2450                    .unwrap(),
2451            )
2452            .await
2453            .unwrap();
2454        let body = current.into_body().collect().await.unwrap().to_bytes();
2455        let project: Value = serde_json::from_slice(&body).unwrap();
2456        assert_eq!(project["name"], "extra-2");
2457
2458        let list = app
2459            .oneshot(
2460                Request::builder()
2461                    .method(Method::GET)
2462                    .uri("/api/projects")
2463                    .body(Body::empty())
2464                    .unwrap(),
2465            )
2466            .await
2467            .unwrap();
2468        let body = list.into_body().collect().await.unwrap().to_bytes();
2469        let projects: Value = serde_json::from_slice(&body).unwrap();
2470        assert!(projects["projects"]
2471            .as_array()
2472            .unwrap()
2473            .iter()
2474            .any(|project| project["name"] == "extra-2"));
2475    }
2476
2477    #[tokio::test]
2478    async fn current_project_without_default_returns_not_found() {
2479        let runtime = test_runtime(registry_without_default(), Arc::new(FakeRunner::default()));
2480        env::remove_var("TEST_EMPTY_GATEWAY_TOKEN");
2481        let response = router(runtime)
2482            .oneshot(
2483                Request::builder()
2484                    .method(Method::GET)
2485                    .uri("/api/projects/current")
2486                    .body(Body::empty())
2487                    .unwrap(),
2488            )
2489            .await
2490            .unwrap();
2491
2492        assert_eq!(response.status(), StatusCode::NOT_FOUND);
2493    }
2494
2495    #[tokio::test]
2496    async fn oauth_routes_issue_token_that_authorizes_protected_routes() {
2497        let oauth_dir = tempfile::tempdir().unwrap();
2498        let oauth = Arc::new(OAuthRuntime {
2499            config: OAuthConfig {
2500                public_url: "https://shuttle.example.test".to_owned(),
2501                admin_token: Some("admin-token".to_owned()),
2502            },
2503            store: OAuthStore::open(oauth_dir.path().join("oauth.db")).unwrap(),
2504        });
2505        let runtime = GatewayRuntime {
2506            service: Arc::new(GatewayService::new(
2507                registry(),
2508                Arc::new(FakeRunner::default()),
2509            )),
2510            auth: GatewayAuth::OAuth(oauth),
2511        };
2512        let app = router(runtime);
2513
2514        let register = app
2515            .clone()
2516            .oneshot(
2517                Request::builder()
2518                    .method(Method::POST)
2519                    .uri("/oauth/register")
2520                    .header(header::CONTENT_TYPE, "application/json")
2521                    .body(Body::from(
2522                        r#"{"redirect_uris":["https://client.example.test/callback"],"client_name":"client"}"#,
2523                    ))
2524                    .unwrap(),
2525            )
2526            .await
2527            .unwrap();
2528        assert_eq!(register.status(), StatusCode::CREATED);
2529        let body = register.into_body().collect().await.unwrap().to_bytes();
2530        let registered: Value = serde_json::from_slice(&body).unwrap();
2531        assert!(registered.get("client_secret").is_none());
2532        let client_id = registered["client_id"].as_str().unwrap();
2533        let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
2534        let challenge = pkce_s256(verifier);
2535        let form = format!(
2536            "admin_token=admin-token&response_type=code&client_id={client_id}&redirect_uri=https%3A%2F%2Fclient.example.test%2Fcallback&state=state-123&scope=mcp&code_challenge={challenge}&code_challenge_method=S256"
2537        );
2538        let authorize = app
2539            .clone()
2540            .oneshot(
2541                Request::builder()
2542                    .method(Method::POST)
2543                    .uri("/oauth/authorize")
2544                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
2545                    .body(Body::from(form))
2546                    .unwrap(),
2547            )
2548            .await
2549            .unwrap();
2550        assert_eq!(authorize.status(), StatusCode::SEE_OTHER);
2551        let location = authorize
2552            .headers()
2553            .get(header::LOCATION)
2554            .and_then(|value| value.to_str().ok())
2555            .unwrap();
2556        assert!(location.contains("&state=state-123"));
2557        let code = location
2558            .split("code=")
2559            .nth(1)
2560            .unwrap()
2561            .split('&')
2562            .next()
2563            .unwrap();
2564        let token_form = format!(
2565            "grant_type=authorization_code&client_id={client_id}&redirect_uri=https%3A%2F%2Fclient.example.test%2Fcallback&code={code}&code_verifier={verifier}"
2566        );
2567        let token = app
2568            .clone()
2569            .oneshot(
2570                Request::builder()
2571                    .method(Method::POST)
2572                    .uri("/oauth/token")
2573                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
2574                    .body(Body::from(token_form))
2575                    .unwrap(),
2576            )
2577            .await
2578            .unwrap();
2579        assert_eq!(token.status(), StatusCode::OK);
2580        let body = token.into_body().collect().await.unwrap().to_bytes();
2581        let token: Value = serde_json::from_slice(&body).unwrap();
2582        let access_token = token["access_token"].as_str().unwrap();
2583
2584        let unauthorized = app
2585            .clone()
2586            .oneshot(
2587                Request::builder()
2588                    .method(Method::GET)
2589                    .uri("/api/projects")
2590                    .body(Body::empty())
2591                    .unwrap(),
2592            )
2593            .await
2594            .unwrap();
2595        assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
2596
2597        let authorized = app
2598            .oneshot(
2599                Request::builder()
2600                    .method(Method::GET)
2601                    .uri("/api/projects")
2602                    .header(header::AUTHORIZATION, format!("Bearer {access_token}"))
2603                    .body(Body::empty())
2604                    .unwrap(),
2605            )
2606            .await
2607            .unwrap();
2608        assert_eq!(authorized.status(), StatusCode::OK);
2609    }
2610
2611    fn write_gateway_config(dir: &Path, include_extra: bool) -> PathBuf {
2612        let path = dir.join("projects.toml");
2613        let extra = if include_extra {
2614            "\n[projects.extra]\nbackend = \"http\"\nurl = \"http://127.0.0.1:8788\"\n"
2615        } else {
2616            ""
2617        };
2618        std::fs::write(
2619            &path,
2620            format!(
2621                "[defaults]\nproject = \"demo\"\n\n[projects.demo]\nbackend = \"local\"\nrepo = \"/tmp/demo\"\n{extra}"
2622            ),
2623        )
2624        .unwrap();
2625        path
2626    }
2627
2628    fn file_backed_runtime(path: &Path, runner: Arc<FakeRunner>) -> GatewayRuntime {
2629        let cfg = GatewayConfig::load(path).unwrap();
2630        let config_path = cfg.config_path.clone();
2631        let registry = ProjectRegistry::new(cfg.defaults.project, cfg.projects).unwrap();
2632        GatewayRuntime {
2633            service: Arc::new(GatewayService::new_with_config_path(
2634                registry,
2635                runner,
2636                config_path,
2637            )),
2638            auth: GatewayAuth::Bearer {
2639                token_env: "TEST_EMPTY_GATEWAY_TOKEN".to_owned(),
2640            },
2641        }
2642    }
2643
2644    fn test_runtime(registry: ProjectRegistry, runner: Arc<FakeRunner>) -> GatewayRuntime {
2645        GatewayRuntime {
2646            service: Arc::new(GatewayService::new(registry, runner)),
2647            auth: GatewayAuth::Bearer {
2648                token_env: "TEST_EMPTY_GATEWAY_TOKEN".to_owned(),
2649            },
2650        }
2651    }
2652
2653    fn registry_without_default() -> ProjectRegistry {
2654        ProjectRegistry::new(
2655            String::new(),
2656            BTreeMap::from([(
2657                "demo".to_owned(),
2658                ProjectConfig {
2659                    backend: ProjectBackendKind::Local,
2660                    repo: Some(PathBuf::from("/tmp/demo")),
2661                    db: None,
2662                    url: String::new(),
2663                    token_env: None,
2664                    description: None,
2665                },
2666            )]),
2667        )
2668        .unwrap()
2669    }
2670
2671    fn pkce_s256(verifier: &str) -> String {
2672        use base64::engine::general_purpose::URL_SAFE_NO_PAD;
2673        use base64::Engine;
2674        use sha2::{Digest, Sha256};
2675
2676        URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()))
2677    }
2678}