systemprompt_api/routes/gateway/
auth.rs1use axum::Json;
11use axum::extract::Request;
12use axum::http::{HeaderMap, StatusCode};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use systemprompt_identifiers::{JwtToken, UserId, headers};
17use systemprompt_models::Config;
18use systemprompt_models::auth::BEARER_PREFIX;
19use systemprompt_oauth::services::{
20 BridgeAuthResult, BridgeOAuthClient, exchange_bridge_session_code, issue_bridge_access,
21 provision_bridge_oauth_client,
22};
23use systemprompt_runtime::AppContext;
24use systemprompt_traits::{AnalyticsProvider, AppContext as _};
25use systemprompt_users::{ApiKeyService, DeviceCertService};
26
27use crate::services::middleware::JwtContextExtractor;
28
29#[derive(Debug, Serialize)]
30pub struct AuthResponse {
31 pub token: String,
32 pub ttl: u64,
33 pub headers: HashMap<String, String>,
34}
35
36impl From<BridgeAuthResult> for AuthResponse {
37 fn from(r: BridgeAuthResult) -> Self {
38 Self {
39 token: r.token,
40 ttl: r.ttl,
41 headers: r.headers,
42 }
43 }
44}
45
46#[derive(Debug, Serialize)]
47pub struct Capabilities {
48 pub modes: Vec<&'static str>,
49}
50
51pub async fn capabilities() -> Json<Capabilities> {
52 Json(Capabilities {
53 modes: vec!["pat", "session", "mtls", "oauth-client"],
54 })
55}
56
57#[derive(Debug, Deserialize)]
58pub struct MtlsRequestBody {
59 pub device_cert_fingerprint: String,
60}
61
62#[derive(Debug, Deserialize)]
63pub struct SessionExchangeBody {
64 pub code: String,
65}
66
67pub async fn pat(
68 ctx: AppContext,
69 request: Request,
70) -> Result<Json<AuthResponse>, (StatusCode, String)> {
71 let pat_token = extract_bearer(request.headers()).ok_or_else(|| {
72 (
73 StatusCode::UNAUTHORIZED,
74 "Missing Authorization: Bearer <pat>".into(),
75 )
76 })?;
77
78 let service = ApiKeyService::new(ctx.db_pool())
79 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
80
81 let record = service
82 .verify(&pat_token)
83 .await
84 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
85 .ok_or_else(|| (StatusCode::UNAUTHORIZED, "Invalid PAT".into()))?;
86
87 let analytics = require_analytics(&ctx)?;
88 let result = issue_bridge_access(
89 ctx.db_pool(),
90 analytics.as_ref(),
91 request.headers(),
92 &record.user_id,
93 )
94 .await
95 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
96
97 Ok(Json(result.into()))
98}
99
100pub async fn session(
101 ctx: AppContext,
102 headers: HeaderMap,
103 Json(body): Json<SessionExchangeBody>,
104) -> Result<Json<AuthResponse>, (StatusCode, String)> {
105 if body.code.trim().is_empty() {
106 return Err((StatusCode::BAD_REQUEST, "missing exchange code".into()));
107 }
108
109 let analytics = require_analytics(&ctx)?;
110 let result = exchange_bridge_session_code(
111 ctx.db_pool(),
112 analytics.as_ref(),
113 &headers,
114 body.code.trim(),
115 )
116 .await
117 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
118 .ok_or_else(|| {
119 (
120 StatusCode::UNAUTHORIZED,
121 "exchange code invalid, expired, or already consumed".into(),
122 )
123 })?;
124
125 Ok(Json(result.into()))
126}
127
128pub async fn provision_oauth_client(
129 jwt_extractor: Arc<JwtContextExtractor>,
130 ctx: AppContext,
131 request: Request,
132) -> Result<Json<BridgeOAuthClient>, (StatusCode, String)> {
133 let bearer = extract_bearer(request.headers()).ok_or_else(|| {
134 (
135 StatusCode::UNAUTHORIZED,
136 "Missing Authorization: Bearer <bridge-jwt>".into(),
137 )
138 })?;
139
140 let (claims, _user) = jwt_extractor
141 .decode_for_gateway(&JwtToken::new(bearer))
142 .await
143 .map_err(|e| (StatusCode::UNAUTHORIZED, e.to_string()))?;
144
145 let user_id = UserId::new(claims.user_id.to_string());
146
147 let token_endpoint =
148 build_token_endpoint().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
149
150 let result = provision_bridge_oauth_client(ctx.db_pool(), &user_id, token_endpoint)
151 .await
152 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
153
154 Ok(Json(result))
155}
156
157fn build_token_endpoint() -> Result<String, String> {
158 let cfg = Config::get().map_err(|e| e.to_string())?;
159 Ok(format!(
160 "{}/api/v1/core/oauth/token",
161 cfg.api_external_url.trim_end_matches('/')
162 ))
163}
164
165pub async fn mtls(
166 ctx: AppContext,
167 headers: HeaderMap,
168 Json(body): Json<MtlsRequestBody>,
169) -> Result<Json<AuthResponse>, (StatusCode, String)> {
170 let fingerprint = body.device_cert_fingerprint.trim();
171 if fingerprint.is_empty() {
172 return Err((
173 StatusCode::BAD_REQUEST,
174 "missing device_cert_fingerprint".into(),
175 ));
176 }
177
178 let service = DeviceCertService::new(ctx.db_pool())
179 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
180
181 let record = service
182 .verify(fingerprint)
183 .await
184 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
185 .ok_or_else(|| {
186 (
187 StatusCode::UNAUTHORIZED,
188 "device certificate not enrolled or revoked".into(),
189 )
190 })?;
191
192 let analytics = require_analytics(&ctx)?;
193 let result = issue_bridge_access(ctx.db_pool(), analytics.as_ref(), &headers, &record.user_id)
194 .await
195 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
196
197 Ok(Json(result.into()))
198}
199
200fn extract_bearer(hdrs: &HeaderMap) -> Option<String> {
201 let auth = hdrs.get(headers::AUTHORIZATION)?.to_str().ok()?;
202 auth.strip_prefix(BEARER_PREFIX)
203 .map(|s| s.trim().to_owned())
204}
205
206fn require_analytics(ctx: &AppContext) -> Result<Arc<dyn AnalyticsProvider>, (StatusCode, String)> {
207 ctx.analytics_provider().ok_or_else(|| {
208 (
209 StatusCode::INTERNAL_SERVER_ERROR,
210 "analytics provider unavailable".into(),
211 )
212 })
213}