systemprompt_api/routes/oauth/endpoints/
callback.rs1use axum::extract::{Query, State};
2use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8 AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::auth::{parse_permissions, AuthenticatedUser, Permission};
11use systemprompt_models::Config;
12
13use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
14use systemprompt_oauth::OAuthState;
15
16#[derive(Debug, Deserialize)]
17pub struct CallbackQuery {
18 pub code: String,
19 pub state: Option<String>,
20}
21
22pub async fn handle_callback(
23 Query(params): Query<CallbackQuery>,
24 State(state): State<OAuthState>,
25 headers: HeaderMap,
26) -> impl IntoResponse {
27 let repo = match OAuthRepository::new(Arc::clone(state.db_pool())) {
28 Ok(r) => r,
29 Err(e) => {
30 return (
31 StatusCode::INTERNAL_SERVER_ERROR,
32 axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
33 ).into_response();
34 },
35 };
36 let config = match Config::get() {
37 Ok(c) => c,
38 Err(e) => {
39 return (
40 StatusCode::INTERNAL_SERVER_ERROR,
41 format!("Failed to load config: {e}"),
42 )
43 .into_response();
44 },
45 };
46
47 let server_base_url = &config.api_external_url;
48 let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
49
50 let browser_client = match find_browser_client(&repo, &redirect_uri).await {
51 Ok(client) => client,
52 Err(e) => {
53 return (
54 StatusCode::INTERNAL_SERVER_ERROR,
55 format!("Failed to find OAuth client: {e}"),
56 )
57 .into_response();
58 },
59 };
60
61 let code = AuthorizationCode::new(¶ms.code);
62 let client_id = ClientId::new(&browser_client.client_id);
63 let token_response = match exchange_code_for_token(
64 &repo,
65 CodeExchangeParams {
66 code: &code,
67 client_id: &client_id,
68 redirect_uri: &redirect_uri,
69 headers: &headers,
70 },
71 &state,
72 )
73 .await
74 {
75 Ok(response) => response,
76 Err(e) => {
77 return (
78 StatusCode::UNAUTHORIZED,
79 format!("Failed to exchange code for token: {e}"),
80 )
81 .into_response();
82 },
83 };
84
85 let redirect_destination = params
86 .state
87 .as_deref()
88 .filter(|s| !s.is_empty())
89 .unwrap_or("/");
90
91 let cookie = format!(
92 "access_token={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
93 token_response.access_token,
94 systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
95 );
96
97 let mut response = Redirect::to(redirect_destination).into_response();
98 if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
99 response
100 .headers_mut()
101 .insert(header::SET_COOKIE, cookie_value);
102 }
103
104 response
105}
106
107async fn find_browser_client(
108 repo: &OAuthRepository,
109 redirect_uri: &str,
110) -> anyhow::Result<BrowserClient> {
111 let client = repo
112 .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
113 .await?
114 .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
115
116 Ok(BrowserClient {
117 client_id: client.client_id.to_string(),
118 })
119}
120
121struct CodeExchangeParams<'a> {
122 code: &'a AuthorizationCode,
123 client_id: &'a ClientId,
124 redirect_uri: &'a str,
125 headers: &'a HeaderMap,
126}
127
128async fn exchange_code_for_token(
129 repo: &OAuthRepository,
130 params: CodeExchangeParams<'_>,
131 state: &OAuthState,
132) -> anyhow::Result<TokenResponse> {
133 use systemprompt_oauth::services::{
134 generate_access_token_jti, generate_jwt, generate_secure_token, JwtConfig, JwtSigningParams,
135 };
136
137 let (user_id, scope) = repo
138 .validate_authorization_code(
139 params.code,
140 params.client_id,
141 Some(params.redirect_uri),
142 None,
143 )
144 .await?;
145
146 let user = load_authenticated_user(&user_id, state.user_provider()).await?;
147
148 let permissions = parse_permissions(&scope)?;
149
150 let session_service = systemprompt_oauth::services::SessionCreationService::new(
151 Arc::clone(state.analytics_provider()),
152 Arc::clone(state.user_provider()),
153 );
154 let session_id = session_service
155 .create_authenticated_session(&user_id, params.headers, SessionSource::Oauth)
156 .await?;
157
158 let access_token_jti = generate_access_token_jti();
159 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
160 let global_config = Config::get()?;
161 let config = JwtConfig {
162 permissions: permissions.clone(),
163 audience: global_config.jwt_audiences.clone(),
164 ..Default::default()
165 };
166 let signing = JwtSigningParams {
167 secret: jwt_secret,
168 issuer: &global_config.jwt_issuer,
169 };
170 let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
171
172 let refresh_token_value = generate_secure_token("rt");
173 let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
174 let refresh_expires_at = chrono::Utc::now().timestamp()
175 + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
176 * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
177
178 let refresh_params = RefreshTokenParams::builder(
179 &refresh_token_id,
180 params.client_id,
181 &user_id,
182 &scope,
183 refresh_expires_at,
184 )
185 .build();
186 repo.store_refresh_token(refresh_params).await?;
187
188 Ok(TokenResponse { access_token })
189}
190
191async fn load_authenticated_user(
192 user_id: &UserId,
193 user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
194) -> anyhow::Result<AuthenticatedUser> {
195 let user = user_provider
196 .find_by_id(user_id.as_str())
197 .await
198 .map_err(|e| anyhow::anyhow!("{}", e))?
199 .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
200
201 let permissions: Vec<Permission> = user
202 .roles
203 .iter()
204 .filter_map(|s| {
205 Permission::from_str(s)
206 .map_err(|e| {
207 tracing::warn!(
208 user_id = %user.id,
209 role = %s,
210 error = %e,
211 "Invalid role in user record"
212 );
213 e
214 })
215 .ok()
216 })
217 .collect();
218
219 let user_uuid = uuid::Uuid::parse_str(&user.id)
220 .map_err(|_| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
221
222 Ok(AuthenticatedUser::new_with_roles(
223 user_uuid,
224 user.name,
225 user.email,
226 permissions,
227 user.roles,
228 ))
229}
230
231#[derive(Debug)]
232struct BrowserClient {
233 client_id: String,
234}
235
236#[derive(Debug, serde::Deserialize)]
237struct TokenResponse {
238 access_token: String,
239}