Skip to main content

systemprompt_api/routes/oauth/endpoints/
callback.rs

1//! OAuth callback endpoint for the server's own browser client.
2//!
3//! Exchanges the returned authorization code for tokens, establishes an
4//! authenticated session, sets the access-token cookie, and redirects to the
5//! origin-validated `return_to` recovered from the consumed state binding.
6
7use axum::extract::{Query, State};
8use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
9use axum::response::{IntoResponse, Redirect};
10use serde::Deserialize;
11use std::str::FromStr;
12use std::sync::Arc;
13use systemprompt_identifiers::{
14    AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
15};
16use systemprompt_models::Config;
17use systemprompt_models::auth::{AuthenticatedUser, Permission, parse_permissions};
18
19use crate::routes::oauth::extractors::OAuthRepo;
20use systemprompt_oauth::OAuthState;
21use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
22
23#[derive(Debug, Deserialize)]
24pub struct CallbackQuery {
25    pub code: String,
26    pub state: Option<String>,
27}
28
29pub async fn handle_callback(
30    Query(params): Query<CallbackQuery>,
31    State(state): State<OAuthState>,
32    OAuthRepo(repo): OAuthRepo,
33    headers: HeaderMap,
34) -> impl IntoResponse {
35    let config = match Config::get() {
36        Ok(c) => c,
37        Err(e) => {
38            return (
39                StatusCode::INTERNAL_SERVER_ERROR,
40                format!("Failed to load config: {e}"),
41            )
42                .into_response();
43        },
44    };
45
46    let server_base_url = &config.api_external_url;
47    let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
48
49    let browser_client = match find_browser_client(&repo, &redirect_uri).await {
50        Ok(client) => client,
51        Err(e) => {
52            return (
53                StatusCode::INTERNAL_SERVER_ERROR,
54                format!("Failed to find OAuth client: {e}"),
55            )
56                .into_response();
57        },
58    };
59
60    let code = AuthorizationCode::new(&params.code);
61    let client_id = ClientId::new(&browser_client.client_id);
62    let token_response = match exchange_code_for_token(
63        &repo,
64        CodeExchangeParams {
65            code: &code,
66            client_id: &client_id,
67            redirect_uri: &redirect_uri,
68            headers: &headers,
69        },
70        &state,
71    )
72    .await
73    {
74        Ok(response) => response,
75        Err(e) => {
76            return (
77                StatusCode::UNAUTHORIZED,
78                format!("Failed to exchange code for token: {e}"),
79            )
80                .into_response();
81        },
82    };
83
84    let Some(state_token) = params.state.as_deref().filter(|s| !s.is_empty()) else {
85        return (StatusCode::BAD_REQUEST, "Missing state parameter").into_response();
86    };
87    let redirect_destination = match repo.consume_state_binding(state_token).await {
88        Ok(Some(binding)) => binding.return_to,
89        Ok(None) => {
90            tracing::warn!("state binding missing, expired, or already consumed");
91            return (StatusCode::BAD_REQUEST, "Invalid state parameter").into_response();
92        },
93        Err(e) => {
94            tracing::error!(error = %e, "state binding lookup failed");
95            return (
96                StatusCode::INTERNAL_SERVER_ERROR,
97                "Failed to validate state",
98            )
99                .into_response();
100        },
101    };
102
103    let cookie = format!(
104        "access_token={}; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={}",
105        token_response.access_token,
106        systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
107    );
108
109    let mut response = Redirect::to(&redirect_destination).into_response();
110    if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
111        response
112            .headers_mut()
113            .insert(header::SET_COOKIE, cookie_value);
114    }
115
116    response
117}
118
119async fn find_browser_client(
120    repo: &OAuthRepository,
121    redirect_uri: &str,
122) -> anyhow::Result<BrowserClient> {
123    let client = repo
124        .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
125        .await?
126        .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
127
128    Ok(BrowserClient {
129        client_id: client.client_id.to_string(),
130    })
131}
132
133struct CodeExchangeParams<'a> {
134    code: &'a AuthorizationCode,
135    client_id: &'a ClientId,
136    redirect_uri: &'a str,
137    headers: &'a HeaderMap,
138}
139
140async fn exchange_code_for_token(
141    repo: &OAuthRepository,
142    params: CodeExchangeParams<'_>,
143    state: &OAuthState,
144) -> anyhow::Result<TokenResponse> {
145    use systemprompt_oauth::services::{
146        JwtConfig, JwtSigningParams, generate_access_token_jti, generate_jwt, generate_secure_token,
147    };
148
149    let validation_result = repo
150        .validate_authorization_code(
151            params.code,
152            params.client_id,
153            Some(params.redirect_uri),
154            None,
155        )
156        .await?;
157
158    let user = load_authenticated_user(&validation_result.user_id, state.user_provider()).await?;
159
160    let permissions = parse_permissions(&validation_result.scope)?;
161
162    let mut session_service = systemprompt_oauth::services::SessionCreationService::new(
163        Arc::clone(state.analytics_provider()),
164        Arc::clone(state.user_provider()),
165    );
166    if let Some(publisher) = state.event_publisher() {
167        session_service = session_service.with_event_publisher(Arc::clone(publisher));
168    }
169    let session_id = session_service
170        .create_authenticated_session(
171            &validation_result.user_id,
172            params.headers,
173            SessionSource::Oauth,
174        )
175        .await?;
176
177    let access_token_jti = generate_access_token_jti();
178    let global_config = Config::get()?;
179    let config = JwtConfig {
180        permissions: permissions.clone(),
181        audience: global_config.jwt_audiences.clone(),
182        ..Default::default()
183    };
184    let signing = JwtSigningParams {
185        issuer: &global_config.jwt_issuer,
186    };
187    let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
188
189    let refresh_token_value = generate_secure_token("rt");
190    let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
191    let refresh_expires_at = chrono::Utc::now().timestamp()
192        + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
193            * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
194
195    let refresh_params = RefreshTokenParams::builder(
196        &refresh_token_id,
197        params.client_id,
198        &validation_result.user_id,
199        &validation_result.scope,
200        refresh_expires_at,
201    )
202    .build();
203    repo.store_refresh_token(refresh_params).await?;
204
205    if let Err(e) = repo
206        .link_auth_code_to_refresh_token(params.code, refresh_token_id.as_str())
207        .await
208    {
209        tracing::warn!(error = %e, "Failed to link auth code to refresh token");
210    }
211
212    Ok(TokenResponse { access_token })
213}
214
215async fn load_authenticated_user(
216    user_id: &UserId,
217    user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
218) -> anyhow::Result<AuthenticatedUser> {
219    let user = user_provider
220        .find_by_id(user_id)
221        .await
222        .map_err(|e| anyhow::anyhow!("{}", e))?
223        .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
224
225    let permissions: Vec<Permission> = user
226        .roles
227        .iter()
228        .filter_map(|s| {
229            Permission::from_str(s)
230                .map_err(|e| {
231                    tracing::warn!(
232                        user_id = %user.id,
233                        role = %s,
234                        error = %e,
235                        "Invalid role in user record"
236                    );
237                    e
238                })
239                .ok()
240        })
241        .collect();
242
243    let user_uuid = uuid::Uuid::parse_str(user.id.as_str())
244        .map_err(|_e| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
245
246    Ok(AuthenticatedUser::new_with_roles(
247        user_uuid,
248        user.name,
249        user.email,
250        permissions,
251        user.roles,
252    ))
253}
254
255#[derive(Debug)]
256struct BrowserClient {
257    client_id: String,
258}
259
260#[derive(Debug, serde::Deserialize)]
261struct TokenResponse {
262    access_token: String,
263}