systemprompt_api/routes/oauth/endpoints/
callback.rs1use axum::extract::{Query, State};
2use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8 AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::Config;
11use systemprompt_models::auth::{AuthenticatedUser, Permission, parse_permissions};
12
13use crate::routes::oauth::extractors::OAuthRepo;
14use systemprompt_oauth::OAuthState;
15use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
16
17#[derive(Debug, Deserialize)]
18pub struct CallbackQuery {
19 pub code: String,
20 pub state: Option<String>,
21}
22
23pub async fn handle_callback(
24 Query(params): Query<CallbackQuery>,
25 State(state): State<OAuthState>,
26 OAuthRepo(repo): OAuthRepo,
27 headers: HeaderMap,
28) -> impl IntoResponse {
29 let config = match Config::get() {
30 Ok(c) => c,
31 Err(e) => {
32 return (
33 StatusCode::INTERNAL_SERVER_ERROR,
34 format!("Failed to load config: {e}"),
35 )
36 .into_response();
37 },
38 };
39
40 let server_base_url = &config.api_external_url;
41 let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
42
43 let browser_client = match find_browser_client(&repo, &redirect_uri).await {
44 Ok(client) => client,
45 Err(e) => {
46 return (
47 StatusCode::INTERNAL_SERVER_ERROR,
48 format!("Failed to find OAuth client: {e}"),
49 )
50 .into_response();
51 },
52 };
53
54 let code = AuthorizationCode::new(¶ms.code);
55 let client_id = ClientId::new(&browser_client.client_id);
56 let token_response = match exchange_code_for_token(
57 &repo,
58 CodeExchangeParams {
59 code: &code,
60 client_id: &client_id,
61 redirect_uri: &redirect_uri,
62 headers: &headers,
63 },
64 &state,
65 )
66 .await
67 {
68 Ok(response) => response,
69 Err(e) => {
70 return (
71 StatusCode::UNAUTHORIZED,
72 format!("Failed to exchange code for token: {e}"),
73 )
74 .into_response();
75 },
76 };
77
78 let redirect_destination = params
79 .state
80 .as_deref()
81 .filter(|s| !s.is_empty())
82 .unwrap_or("/");
83
84 let cookie = format!(
85 "access_token={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
86 token_response.access_token,
87 systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
88 );
89
90 let mut response = Redirect::to(redirect_destination).into_response();
91 if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
92 response
93 .headers_mut()
94 .insert(header::SET_COOKIE, cookie_value);
95 }
96
97 response
98}
99
100async fn find_browser_client(
101 repo: &OAuthRepository,
102 redirect_uri: &str,
103) -> anyhow::Result<BrowserClient> {
104 let client = repo
105 .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
106 .await?
107 .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
108
109 Ok(BrowserClient {
110 client_id: client.client_id.to_string(),
111 })
112}
113
114struct CodeExchangeParams<'a> {
115 code: &'a AuthorizationCode,
116 client_id: &'a ClientId,
117 redirect_uri: &'a str,
118 headers: &'a HeaderMap,
119}
120
121async fn exchange_code_for_token(
122 repo: &OAuthRepository,
123 params: CodeExchangeParams<'_>,
124 state: &OAuthState,
125) -> anyhow::Result<TokenResponse> {
126 use systemprompt_oauth::services::{
127 JwtConfig, JwtSigningParams, generate_access_token_jti, generate_jwt, generate_secure_token,
128 };
129
130 let validation_result = repo
131 .validate_authorization_code(
132 params.code,
133 params.client_id,
134 Some(params.redirect_uri),
135 None,
136 )
137 .await?;
138
139 let user = load_authenticated_user(&validation_result.user_id, state.user_provider()).await?;
140
141 let permissions = parse_permissions(&validation_result.scope)?;
142
143 let mut session_service = systemprompt_oauth::services::SessionCreationService::new(
144 Arc::clone(state.analytics_provider()),
145 Arc::clone(state.user_provider()),
146 );
147 if let Some(publisher) = state.event_publisher() {
148 session_service = session_service.with_event_publisher(Arc::clone(publisher));
149 }
150 let session_id = session_service
151 .create_authenticated_session(
152 &validation_result.user_id,
153 params.headers,
154 SessionSource::Oauth,
155 )
156 .await?;
157
158 let access_token_jti = generate_access_token_jti();
159 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
160 let global_config = Config::get()?;
161 let config = JwtConfig {
162 permissions: permissions.clone(),
163 audience: global_config.jwt_audiences.clone(),
164 ..Default::default()
165 };
166 let signing = JwtSigningParams {
167 secret: jwt_secret,
168 issuer: &global_config.jwt_issuer,
169 };
170 let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
171
172 let refresh_token_value = generate_secure_token("rt");
173 let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
174 let refresh_expires_at = chrono::Utc::now().timestamp()
175 + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
176 * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
177
178 let refresh_params = RefreshTokenParams::builder(
179 &refresh_token_id,
180 params.client_id,
181 &validation_result.user_id,
182 &validation_result.scope,
183 refresh_expires_at,
184 )
185 .build();
186 repo.store_refresh_token(refresh_params).await?;
187
188 Ok(TokenResponse { access_token })
189}
190
191async fn load_authenticated_user(
192 user_id: &UserId,
193 user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
194) -> anyhow::Result<AuthenticatedUser> {
195 let user = user_provider
196 .find_by_id(user_id)
197 .await
198 .map_err(|e| anyhow::anyhow!("{}", e))?
199 .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
200
201 let permissions: Vec<Permission> = user
202 .roles
203 .iter()
204 .filter_map(|s| {
205 Permission::from_str(s)
206 .map_err(|e| {
207 tracing::warn!(
208 user_id = %user.id,
209 role = %s,
210 error = %e,
211 "Invalid role in user record"
212 );
213 e
214 })
215 .ok()
216 })
217 .collect();
218
219 let user_uuid = uuid::Uuid::parse_str(user.id.as_str())
220 .map_err(|_| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
221
222 Ok(AuthenticatedUser::new_with_roles(
223 user_uuid,
224 user.name,
225 user.email,
226 permissions,
227 user.roles,
228 ))
229}
230
231#[derive(Debug)]
232struct BrowserClient {
233 client_id: String,
234}
235
236#[derive(Debug, serde::Deserialize)]
237struct TokenResponse {
238 access_token: String,
239}