systemprompt_api/routes/gateway/
auth.rs1use axum::Json;
2use axum::extract::Request;
3use axum::http::{HeaderMap, StatusCode};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use systemprompt_identifiers::{JwtToken, UserId, headers};
8use systemprompt_models::Config;
9use systemprompt_models::auth::BEARER_PREFIX;
10use systemprompt_oauth::services::{
11 BridgeAuthResult, BridgeOAuthClient, exchange_bridge_session_code, issue_bridge_access,
12 provision_bridge_oauth_client,
13};
14use systemprompt_runtime::AppContext;
15use systemprompt_users::{ApiKeyService, DeviceCertService};
16
17use crate::services::middleware::JwtContextExtractor;
18
19#[derive(Debug, Serialize)]
20pub struct AuthResponse {
21 pub token: String,
22 pub ttl: u64,
23 pub headers: HashMap<String, String>,
24}
25
26impl From<BridgeAuthResult> for AuthResponse {
27 fn from(r: BridgeAuthResult) -> Self {
28 Self {
29 token: r.token,
30 ttl: r.ttl,
31 headers: r.headers,
32 }
33 }
34}
35
36#[derive(Debug, Serialize)]
37pub struct Capabilities {
38 pub modes: Vec<&'static str>,
39}
40
41pub async fn capabilities() -> Json<Capabilities> {
42 Json(Capabilities {
43 modes: vec!["pat", "session", "mtls", "oauth-client"],
44 })
45}
46
47#[derive(Debug, Deserialize)]
48pub struct MtlsRequestBody {
49 pub device_cert_fingerprint: String,
50 #[serde(default)]
51 pub session_id: Option<String>,
52}
53
54#[derive(Debug, Deserialize)]
55pub struct SessionExchangeBody {
56 pub code: String,
57 #[serde(default)]
58 pub session_id: Option<String>,
59}
60
61pub async fn pat(
62 ctx: AppContext,
63 request: Request,
64) -> Result<Json<AuthResponse>, (StatusCode, String)> {
65 let pat_token = extract_bearer(request.headers()).ok_or_else(|| {
66 (
67 StatusCode::UNAUTHORIZED,
68 "Missing Authorization: Bearer <pat>".into(),
69 )
70 })?;
71
72 let service = ApiKeyService::new(ctx.db_pool())
73 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
74
75 let record = service
76 .verify(&pat_token)
77 .await
78 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
79 .ok_or_else(|| (StatusCode::UNAUTHORIZED, "Invalid PAT".into()))?;
80
81 let result = issue_bridge_access(ctx.db_pool(), &record.user_id)
82 .await
83 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
84
85 Ok(Json(result.into()))
86}
87
88pub async fn session(
89 ctx: AppContext,
90 Json(body): Json<SessionExchangeBody>,
91) -> Result<Json<AuthResponse>, (StatusCode, String)> {
92 if body.code.trim().is_empty() {
93 return Err((StatusCode::BAD_REQUEST, "missing exchange code".into()));
94 }
95
96 let result = exchange_bridge_session_code(ctx.db_pool(), body.code.trim())
97 .await
98 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
99 .ok_or_else(|| {
100 (
101 StatusCode::UNAUTHORIZED,
102 "exchange code invalid, expired, or already consumed".into(),
103 )
104 })?;
105
106 Ok(Json(result.into()))
107}
108
109pub async fn provision_oauth_client(
110 jwt_extractor: Arc<JwtContextExtractor>,
111 ctx: AppContext,
112 request: Request,
113) -> Result<Json<BridgeOAuthClient>, (StatusCode, String)> {
114 let bearer = extract_bearer(request.headers()).ok_or_else(|| {
115 (
116 StatusCode::UNAUTHORIZED,
117 "Missing Authorization: Bearer <bridge-jwt>".into(),
118 )
119 })?;
120
121 let claims = jwt_extractor
122 .decode_for_gateway(&JwtToken::new(bearer))
123 .await
124 .map_err(|e| (StatusCode::UNAUTHORIZED, e.to_string()))?;
125
126 let user_id = UserId::new(claims.user_id.to_string());
127
128 let token_endpoint =
129 build_token_endpoint().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
130
131 let result = provision_bridge_oauth_client(ctx.db_pool(), &user_id, token_endpoint)
132 .await
133 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
134
135 Ok(Json(result))
136}
137
138fn build_token_endpoint() -> Result<String, String> {
139 let cfg = Config::get().map_err(|e| e.to_string())?;
140 Ok(format!(
141 "{}/api/v1/core/oauth/token",
142 cfg.api_external_url.trim_end_matches('/')
143 ))
144}
145
146pub async fn mtls(
147 ctx: AppContext,
148 Json(body): Json<MtlsRequestBody>,
149) -> Result<Json<AuthResponse>, (StatusCode, String)> {
150 let fingerprint = body.device_cert_fingerprint.trim();
151 if fingerprint.is_empty() {
152 return Err((
153 StatusCode::BAD_REQUEST,
154 "missing device_cert_fingerprint".into(),
155 ));
156 }
157
158 let service = DeviceCertService::new(ctx.db_pool())
159 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
160
161 let record = service
162 .verify(fingerprint)
163 .await
164 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
165 .ok_or_else(|| {
166 (
167 StatusCode::UNAUTHORIZED,
168 "device certificate not enrolled or revoked".into(),
169 )
170 })?;
171
172 let result = issue_bridge_access(ctx.db_pool(), &record.user_id)
173 .await
174 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
175
176 Ok(Json(result.into()))
177}
178
179fn extract_bearer(hdrs: &HeaderMap) -> Option<String> {
180 let auth = hdrs.get(headers::AUTHORIZATION)?.to_str().ok()?;
181 auth.strip_prefix(BEARER_PREFIX)
182 .map(|s| s.trim().to_string())
183}