Skip to main content

systemprompt_api/routes/gateway/
auth.rs

1//! Bridge authentication handlers for the gateway router.
2//!
3//! Exposes the credential-exchange endpoints a bridge uses to obtain a
4//! short-lived access token: [`pat`] (personal access token), [`session`]
5//! (one-time exchange code), [`mtls`] (enrolled device certificate), and
6//! [`provision_oauth_client`] (dynamic OAuth client registration), plus
7//! [`capabilities`] advertising the supported modes. All token-minting paths
8//! funnel through `systemprompt_oauth`'s `issue_bridge_access`.
9
10use axum::Json;
11use axum::extract::Request;
12use axum::http::{HeaderMap, StatusCode};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use systemprompt_identifiers::{JwtToken, UserId, headers};
17use systemprompt_models::Config;
18use systemprompt_models::auth::BEARER_PREFIX;
19use systemprompt_oauth::services::{
20    BridgeAuthResult, BridgeOAuthClient, exchange_bridge_session_code, issue_bridge_access,
21    provision_bridge_oauth_client,
22};
23use systemprompt_runtime::AppContext;
24use systemprompt_traits::{AnalyticsProvider, AppContext as _};
25use systemprompt_users::{ApiKeyService, DeviceCertService};
26
27use crate::services::middleware::JwtContextExtractor;
28
29#[derive(Debug, Serialize)]
30pub struct AuthResponse {
31    pub token: String,
32    pub ttl: u64,
33    pub headers: HashMap<String, String>,
34}
35
36impl From<BridgeAuthResult> for AuthResponse {
37    fn from(r: BridgeAuthResult) -> Self {
38        Self {
39            token: r.token,
40            ttl: r.ttl,
41            headers: r.headers,
42        }
43    }
44}
45
46#[derive(Debug, Serialize)]
47pub struct Capabilities {
48    pub modes: Vec<&'static str>,
49}
50
51pub async fn capabilities() -> Json<Capabilities> {
52    Json(Capabilities {
53        modes: vec!["pat", "session", "mtls", "oauth-client"],
54    })
55}
56
57#[derive(Debug, Deserialize)]
58pub struct MtlsRequestBody {
59    pub device_cert_fingerprint: String,
60}
61
62#[derive(Debug, Deserialize)]
63pub struct SessionExchangeBody {
64    pub code: String,
65}
66
67pub async fn pat(
68    ctx: AppContext,
69    request: Request,
70) -> Result<Json<AuthResponse>, (StatusCode, String)> {
71    let pat_token = extract_bearer(request.headers()).ok_or_else(|| {
72        (
73            StatusCode::UNAUTHORIZED,
74            "Missing Authorization: Bearer <pat>".into(),
75        )
76    })?;
77
78    let service = ApiKeyService::new(ctx.db_pool())
79        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
80
81    let record = service
82        .verify(&pat_token)
83        .await
84        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
85        .ok_or_else(|| (StatusCode::UNAUTHORIZED, "Invalid PAT".into()))?;
86
87    let analytics = require_analytics(&ctx)?;
88    let result = issue_bridge_access(
89        ctx.db_pool(),
90        analytics.as_ref(),
91        request.headers(),
92        &record.user_id,
93    )
94    .await
95    .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
96
97    Ok(Json(result.into()))
98}
99
100pub async fn session(
101    ctx: AppContext,
102    headers: HeaderMap,
103    Json(body): Json<SessionExchangeBody>,
104) -> Result<Json<AuthResponse>, (StatusCode, String)> {
105    if body.code.trim().is_empty() {
106        return Err((StatusCode::BAD_REQUEST, "missing exchange code".into()));
107    }
108
109    let analytics = require_analytics(&ctx)?;
110    let result = exchange_bridge_session_code(
111        ctx.db_pool(),
112        analytics.as_ref(),
113        &headers,
114        body.code.trim(),
115    )
116    .await
117    .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
118    .ok_or_else(|| {
119        (
120            StatusCode::UNAUTHORIZED,
121            "exchange code invalid, expired, or already consumed".into(),
122        )
123    })?;
124
125    Ok(Json(result.into()))
126}
127
128pub async fn provision_oauth_client(
129    jwt_extractor: Arc<JwtContextExtractor>,
130    ctx: AppContext,
131    request: Request,
132) -> Result<Json<BridgeOAuthClient>, (StatusCode, String)> {
133    let bearer = extract_bearer(request.headers()).ok_or_else(|| {
134        (
135            StatusCode::UNAUTHORIZED,
136            "Missing Authorization: Bearer <bridge-jwt>".into(),
137        )
138    })?;
139
140    let (claims, _user) = jwt_extractor
141        .decode_for_gateway(&JwtToken::new(bearer))
142        .await
143        .map_err(|e| (StatusCode::UNAUTHORIZED, e.to_string()))?;
144
145    let user_id = UserId::new(claims.user_id.to_string());
146
147    let token_endpoint =
148        build_token_endpoint().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
149
150    let result = provision_bridge_oauth_client(ctx.db_pool(), &user_id, token_endpoint)
151        .await
152        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
153
154    Ok(Json(result))
155}
156
157fn build_token_endpoint() -> Result<String, String> {
158    let cfg = Config::get().map_err(|e| e.to_string())?;
159    Ok(format!(
160        "{}/api/v1/core/oauth/token",
161        cfg.api_external_url.trim_end_matches('/')
162    ))
163}
164
165pub async fn mtls(
166    ctx: AppContext,
167    headers: HeaderMap,
168    Json(body): Json<MtlsRequestBody>,
169) -> Result<Json<AuthResponse>, (StatusCode, String)> {
170    let fingerprint = body.device_cert_fingerprint.trim();
171    if fingerprint.is_empty() {
172        return Err((
173            StatusCode::BAD_REQUEST,
174            "missing device_cert_fingerprint".into(),
175        ));
176    }
177
178    let service = DeviceCertService::new(ctx.db_pool())
179        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
180
181    let record = service
182        .verify(fingerprint)
183        .await
184        .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
185        .ok_or_else(|| {
186            (
187                StatusCode::UNAUTHORIZED,
188                "device certificate not enrolled or revoked".into(),
189            )
190        })?;
191
192    let analytics = require_analytics(&ctx)?;
193    let result = issue_bridge_access(ctx.db_pool(), analytics.as_ref(), &headers, &record.user_id)
194        .await
195        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
196
197    Ok(Json(result.into()))
198}
199
200fn extract_bearer(hdrs: &HeaderMap) -> Option<String> {
201    let auth = hdrs.get(headers::AUTHORIZATION)?.to_str().ok()?;
202    auth.strip_prefix(BEARER_PREFIX)
203        .map(|s| s.trim().to_owned())
204}
205
206fn require_analytics(ctx: &AppContext) -> Result<Arc<dyn AnalyticsProvider>, (StatusCode, String)> {
207    ctx.analytics_provider().ok_or_else(|| {
208        (
209            StatusCode::INTERNAL_SERVER_ERROR,
210            "analytics provider unavailable".into(),
211        )
212    })
213}