systemprompt_api/routes/oauth/endpoints/
callback.rs1use axum::extract::{Query, State};
8use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
9use axum::response::{IntoResponse, Redirect};
10use serde::Deserialize;
11use std::str::FromStr;
12use std::sync::Arc;
13use systemprompt_identifiers::{
14 AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
15};
16use systemprompt_models::Config;
17use systemprompt_models::auth::{AuthenticatedUser, Permission, parse_permissions};
18
19use crate::routes::oauth::extractors::OAuthRepo;
20use systemprompt_oauth::OAuthState;
21use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
22
23#[derive(Debug, Deserialize)]
24pub struct CallbackQuery {
25 pub code: String,
26 pub state: Option<String>,
27}
28
29pub async fn handle_callback(
30 Query(params): Query<CallbackQuery>,
31 State(state): State<OAuthState>,
32 OAuthRepo(repo): OAuthRepo,
33 headers: HeaderMap,
34) -> impl IntoResponse {
35 let config = match Config::get() {
36 Ok(c) => c,
37 Err(e) => {
38 return (
39 StatusCode::INTERNAL_SERVER_ERROR,
40 format!("Failed to load config: {e}"),
41 )
42 .into_response();
43 },
44 };
45
46 let server_base_url = &config.api_external_url;
47 let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
48
49 let browser_client = match find_browser_client(&repo, &redirect_uri).await {
50 Ok(client) => client,
51 Err(e) => {
52 return (
53 StatusCode::INTERNAL_SERVER_ERROR,
54 format!("Failed to find OAuth client: {e}"),
55 )
56 .into_response();
57 },
58 };
59
60 let code = AuthorizationCode::new(¶ms.code);
61 let client_id = ClientId::new(&browser_client.client_id);
62 let token_response = match exchange_code_for_token(
63 &repo,
64 CodeExchangeParams {
65 code: &code,
66 client_id: &client_id,
67 redirect_uri: &redirect_uri,
68 headers: &headers,
69 },
70 &state,
71 )
72 .await
73 {
74 Ok(response) => response,
75 Err(e) => {
76 return (
77 StatusCode::UNAUTHORIZED,
78 format!("Failed to exchange code for token: {e}"),
79 )
80 .into_response();
81 },
82 };
83
84 let Some(state_token) = params.state.as_deref().filter(|s| !s.is_empty()) else {
85 return (StatusCode::BAD_REQUEST, "Missing state parameter").into_response();
86 };
87 let redirect_destination = match repo.consume_state_binding(state_token).await {
88 Ok(Some(binding)) => binding.return_to,
89 Ok(None) => {
90 tracing::warn!("state binding missing, expired, or already consumed");
91 return (StatusCode::BAD_REQUEST, "Invalid state parameter").into_response();
92 },
93 Err(e) => {
94 tracing::error!(error = %e, "state binding lookup failed");
95 return (
96 StatusCode::INTERNAL_SERVER_ERROR,
97 "Failed to validate state",
98 )
99 .into_response();
100 },
101 };
102
103 let cookie = format!(
104 "access_token={}; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={}",
105 token_response.access_token,
106 systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
107 );
108
109 let mut response = Redirect::to(&redirect_destination).into_response();
110 if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
111 response
112 .headers_mut()
113 .insert(header::SET_COOKIE, cookie_value);
114 }
115
116 response
117}
118
119async fn find_browser_client(
120 repo: &OAuthRepository,
121 redirect_uri: &str,
122) -> anyhow::Result<BrowserClient> {
123 let client = repo
124 .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
125 .await?
126 .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
127
128 Ok(BrowserClient {
129 client_id: client.client_id.to_string(),
130 })
131}
132
133struct CodeExchangeParams<'a> {
134 code: &'a AuthorizationCode,
135 client_id: &'a ClientId,
136 redirect_uri: &'a str,
137 headers: &'a HeaderMap,
138}
139
140async fn exchange_code_for_token(
141 repo: &OAuthRepository,
142 params: CodeExchangeParams<'_>,
143 state: &OAuthState,
144) -> anyhow::Result<TokenResponse> {
145 use systemprompt_oauth::services::{
146 JwtConfig, JwtSigningParams, generate_access_token_jti, generate_jwt, generate_secure_token,
147 };
148
149 let validation_result = repo
150 .validate_authorization_code(
151 params.code,
152 params.client_id,
153 Some(params.redirect_uri),
154 None,
155 )
156 .await?;
157
158 let user = load_authenticated_user(&validation_result.user_id, state.user_provider()).await?;
159
160 let permissions = parse_permissions(&validation_result.scope)?;
161
162 let mut session_service = systemprompt_oauth::services::SessionCreationService::new(
163 Arc::clone(state.analytics_provider()),
164 Arc::clone(state.user_provider()),
165 );
166 if let Some(publisher) = state.event_publisher() {
167 session_service = session_service.with_event_publisher(Arc::clone(publisher));
168 }
169 let session_id = session_service
170 .create_authenticated_session(
171 &validation_result.user_id,
172 params.headers,
173 SessionSource::Oauth,
174 )
175 .await?;
176
177 let access_token_jti = generate_access_token_jti();
178 let global_config = Config::get()?;
179 let config = JwtConfig {
180 permissions: permissions.clone(),
181 audience: global_config.jwt_audiences.clone(),
182 ..Default::default()
183 };
184 let signing = JwtSigningParams {
185 issuer: &global_config.jwt_issuer,
186 };
187 let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
188
189 let refresh_token_value = generate_secure_token("rt");
190 let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
191 let refresh_expires_at = chrono::Utc::now().timestamp()
192 + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
193 * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
194
195 let refresh_params = RefreshTokenParams::builder(
196 &refresh_token_id,
197 params.client_id,
198 &validation_result.user_id,
199 &validation_result.scope,
200 refresh_expires_at,
201 )
202 .build();
203 repo.store_refresh_token(refresh_params).await?;
204
205 if let Err(e) = repo
206 .link_auth_code_to_refresh_token(params.code, refresh_token_id.as_str())
207 .await
208 {
209 tracing::warn!(error = %e, "Failed to link auth code to refresh token");
210 }
211
212 Ok(TokenResponse { access_token })
213}
214
215async fn load_authenticated_user(
216 user_id: &UserId,
217 user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
218) -> anyhow::Result<AuthenticatedUser> {
219 let user = user_provider
220 .find_by_id(user_id)
221 .await
222 .map_err(|e| anyhow::anyhow!("{}", e))?
223 .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
224
225 let permissions: Vec<Permission> = user
226 .roles
227 .iter()
228 .filter_map(|s| {
229 Permission::from_str(s)
230 .map_err(|e| {
231 tracing::warn!(
232 user_id = %user.id,
233 role = %s,
234 error = %e,
235 "Invalid role in user record"
236 );
237 e
238 })
239 .ok()
240 })
241 .collect();
242
243 let user_uuid = uuid::Uuid::parse_str(user.id.as_str())
244 .map_err(|_e| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
245
246 Ok(AuthenticatedUser::new_with_roles(
247 user_uuid,
248 user.name,
249 user.email,
250 permissions,
251 user.roles,
252 ))
253}
254
255#[derive(Debug)]
256struct BrowserClient {
257 client_id: String,
258}
259
260#[derive(Debug, serde::Deserialize)]
261struct TokenResponse {
262 access_token: String,
263}