systemprompt_api/routes/oauth/endpoints/
callback.rs1use axum::extract::{Query, State};
2use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8 AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::Config;
11use systemprompt_models::auth::{AuthenticatedUser, Permission, parse_permissions};
12
13use crate::routes::oauth::extractors::OAuthRepo;
14use systemprompt_oauth::OAuthState;
15use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
16
17#[derive(Debug, Deserialize)]
18pub struct CallbackQuery {
19 pub code: String,
20 pub state: Option<String>,
21}
22
23pub async fn handle_callback(
24 Query(params): Query<CallbackQuery>,
25 State(state): State<OAuthState>,
26 OAuthRepo(repo): OAuthRepo,
27 headers: HeaderMap,
28) -> impl IntoResponse {
29 let config = match Config::get() {
30 Ok(c) => c,
31 Err(e) => {
32 return (
33 StatusCode::INTERNAL_SERVER_ERROR,
34 format!("Failed to load config: {e}"),
35 )
36 .into_response();
37 },
38 };
39
40 let server_base_url = &config.api_external_url;
41 let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
42
43 let browser_client = match find_browser_client(&repo, &redirect_uri).await {
44 Ok(client) => client,
45 Err(e) => {
46 return (
47 StatusCode::INTERNAL_SERVER_ERROR,
48 format!("Failed to find OAuth client: {e}"),
49 )
50 .into_response();
51 },
52 };
53
54 let code = AuthorizationCode::new(¶ms.code);
55 let client_id = ClientId::new(&browser_client.client_id);
56 let token_response = match exchange_code_for_token(
57 &repo,
58 CodeExchangeParams {
59 code: &code,
60 client_id: &client_id,
61 redirect_uri: &redirect_uri,
62 headers: &headers,
63 },
64 &state,
65 )
66 .await
67 {
68 Ok(response) => response,
69 Err(e) => {
70 return (
71 StatusCode::UNAUTHORIZED,
72 format!("Failed to exchange code for token: {e}"),
73 )
74 .into_response();
75 },
76 };
77
78 let Some(state_token) = params.state.as_deref().filter(|s| !s.is_empty()) else {
79 return (StatusCode::BAD_REQUEST, "Missing state parameter").into_response();
80 };
81 let redirect_destination = match repo.consume_state_binding(state_token).await {
82 Ok(Some(binding)) => binding.return_to,
83 Ok(None) => {
84 tracing::warn!("state binding missing, expired, or already consumed");
85 return (StatusCode::BAD_REQUEST, "Invalid state parameter").into_response();
86 },
87 Err(e) => {
88 tracing::error!(error = %e, "state binding lookup failed");
89 return (
90 StatusCode::INTERNAL_SERVER_ERROR,
91 "Failed to validate state",
92 )
93 .into_response();
94 },
95 };
96
97 let cookie = format!(
98 "access_token={}; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={}",
99 token_response.access_token,
100 systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
101 );
102
103 let mut response = Redirect::to(&redirect_destination).into_response();
104 if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
105 response
106 .headers_mut()
107 .insert(header::SET_COOKIE, cookie_value);
108 }
109
110 response
111}
112
113async fn find_browser_client(
114 repo: &OAuthRepository,
115 redirect_uri: &str,
116) -> anyhow::Result<BrowserClient> {
117 let client = repo
118 .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
119 .await?
120 .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
121
122 Ok(BrowserClient {
123 client_id: client.client_id.to_string(),
124 })
125}
126
127struct CodeExchangeParams<'a> {
128 code: &'a AuthorizationCode,
129 client_id: &'a ClientId,
130 redirect_uri: &'a str,
131 headers: &'a HeaderMap,
132}
133
134async fn exchange_code_for_token(
135 repo: &OAuthRepository,
136 params: CodeExchangeParams<'_>,
137 state: &OAuthState,
138) -> anyhow::Result<TokenResponse> {
139 use systemprompt_oauth::services::{
140 JwtConfig, JwtSigningParams, generate_access_token_jti, generate_jwt, generate_secure_token,
141 };
142
143 let validation_result = repo
144 .validate_authorization_code(
145 params.code,
146 params.client_id,
147 Some(params.redirect_uri),
148 None,
149 )
150 .await?;
151
152 let user = load_authenticated_user(&validation_result.user_id, state.user_provider()).await?;
153
154 let permissions = parse_permissions(&validation_result.scope)?;
155
156 let mut session_service = systemprompt_oauth::services::SessionCreationService::new(
157 Arc::clone(state.analytics_provider()),
158 Arc::clone(state.user_provider()),
159 );
160 if let Some(publisher) = state.event_publisher() {
161 session_service = session_service.with_event_publisher(Arc::clone(publisher));
162 }
163 let session_id = session_service
164 .create_authenticated_session(
165 &validation_result.user_id,
166 params.headers,
167 SessionSource::Oauth,
168 )
169 .await?;
170
171 let access_token_jti = generate_access_token_jti();
172 let global_config = Config::get()?;
173 let config = JwtConfig {
174 permissions: permissions.clone(),
175 audience: global_config.jwt_audiences.clone(),
176 ..Default::default()
177 };
178 let signing = JwtSigningParams {
179 issuer: &global_config.jwt_issuer,
180 };
181 let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
182
183 let refresh_token_value = generate_secure_token("rt");
184 let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
185 let refresh_expires_at = chrono::Utc::now().timestamp()
186 + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
187 * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
188
189 let refresh_params = RefreshTokenParams::builder(
190 &refresh_token_id,
191 params.client_id,
192 &validation_result.user_id,
193 &validation_result.scope,
194 refresh_expires_at,
195 )
196 .build();
197 repo.store_refresh_token(refresh_params).await?;
198
199 if let Err(e) = repo
200 .link_auth_code_to_refresh_token(params.code, refresh_token_id.as_str())
201 .await
202 {
203 tracing::warn!(error = %e, "Failed to link auth code to refresh token");
204 }
205
206 Ok(TokenResponse { access_token })
207}
208
209async fn load_authenticated_user(
210 user_id: &UserId,
211 user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
212) -> anyhow::Result<AuthenticatedUser> {
213 let user = user_provider
214 .find_by_id(user_id)
215 .await
216 .map_err(|e| anyhow::anyhow!("{}", e))?
217 .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
218
219 let permissions: Vec<Permission> = user
220 .roles
221 .iter()
222 .filter_map(|s| {
223 Permission::from_str(s)
224 .map_err(|e| {
225 tracing::warn!(
226 user_id = %user.id,
227 role = %s,
228 error = %e,
229 "Invalid role in user record"
230 );
231 e
232 })
233 .ok()
234 })
235 .collect();
236
237 let user_uuid = uuid::Uuid::parse_str(user.id.as_str())
238 .map_err(|_| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
239
240 Ok(AuthenticatedUser::new_with_roles(
241 user_uuid,
242 user.name,
243 user.email,
244 permissions,
245 user.roles,
246 ))
247}
248
249#[derive(Debug)]
250struct BrowserClient {
251 client_id: String,
252}
253
254#[derive(Debug, serde::Deserialize)]
255struct TokenResponse {
256 access_token: String,
257}