Skip to main content

systemprompt_api/routes/oauth/endpoints/
callback.rs

1use axum::extract::{Query, State};
2use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8    AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::auth::{parse_permissions, AuthenticatedUser, Permission};
11use systemprompt_models::Config;
12
13use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
14use systemprompt_oauth::OAuthState;
15
16#[derive(Debug, Deserialize)]
17pub struct CallbackQuery {
18    pub code: String,
19    pub state: Option<String>,
20}
21
22pub async fn handle_callback(
23    Query(params): Query<CallbackQuery>,
24    State(state): State<OAuthState>,
25    headers: HeaderMap,
26) -> impl IntoResponse {
27    let repo = match OAuthRepository::new(Arc::clone(state.db_pool())) {
28        Ok(r) => r,
29        Err(e) => {
30            return (
31                StatusCode::INTERNAL_SERVER_ERROR,
32                axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
33            ).into_response();
34        },
35    };
36    let config = match Config::get() {
37        Ok(c) => c,
38        Err(e) => {
39            return (
40                StatusCode::INTERNAL_SERVER_ERROR,
41                format!("Failed to load config: {e}"),
42            )
43                .into_response();
44        },
45    };
46
47    let server_base_url = &config.api_external_url;
48    let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
49
50    let browser_client = match find_browser_client(&repo, &redirect_uri).await {
51        Ok(client) => client,
52        Err(e) => {
53            return (
54                StatusCode::INTERNAL_SERVER_ERROR,
55                format!("Failed to find OAuth client: {e}"),
56            )
57                .into_response();
58        },
59    };
60
61    let code = AuthorizationCode::new(&params.code);
62    let client_id = ClientId::new(&browser_client.client_id);
63    let token_response = match exchange_code_for_token(
64        &repo,
65        CodeExchangeParams {
66            code: &code,
67            client_id: &client_id,
68            redirect_uri: &redirect_uri,
69            headers: &headers,
70        },
71        &state,
72    )
73    .await
74    {
75        Ok(response) => response,
76        Err(e) => {
77            return (
78                StatusCode::UNAUTHORIZED,
79                format!("Failed to exchange code for token: {e}"),
80            )
81                .into_response();
82        },
83    };
84
85    let redirect_destination = params
86        .state
87        .as_deref()
88        .filter(|s| !s.is_empty())
89        .unwrap_or("/");
90
91    let cookie = format!(
92        "access_token={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
93        token_response.access_token,
94        systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
95    );
96
97    let mut response = Redirect::to(redirect_destination).into_response();
98    if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
99        response
100            .headers_mut()
101            .insert(header::SET_COOKIE, cookie_value);
102    }
103
104    response
105}
106
107async fn find_browser_client(
108    repo: &OAuthRepository,
109    redirect_uri: &str,
110) -> anyhow::Result<BrowserClient> {
111    let client = repo
112        .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
113        .await?
114        .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
115
116    Ok(BrowserClient {
117        client_id: client.client_id.to_string(),
118    })
119}
120
121struct CodeExchangeParams<'a> {
122    code: &'a AuthorizationCode,
123    client_id: &'a ClientId,
124    redirect_uri: &'a str,
125    headers: &'a HeaderMap,
126}
127
128async fn exchange_code_for_token(
129    repo: &OAuthRepository,
130    params: CodeExchangeParams<'_>,
131    state: &OAuthState,
132) -> anyhow::Result<TokenResponse> {
133    use systemprompt_oauth::services::{
134        generate_access_token_jti, generate_jwt, generate_secure_token, JwtConfig, JwtSigningParams,
135    };
136
137    let (user_id, scope) = repo
138        .validate_authorization_code(
139            params.code,
140            params.client_id,
141            Some(params.redirect_uri),
142            None,
143        )
144        .await?;
145
146    let user = load_authenticated_user(&user_id, state.user_provider()).await?;
147
148    let permissions = parse_permissions(&scope)?;
149
150    let session_service = systemprompt_oauth::services::SessionCreationService::new(
151        Arc::clone(state.analytics_provider()),
152        Arc::clone(state.user_provider()),
153    );
154    let session_id = session_service
155        .create_authenticated_session(&user_id, params.headers, SessionSource::Oauth)
156        .await?;
157
158    let access_token_jti = generate_access_token_jti();
159    let config = JwtConfig {
160        permissions: permissions.clone(),
161        ..Default::default()
162    };
163    let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
164    let global_config = Config::get()?;
165    let signing = JwtSigningParams {
166        secret: jwt_secret,
167        issuer: &global_config.jwt_issuer,
168    };
169    let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
170
171    let refresh_token_value = generate_secure_token("rt");
172    let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
173    let refresh_expires_at = chrono::Utc::now().timestamp()
174        + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
175            * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
176
177    let refresh_params = RefreshTokenParams::builder(
178        &refresh_token_id,
179        params.client_id,
180        &user_id,
181        &scope,
182        refresh_expires_at,
183    )
184    .build();
185    repo.store_refresh_token(refresh_params).await?;
186
187    Ok(TokenResponse { access_token })
188}
189
190async fn load_authenticated_user(
191    user_id: &UserId,
192    user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
193) -> anyhow::Result<AuthenticatedUser> {
194    let user = user_provider
195        .find_by_id(user_id.as_str())
196        .await
197        .map_err(|e| anyhow::anyhow!("{}", e))?
198        .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
199
200    let permissions: Vec<Permission> = user
201        .roles
202        .iter()
203        .filter_map(|s| {
204            Permission::from_str(s)
205                .map_err(|e| {
206                    tracing::warn!(
207                        user_id = %user.id,
208                        role = %s,
209                        error = %e,
210                        "Invalid role in user record"
211                    );
212                    e
213                })
214                .ok()
215        })
216        .collect();
217
218    let user_uuid = uuid::Uuid::parse_str(&user.id)
219        .map_err(|_| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
220
221    Ok(AuthenticatedUser::new_with_roles(
222        user_uuid,
223        user.name,
224        user.email,
225        permissions,
226        user.roles,
227    ))
228}
229
230#[derive(Debug)]
231struct BrowserClient {
232    client_id: String,
233}
234
235#[derive(Debug, serde::Deserialize)]
236struct TokenResponse {
237    access_token: String,
238}