systemprompt_api/routes/oauth/endpoints/
callback.rs1use axum::extract::{Query, State};
2use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8 AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::auth::{parse_permissions, AuthenticatedUser, Permission};
11use systemprompt_models::Config;
12
13use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
14use systemprompt_oauth::OAuthState;
15
16#[derive(Debug, Deserialize)]
17pub struct CallbackQuery {
18 pub code: String,
19 pub state: Option<String>,
20}
21
22pub async fn handle_callback(
23 Query(params): Query<CallbackQuery>,
24 State(state): State<OAuthState>,
25 headers: HeaderMap,
26) -> impl IntoResponse {
27 let repo = match OAuthRepository::new(Arc::clone(state.db_pool())) {
28 Ok(r) => r,
29 Err(e) => {
30 return (
31 StatusCode::INTERNAL_SERVER_ERROR,
32 axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
33 ).into_response();
34 },
35 };
36 let config = match Config::get() {
37 Ok(c) => c,
38 Err(e) => {
39 return (
40 StatusCode::INTERNAL_SERVER_ERROR,
41 format!("Failed to load config: {e}"),
42 )
43 .into_response();
44 },
45 };
46
47 let server_base_url = &config.api_external_url;
48 let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
49
50 let browser_client = match find_browser_client(&repo, &redirect_uri).await {
51 Ok(client) => client,
52 Err(e) => {
53 return (
54 StatusCode::INTERNAL_SERVER_ERROR,
55 format!("Failed to find OAuth client: {e}"),
56 )
57 .into_response();
58 },
59 };
60
61 let code = AuthorizationCode::new(¶ms.code);
62 let client_id = ClientId::new(&browser_client.client_id);
63 let token_response = match exchange_code_for_token(
64 &repo,
65 CodeExchangeParams {
66 code: &code,
67 client_id: &client_id,
68 redirect_uri: &redirect_uri,
69 headers: &headers,
70 },
71 &state,
72 )
73 .await
74 {
75 Ok(response) => response,
76 Err(e) => {
77 return (
78 StatusCode::UNAUTHORIZED,
79 format!("Failed to exchange code for token: {e}"),
80 )
81 .into_response();
82 },
83 };
84
85 let redirect_destination = params
86 .state
87 .as_deref()
88 .filter(|s| !s.is_empty())
89 .unwrap_or("/");
90
91 let cookie = format!(
92 "access_token={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
93 token_response.access_token,
94 systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
95 );
96
97 let mut response = Redirect::to(redirect_destination).into_response();
98 if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
99 response
100 .headers_mut()
101 .insert(header::SET_COOKIE, cookie_value);
102 }
103
104 response
105}
106
107async fn find_browser_client(
108 repo: &OAuthRepository,
109 redirect_uri: &str,
110) -> anyhow::Result<BrowserClient> {
111 let client = repo
112 .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
113 .await?
114 .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
115
116 Ok(BrowserClient {
117 client_id: client.client_id.to_string(),
118 })
119}
120
121struct CodeExchangeParams<'a> {
122 code: &'a AuthorizationCode,
123 client_id: &'a ClientId,
124 redirect_uri: &'a str,
125 headers: &'a HeaderMap,
126}
127
128async fn exchange_code_for_token(
129 repo: &OAuthRepository,
130 params: CodeExchangeParams<'_>,
131 state: &OAuthState,
132) -> anyhow::Result<TokenResponse> {
133 use systemprompt_oauth::services::{
134 generate_access_token_jti, generate_jwt, generate_secure_token, JwtConfig, JwtSigningParams,
135 };
136
137 let validation_result = repo
138 .validate_authorization_code(
139 params.code,
140 params.client_id,
141 Some(params.redirect_uri),
142 None,
143 )
144 .await?;
145
146 let user = load_authenticated_user(&validation_result.user_id, state.user_provider()).await?;
147
148 let permissions = parse_permissions(&validation_result.scope)?;
149
150 let mut session_service = systemprompt_oauth::services::SessionCreationService::new(
151 Arc::clone(state.analytics_provider()),
152 Arc::clone(state.user_provider()),
153 );
154 if let Some(publisher) = state.event_publisher() {
155 session_service = session_service.with_event_publisher(Arc::clone(publisher));
156 }
157 let session_id = session_service
158 .create_authenticated_session(
159 &validation_result.user_id,
160 params.headers,
161 SessionSource::Oauth,
162 )
163 .await?;
164
165 let access_token_jti = generate_access_token_jti();
166 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
167 let global_config = Config::get()?;
168 let config = JwtConfig {
169 permissions: permissions.clone(),
170 audience: global_config.jwt_audiences.clone(),
171 ..Default::default()
172 };
173 let signing = JwtSigningParams {
174 secret: jwt_secret,
175 issuer: &global_config.jwt_issuer,
176 };
177 let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
178
179 let refresh_token_value = generate_secure_token("rt");
180 let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
181 let refresh_expires_at = chrono::Utc::now().timestamp()
182 + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
183 * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
184
185 let refresh_params = RefreshTokenParams::builder(
186 &refresh_token_id,
187 params.client_id,
188 &validation_result.user_id,
189 &validation_result.scope,
190 refresh_expires_at,
191 )
192 .build();
193 repo.store_refresh_token(refresh_params).await?;
194
195 Ok(TokenResponse { access_token })
196}
197
198async fn load_authenticated_user(
199 user_id: &UserId,
200 user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
201) -> anyhow::Result<AuthenticatedUser> {
202 let user = user_provider
203 .find_by_id(user_id.as_str())
204 .await
205 .map_err(|e| anyhow::anyhow!("{}", e))?
206 .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
207
208 let permissions: Vec<Permission> = user
209 .roles
210 .iter()
211 .filter_map(|s| {
212 Permission::from_str(s)
213 .map_err(|e| {
214 tracing::warn!(
215 user_id = %user.id,
216 role = %s,
217 error = %e,
218 "Invalid role in user record"
219 );
220 e
221 })
222 .ok()
223 })
224 .collect();
225
226 let user_uuid = uuid::Uuid::parse_str(&user.id)
227 .map_err(|_| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
228
229 Ok(AuthenticatedUser::new_with_roles(
230 user_uuid,
231 user.name,
232 user.email,
233 permissions,
234 user.roles,
235 ))
236}
237
238#[derive(Debug)]
239struct BrowserClient {
240 client_id: String,
241}
242
243#[derive(Debug, serde::Deserialize)]
244struct TokenResponse {
245 access_token: String,
246}