Skip to main content

systemprompt_api/routes/oauth/endpoints/
callback.rs

1use axum::extract::{Query, State};
2use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8    AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::Config;
11use systemprompt_models::auth::{AuthenticatedUser, Permission, parse_permissions};
12
13use crate::routes::oauth::extractors::OAuthRepo;
14use systemprompt_oauth::OAuthState;
15use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
16
17#[derive(Debug, Deserialize)]
18pub struct CallbackQuery {
19    pub code: String,
20    pub state: Option<String>,
21}
22
23pub async fn handle_callback(
24    Query(params): Query<CallbackQuery>,
25    State(state): State<OAuthState>,
26    OAuthRepo(repo): OAuthRepo,
27    headers: HeaderMap,
28) -> impl IntoResponse {
29    let config = match Config::get() {
30        Ok(c) => c,
31        Err(e) => {
32            return (
33                StatusCode::INTERNAL_SERVER_ERROR,
34                format!("Failed to load config: {e}"),
35            )
36                .into_response();
37        },
38    };
39
40    let server_base_url = &config.api_external_url;
41    let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
42
43    let browser_client = match find_browser_client(&repo, &redirect_uri).await {
44        Ok(client) => client,
45        Err(e) => {
46            return (
47                StatusCode::INTERNAL_SERVER_ERROR,
48                format!("Failed to find OAuth client: {e}"),
49            )
50                .into_response();
51        },
52    };
53
54    let code = AuthorizationCode::new(&params.code);
55    let client_id = ClientId::new(&browser_client.client_id);
56    let token_response = match exchange_code_for_token(
57        &repo,
58        CodeExchangeParams {
59            code: &code,
60            client_id: &client_id,
61            redirect_uri: &redirect_uri,
62            headers: &headers,
63        },
64        &state,
65    )
66    .await
67    {
68        Ok(response) => response,
69        Err(e) => {
70            return (
71                StatusCode::UNAUTHORIZED,
72                format!("Failed to exchange code for token: {e}"),
73            )
74                .into_response();
75        },
76    };
77
78    let redirect_destination = params
79        .state
80        .as_deref()
81        .filter(|s| !s.is_empty())
82        .unwrap_or("/");
83
84    let cookie = format!(
85        "access_token={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
86        token_response.access_token,
87        systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
88    );
89
90    let mut response = Redirect::to(redirect_destination).into_response();
91    if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
92        response
93            .headers_mut()
94            .insert(header::SET_COOKIE, cookie_value);
95    }
96
97    response
98}
99
100async fn find_browser_client(
101    repo: &OAuthRepository,
102    redirect_uri: &str,
103) -> anyhow::Result<BrowserClient> {
104    let client = repo
105        .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
106        .await?
107        .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
108
109    Ok(BrowserClient {
110        client_id: client.client_id.to_string(),
111    })
112}
113
114struct CodeExchangeParams<'a> {
115    code: &'a AuthorizationCode,
116    client_id: &'a ClientId,
117    redirect_uri: &'a str,
118    headers: &'a HeaderMap,
119}
120
121async fn exchange_code_for_token(
122    repo: &OAuthRepository,
123    params: CodeExchangeParams<'_>,
124    state: &OAuthState,
125) -> anyhow::Result<TokenResponse> {
126    use systemprompt_oauth::services::{
127        JwtConfig, JwtSigningParams, generate_access_token_jti, generate_jwt, generate_secure_token,
128    };
129
130    let validation_result = repo
131        .validate_authorization_code(
132            params.code,
133            params.client_id,
134            Some(params.redirect_uri),
135            None,
136        )
137        .await?;
138
139    let user = load_authenticated_user(&validation_result.user_id, state.user_provider()).await?;
140
141    let permissions = parse_permissions(&validation_result.scope)?;
142
143    let mut session_service = systemprompt_oauth::services::SessionCreationService::new(
144        Arc::clone(state.analytics_provider()),
145        Arc::clone(state.user_provider()),
146    );
147    if let Some(publisher) = state.event_publisher() {
148        session_service = session_service.with_event_publisher(Arc::clone(publisher));
149    }
150    let session_id = session_service
151        .create_authenticated_session(
152            &validation_result.user_id,
153            params.headers,
154            SessionSource::Oauth,
155        )
156        .await?;
157
158    let access_token_jti = generate_access_token_jti();
159    let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
160    let global_config = Config::get()?;
161    let config = JwtConfig {
162        permissions: permissions.clone(),
163        audience: global_config.jwt_audiences.clone(),
164        ..Default::default()
165    };
166    let signing = JwtSigningParams {
167        secret: jwt_secret,
168        issuer: &global_config.jwt_issuer,
169    };
170    let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
171
172    let refresh_token_value = generate_secure_token("rt");
173    let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
174    let refresh_expires_at = chrono::Utc::now().timestamp()
175        + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
176            * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
177
178    let refresh_params = RefreshTokenParams::builder(
179        &refresh_token_id,
180        params.client_id,
181        &validation_result.user_id,
182        &validation_result.scope,
183        refresh_expires_at,
184    )
185    .build();
186    repo.store_refresh_token(refresh_params).await?;
187
188    Ok(TokenResponse { access_token })
189}
190
191async fn load_authenticated_user(
192    user_id: &UserId,
193    user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
194) -> anyhow::Result<AuthenticatedUser> {
195    let user = user_provider
196        .find_by_id(user_id)
197        .await
198        .map_err(|e| anyhow::anyhow!("{}", e))?
199        .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
200
201    let permissions: Vec<Permission> = user
202        .roles
203        .iter()
204        .filter_map(|s| {
205            Permission::from_str(s)
206                .map_err(|e| {
207                    tracing::warn!(
208                        user_id = %user.id,
209                        role = %s,
210                        error = %e,
211                        "Invalid role in user record"
212                    );
213                    e
214                })
215                .ok()
216        })
217        .collect();
218
219    let user_uuid = uuid::Uuid::parse_str(user.id.as_str())
220        .map_err(|_| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
221
222    Ok(AuthenticatedUser::new_with_roles(
223        user_uuid,
224        user.name,
225        user.email,
226        permissions,
227        user.roles,
228    ))
229}
230
231#[derive(Debug)]
232struct BrowserClient {
233    client_id: String,
234}
235
236#[derive(Debug, serde::Deserialize)]
237struct TokenResponse {
238    access_token: String,
239}