Skip to main content

systemprompt_api/routes/oauth/endpoints/
callback.rs

1use axum::extract::{Query, State};
2use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8    AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::Config;
11use systemprompt_models::auth::{AuthenticatedUser, Permission, parse_permissions};
12
13use crate::routes::oauth::extractors::OAuthRepo;
14use systemprompt_oauth::OAuthState;
15use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
16
17#[derive(Debug, Deserialize)]
18pub struct CallbackQuery {
19    pub code: String,
20    pub state: Option<String>,
21}
22
23pub async fn handle_callback(
24    Query(params): Query<CallbackQuery>,
25    State(state): State<OAuthState>,
26    OAuthRepo(repo): OAuthRepo,
27    headers: HeaderMap,
28) -> impl IntoResponse {
29    let config = match Config::get() {
30        Ok(c) => c,
31        Err(e) => {
32            return (
33                StatusCode::INTERNAL_SERVER_ERROR,
34                format!("Failed to load config: {e}"),
35            )
36                .into_response();
37        },
38    };
39
40    let server_base_url = &config.api_external_url;
41    let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
42
43    let browser_client = match find_browser_client(&repo, &redirect_uri).await {
44        Ok(client) => client,
45        Err(e) => {
46            return (
47                StatusCode::INTERNAL_SERVER_ERROR,
48                format!("Failed to find OAuth client: {e}"),
49            )
50                .into_response();
51        },
52    };
53
54    let code = AuthorizationCode::new(&params.code);
55    let client_id = ClientId::new(&browser_client.client_id);
56    let token_response = match exchange_code_for_token(
57        &repo,
58        CodeExchangeParams {
59            code: &code,
60            client_id: &client_id,
61            redirect_uri: &redirect_uri,
62            headers: &headers,
63        },
64        &state,
65    )
66    .await
67    {
68        Ok(response) => response,
69        Err(e) => {
70            return (
71                StatusCode::UNAUTHORIZED,
72                format!("Failed to exchange code for token: {e}"),
73            )
74                .into_response();
75        },
76    };
77
78    let Some(state_token) = params.state.as_deref().filter(|s| !s.is_empty()) else {
79        return (StatusCode::BAD_REQUEST, "Missing state parameter").into_response();
80    };
81    let redirect_destination = match repo.consume_state_binding(state_token).await {
82        Ok(Some(binding)) => binding.return_to,
83        Ok(None) => {
84            tracing::warn!("state binding missing, expired, or already consumed");
85            return (StatusCode::BAD_REQUEST, "Invalid state parameter").into_response();
86        },
87        Err(e) => {
88            tracing::error!(error = %e, "state binding lookup failed");
89            return (
90                StatusCode::INTERNAL_SERVER_ERROR,
91                "Failed to validate state",
92            )
93                .into_response();
94        },
95    };
96
97    let cookie = format!(
98        "access_token={}; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={}",
99        token_response.access_token,
100        systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
101    );
102
103    let mut response = Redirect::to(&redirect_destination).into_response();
104    if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
105        response
106            .headers_mut()
107            .insert(header::SET_COOKIE, cookie_value);
108    }
109
110    response
111}
112
113async fn find_browser_client(
114    repo: &OAuthRepository,
115    redirect_uri: &str,
116) -> anyhow::Result<BrowserClient> {
117    let client = repo
118        .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
119        .await?
120        .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
121
122    Ok(BrowserClient {
123        client_id: client.client_id.to_string(),
124    })
125}
126
127struct CodeExchangeParams<'a> {
128    code: &'a AuthorizationCode,
129    client_id: &'a ClientId,
130    redirect_uri: &'a str,
131    headers: &'a HeaderMap,
132}
133
134async fn exchange_code_for_token(
135    repo: &OAuthRepository,
136    params: CodeExchangeParams<'_>,
137    state: &OAuthState,
138) -> anyhow::Result<TokenResponse> {
139    use systemprompt_oauth::services::{
140        JwtConfig, JwtSigningParams, generate_access_token_jti, generate_jwt, generate_secure_token,
141    };
142
143    let validation_result = repo
144        .validate_authorization_code(
145            params.code,
146            params.client_id,
147            Some(params.redirect_uri),
148            None,
149        )
150        .await?;
151
152    let user = load_authenticated_user(&validation_result.user_id, state.user_provider()).await?;
153
154    let permissions = parse_permissions(&validation_result.scope)?;
155
156    let mut session_service = systemprompt_oauth::services::SessionCreationService::new(
157        Arc::clone(state.analytics_provider()),
158        Arc::clone(state.user_provider()),
159    );
160    if let Some(publisher) = state.event_publisher() {
161        session_service = session_service.with_event_publisher(Arc::clone(publisher));
162    }
163    let session_id = session_service
164        .create_authenticated_session(
165            &validation_result.user_id,
166            params.headers,
167            SessionSource::Oauth,
168        )
169        .await?;
170
171    let access_token_jti = generate_access_token_jti();
172    let global_config = Config::get()?;
173    let config = JwtConfig {
174        permissions: permissions.clone(),
175        audience: global_config.jwt_audiences.clone(),
176        ..Default::default()
177    };
178    let signing = JwtSigningParams {
179        issuer: &global_config.jwt_issuer,
180    };
181    let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
182
183    let refresh_token_value = generate_secure_token("rt");
184    let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
185    let refresh_expires_at = chrono::Utc::now().timestamp()
186        + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
187            * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
188
189    let refresh_params = RefreshTokenParams::builder(
190        &refresh_token_id,
191        params.client_id,
192        &validation_result.user_id,
193        &validation_result.scope,
194        refresh_expires_at,
195    )
196    .build();
197    repo.store_refresh_token(refresh_params).await?;
198
199    if let Err(e) = repo
200        .link_auth_code_to_refresh_token(params.code, refresh_token_id.as_str())
201        .await
202    {
203        tracing::warn!(error = %e, "Failed to link auth code to refresh token");
204    }
205
206    Ok(TokenResponse { access_token })
207}
208
209async fn load_authenticated_user(
210    user_id: &UserId,
211    user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
212) -> anyhow::Result<AuthenticatedUser> {
213    let user = user_provider
214        .find_by_id(user_id)
215        .await
216        .map_err(|e| anyhow::anyhow!("{}", e))?
217        .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
218
219    let permissions: Vec<Permission> = user
220        .roles
221        .iter()
222        .filter_map(|s| {
223            Permission::from_str(s)
224                .map_err(|e| {
225                    tracing::warn!(
226                        user_id = %user.id,
227                        role = %s,
228                        error = %e,
229                        "Invalid role in user record"
230                    );
231                    e
232                })
233                .ok()
234        })
235        .collect();
236
237    let user_uuid = uuid::Uuid::parse_str(user.id.as_str())
238        .map_err(|_e| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
239
240    Ok(AuthenticatedUser::new_with_roles(
241        user_uuid,
242        user.name,
243        user.email,
244        permissions,
245        user.roles,
246    ))
247}
248
249#[derive(Debug)]
250struct BrowserClient {
251    client_id: String,
252}
253
254#[derive(Debug, serde::Deserialize)]
255struct TokenResponse {
256    access_token: String,
257}