Skip to main content

systemprompt_api/routes/gateway/
auth.rs

1use axum::Json;
2use axum::extract::Request;
3use axum::http::{HeaderMap, StatusCode};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use systemprompt_identifiers::{JwtToken, UserId, headers};
8use systemprompt_models::Config;
9use systemprompt_models::auth::BEARER_PREFIX;
10use systemprompt_oauth::services::{
11    BridgeAuthResult, BridgeOAuthClient, exchange_bridge_session_code, issue_bridge_access,
12    provision_bridge_oauth_client,
13};
14use systemprompt_runtime::AppContext;
15use systemprompt_users::{ApiKeyService, DeviceCertService};
16
17use crate::services::middleware::JwtContextExtractor;
18
19#[derive(Debug, Serialize)]
20pub struct AuthResponse {
21    pub token: String,
22    pub ttl: u64,
23    pub headers: HashMap<String, String>,
24}
25
26impl From<BridgeAuthResult> for AuthResponse {
27    fn from(r: BridgeAuthResult) -> Self {
28        Self {
29            token: r.token,
30            ttl: r.ttl,
31            headers: r.headers,
32        }
33    }
34}
35
36#[derive(Debug, Serialize)]
37pub struct Capabilities {
38    pub modes: Vec<&'static str>,
39}
40
41pub async fn capabilities() -> Json<Capabilities> {
42    Json(Capabilities {
43        modes: vec!["pat", "session", "mtls", "oauth-client"],
44    })
45}
46
47#[derive(Debug, Deserialize)]
48pub struct MtlsRequestBody {
49    pub device_cert_fingerprint: String,
50    #[serde(default)]
51    pub session_id: Option<String>,
52}
53
54#[derive(Debug, Deserialize)]
55pub struct SessionExchangeBody {
56    pub code: String,
57    #[serde(default)]
58    pub session_id: Option<String>,
59}
60
61pub async fn pat(
62    ctx: AppContext,
63    request: Request,
64) -> Result<Json<AuthResponse>, (StatusCode, String)> {
65    let pat_token = extract_bearer(request.headers()).ok_or_else(|| {
66        (
67            StatusCode::UNAUTHORIZED,
68            "Missing Authorization: Bearer <pat>".into(),
69        )
70    })?;
71
72    let service = ApiKeyService::new(ctx.db_pool())
73        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
74
75    let record = service
76        .verify(&pat_token)
77        .await
78        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
79        .ok_or_else(|| (StatusCode::UNAUTHORIZED, "Invalid PAT".into()))?;
80
81    let result = issue_bridge_access(ctx.db_pool(), &record.user_id)
82        .await
83        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
84
85    Ok(Json(result.into()))
86}
87
88pub async fn session(
89    ctx: AppContext,
90    Json(body): Json<SessionExchangeBody>,
91) -> Result<Json<AuthResponse>, (StatusCode, String)> {
92    if body.code.trim().is_empty() {
93        return Err((StatusCode::BAD_REQUEST, "missing exchange code".into()));
94    }
95
96    let result = exchange_bridge_session_code(ctx.db_pool(), body.code.trim())
97        .await
98        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
99        .ok_or_else(|| {
100            (
101                StatusCode::UNAUTHORIZED,
102                "exchange code invalid, expired, or already consumed".into(),
103            )
104        })?;
105
106    Ok(Json(result.into()))
107}
108
109pub async fn provision_oauth_client(
110    jwt_extractor: Arc<JwtContextExtractor>,
111    ctx: AppContext,
112    request: Request,
113) -> Result<Json<BridgeOAuthClient>, (StatusCode, String)> {
114    let bearer = extract_bearer(request.headers()).ok_or_else(|| {
115        (
116            StatusCode::UNAUTHORIZED,
117            "Missing Authorization: Bearer <bridge-jwt>".into(),
118        )
119    })?;
120
121    let claims = jwt_extractor
122        .decode_for_gateway(&JwtToken::new(bearer))
123        .await
124        .map_err(|e| (StatusCode::UNAUTHORIZED, e.to_string()))?;
125
126    let user_id = UserId::new(claims.user_id.to_string());
127
128    let token_endpoint =
129        build_token_endpoint().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
130
131    let result = provision_bridge_oauth_client(ctx.db_pool(), &user_id, token_endpoint)
132        .await
133        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
134
135    Ok(Json(result))
136}
137
138fn build_token_endpoint() -> Result<String, String> {
139    let cfg = Config::get().map_err(|e| e.to_string())?;
140    Ok(format!(
141        "{}/api/v1/core/oauth/token",
142        cfg.api_external_url.trim_end_matches('/')
143    ))
144}
145
146pub async fn mtls(
147    ctx: AppContext,
148    Json(body): Json<MtlsRequestBody>,
149) -> Result<Json<AuthResponse>, (StatusCode, String)> {
150    let fingerprint = body.device_cert_fingerprint.trim();
151    if fingerprint.is_empty() {
152        return Err((
153            StatusCode::BAD_REQUEST,
154            "missing device_cert_fingerprint".into(),
155        ));
156    }
157
158    let service = DeviceCertService::new(ctx.db_pool())
159        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
160
161    let record = service
162        .verify(fingerprint)
163        .await
164        .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
165        .ok_or_else(|| {
166            (
167                StatusCode::UNAUTHORIZED,
168                "device certificate not enrolled or revoked".into(),
169            )
170        })?;
171
172    let result = issue_bridge_access(ctx.db_pool(), &record.user_id)
173        .await
174        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
175
176    Ok(Json(result.into()))
177}
178
179fn extract_bearer(hdrs: &HeaderMap) -> Option<String> {
180    let auth = hdrs.get(headers::AUTHORIZATION)?.to_str().ok()?;
181    auth.strip_prefix(BEARER_PREFIX)
182        .map(|s| s.trim().to_string())
183}