Skip to main content

systemprompt_api/routes/oauth/endpoints/
callback.rs

1use axum::extract::{Query, State};
2use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8    AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::auth::{parse_permissions, AuthenticatedUser, Permission};
11use systemprompt_models::Config;
12
13use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
14use systemprompt_oauth::OAuthState;
15
16#[derive(Debug, Deserialize)]
17pub struct CallbackQuery {
18    pub code: String,
19    pub state: Option<String>,
20}
21
22pub async fn handle_callback(
23    Query(params): Query<CallbackQuery>,
24    State(state): State<OAuthState>,
25    headers: HeaderMap,
26) -> impl IntoResponse {
27    let repo = match OAuthRepository::new(Arc::clone(state.db_pool())) {
28        Ok(r) => r,
29        Err(e) => {
30            return (
31                StatusCode::INTERNAL_SERVER_ERROR,
32                axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
33            ).into_response();
34        },
35    };
36    let config = match Config::get() {
37        Ok(c) => c,
38        Err(e) => {
39            return (
40                StatusCode::INTERNAL_SERVER_ERROR,
41                format!("Failed to load config: {e}"),
42            )
43                .into_response();
44        },
45    };
46
47    let server_base_url = &config.api_external_url;
48    let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
49
50    let browser_client = match find_browser_client(&repo, &redirect_uri).await {
51        Ok(client) => client,
52        Err(e) => {
53            return (
54                StatusCode::INTERNAL_SERVER_ERROR,
55                format!("Failed to find OAuth client: {e}"),
56            )
57                .into_response();
58        },
59    };
60
61    let code = AuthorizationCode::new(&params.code);
62    let client_id = ClientId::new(&browser_client.client_id);
63    let token_response = match exchange_code_for_token(
64        &repo,
65        CodeExchangeParams {
66            code: &code,
67            client_id: &client_id,
68            redirect_uri: &redirect_uri,
69            headers: &headers,
70        },
71        &state,
72    )
73    .await
74    {
75        Ok(response) => response,
76        Err(e) => {
77            return (
78                StatusCode::UNAUTHORIZED,
79                format!("Failed to exchange code for token: {e}"),
80            )
81                .into_response();
82        },
83    };
84
85    let redirect_destination = params
86        .state
87        .as_deref()
88        .filter(|s| !s.is_empty())
89        .unwrap_or("/");
90
91    let cookie = format!(
92        "access_token={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
93        token_response.access_token,
94        systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
95    );
96
97    let mut response = Redirect::to(redirect_destination).into_response();
98    if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
99        response
100            .headers_mut()
101            .insert(header::SET_COOKIE, cookie_value);
102    }
103
104    response
105}
106
107async fn find_browser_client(
108    repo: &OAuthRepository,
109    redirect_uri: &str,
110) -> anyhow::Result<BrowserClient> {
111    let client = repo
112        .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
113        .await?
114        .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
115
116    Ok(BrowserClient {
117        client_id: client.client_id.to_string(),
118    })
119}
120
121struct CodeExchangeParams<'a> {
122    code: &'a AuthorizationCode,
123    client_id: &'a ClientId,
124    redirect_uri: &'a str,
125    headers: &'a HeaderMap,
126}
127
128async fn exchange_code_for_token(
129    repo: &OAuthRepository,
130    params: CodeExchangeParams<'_>,
131    state: &OAuthState,
132) -> anyhow::Result<TokenResponse> {
133    use systemprompt_oauth::services::{
134        generate_access_token_jti, generate_jwt, generate_secure_token, JwtConfig, JwtSigningParams,
135    };
136
137    let (user_id, scope) = repo
138        .validate_authorization_code(
139            params.code,
140            params.client_id,
141            Some(params.redirect_uri),
142            None,
143        )
144        .await?;
145
146    let user = load_authenticated_user(&user_id, state.user_provider()).await?;
147
148    let permissions = parse_permissions(&scope)?;
149
150    let session_service = systemprompt_oauth::services::SessionCreationService::new(
151        Arc::clone(state.analytics_provider()),
152        Arc::clone(state.user_provider()),
153    );
154    let session_id = session_service
155        .create_authenticated_session(&user_id, params.headers, SessionSource::Oauth)
156        .await?;
157
158    let access_token_jti = generate_access_token_jti();
159    let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
160    let global_config = Config::get()?;
161    let config = JwtConfig {
162        permissions: permissions.clone(),
163        audience: global_config.jwt_audiences.clone(),
164        ..Default::default()
165    };
166    let signing = JwtSigningParams {
167        secret: jwt_secret,
168        issuer: &global_config.jwt_issuer,
169    };
170    let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
171
172    let refresh_token_value = generate_secure_token("rt");
173    let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
174    let refresh_expires_at = chrono::Utc::now().timestamp()
175        + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
176            * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
177
178    let refresh_params = RefreshTokenParams::builder(
179        &refresh_token_id,
180        params.client_id,
181        &user_id,
182        &scope,
183        refresh_expires_at,
184    )
185    .build();
186    repo.store_refresh_token(refresh_params).await?;
187
188    Ok(TokenResponse { access_token })
189}
190
191async fn load_authenticated_user(
192    user_id: &UserId,
193    user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
194) -> anyhow::Result<AuthenticatedUser> {
195    let user = user_provider
196        .find_by_id(user_id.as_str())
197        .await
198        .map_err(|e| anyhow::anyhow!("{}", e))?
199        .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
200
201    let permissions: Vec<Permission> = user
202        .roles
203        .iter()
204        .filter_map(|s| {
205            Permission::from_str(s)
206                .map_err(|e| {
207                    tracing::warn!(
208                        user_id = %user.id,
209                        role = %s,
210                        error = %e,
211                        "Invalid role in user record"
212                    );
213                    e
214                })
215                .ok()
216        })
217        .collect();
218
219    let user_uuid = uuid::Uuid::parse_str(&user.id)
220        .map_err(|_| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
221
222    Ok(AuthenticatedUser::new_with_roles(
223        user_uuid,
224        user.name,
225        user.email,
226        permissions,
227        user.roles,
228    ))
229}
230
231#[derive(Debug)]
232struct BrowserClient {
233    client_id: String,
234}
235
236#[derive(Debug, serde::Deserialize)]
237struct TokenResponse {
238    access_token: String,
239}