Skip to main content

systemprompt_api/routes/oauth/endpoints/
callback.rs

1use axum::extract::{Query, State};
2use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
3use axum::response::{IntoResponse, Redirect};
4use serde::Deserialize;
5use std::str::FromStr;
6use std::sync::Arc;
7use systemprompt_identifiers::{
8    AuthorizationCode, ClientId, RefreshTokenId, SessionSource, UserId,
9};
10use systemprompt_models::auth::{parse_permissions, AuthenticatedUser, Permission};
11use systemprompt_models::Config;
12
13use systemprompt_oauth::repository::{OAuthRepository, RefreshTokenParams};
14use systemprompt_oauth::OAuthState;
15
16#[derive(Debug, Deserialize)]
17pub struct CallbackQuery {
18    pub code: String,
19    pub state: Option<String>,
20}
21
22pub async fn handle_callback(
23    Query(params): Query<CallbackQuery>,
24    State(state): State<OAuthState>,
25    headers: HeaderMap,
26) -> impl IntoResponse {
27    let repo = match OAuthRepository::new(state.db_pool()) {
28        Ok(r) => r,
29        Err(e) => {
30            return (
31                StatusCode::INTERNAL_SERVER_ERROR,
32                axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
33            ).into_response();
34        },
35    };
36    let config = match Config::get() {
37        Ok(c) => c,
38        Err(e) => {
39            return (
40                StatusCode::INTERNAL_SERVER_ERROR,
41                format!("Failed to load config: {e}"),
42            )
43                .into_response();
44        },
45    };
46
47    let server_base_url = &config.api_external_url;
48    let redirect_uri = format!("{server_base_url}/api/v1/core/oauth/callback");
49
50    let browser_client = match find_browser_client(&repo, &redirect_uri).await {
51        Ok(client) => client,
52        Err(e) => {
53            return (
54                StatusCode::INTERNAL_SERVER_ERROR,
55                format!("Failed to find OAuth client: {e}"),
56            )
57                .into_response();
58        },
59    };
60
61    let code = AuthorizationCode::new(&params.code);
62    let client_id = ClientId::new(&browser_client.client_id);
63    let token_response = match exchange_code_for_token(
64        &repo,
65        CodeExchangeParams {
66            code: &code,
67            client_id: &client_id,
68            redirect_uri: &redirect_uri,
69            headers: &headers,
70        },
71        &state,
72    )
73    .await
74    {
75        Ok(response) => response,
76        Err(e) => {
77            return (
78                StatusCode::UNAUTHORIZED,
79                format!("Failed to exchange code for token: {e}"),
80            )
81                .into_response();
82        },
83    };
84
85    let redirect_destination = params
86        .state
87        .as_deref()
88        .filter(|s| !s.is_empty())
89        .unwrap_or("/");
90
91    let cookie = format!(
92        "access_token={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
93        token_response.access_token,
94        systemprompt_oauth::constants::token::COOKIE_MAX_AGE_SECONDS
95    );
96
97    let mut response = Redirect::to(redirect_destination).into_response();
98    if let Ok(cookie_value) = HeaderValue::from_str(&cookie) {
99        response
100            .headers_mut()
101            .insert(header::SET_COOKIE, cookie_value);
102    }
103
104    response
105}
106
107async fn find_browser_client(
108    repo: &OAuthRepository,
109    redirect_uri: &str,
110) -> anyhow::Result<BrowserClient> {
111    let client = repo
112        .find_client_by_redirect_uri_with_scope(redirect_uri, &["admin", "user"])
113        .await?
114        .ok_or_else(|| anyhow::anyhow!("No suitable browser client found"))?;
115
116    Ok(BrowserClient {
117        client_id: client.client_id.to_string(),
118    })
119}
120
121struct CodeExchangeParams<'a> {
122    code: &'a AuthorizationCode,
123    client_id: &'a ClientId,
124    redirect_uri: &'a str,
125    headers: &'a HeaderMap,
126}
127
128async fn exchange_code_for_token(
129    repo: &OAuthRepository,
130    params: CodeExchangeParams<'_>,
131    state: &OAuthState,
132) -> anyhow::Result<TokenResponse> {
133    use systemprompt_oauth::services::{
134        generate_access_token_jti, generate_jwt, generate_secure_token, JwtConfig, JwtSigningParams,
135    };
136
137    let validation_result = repo
138        .validate_authorization_code(
139            params.code,
140            params.client_id,
141            Some(params.redirect_uri),
142            None,
143        )
144        .await?;
145
146    let user = load_authenticated_user(&validation_result.user_id, state.user_provider()).await?;
147
148    let permissions = parse_permissions(&validation_result.scope)?;
149
150    let mut session_service = systemprompt_oauth::services::SessionCreationService::new(
151        Arc::clone(state.analytics_provider()),
152        Arc::clone(state.user_provider()),
153    );
154    if let Some(publisher) = state.event_publisher() {
155        session_service = session_service.with_event_publisher(Arc::clone(publisher));
156    }
157    let session_id = session_service
158        .create_authenticated_session(
159            &validation_result.user_id,
160            params.headers,
161            SessionSource::Oauth,
162        )
163        .await?;
164
165    let access_token_jti = generate_access_token_jti();
166    let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
167    let global_config = Config::get()?;
168    let config = JwtConfig {
169        permissions: permissions.clone(),
170        audience: global_config.jwt_audiences.clone(),
171        ..Default::default()
172    };
173    let signing = JwtSigningParams {
174        secret: jwt_secret,
175        issuer: &global_config.jwt_issuer,
176    };
177    let access_token = generate_jwt(&user, config, access_token_jti, &session_id, &signing)?;
178
179    let refresh_token_value = generate_secure_token("rt");
180    let refresh_token_id = RefreshTokenId::new(&refresh_token_value);
181    let refresh_expires_at = chrono::Utc::now().timestamp()
182        + (systemprompt_oauth::constants::token::SECONDS_PER_DAY
183            * systemprompt_oauth::constants::token::REFRESH_TOKEN_EXPIRY_DAYS);
184
185    let refresh_params = RefreshTokenParams::builder(
186        &refresh_token_id,
187        params.client_id,
188        &validation_result.user_id,
189        &validation_result.scope,
190        refresh_expires_at,
191    )
192    .build();
193    repo.store_refresh_token(refresh_params).await?;
194
195    Ok(TokenResponse { access_token })
196}
197
198async fn load_authenticated_user(
199    user_id: &UserId,
200    user_provider: &Arc<dyn systemprompt_traits::UserProvider>,
201) -> anyhow::Result<AuthenticatedUser> {
202    let user = user_provider
203        .find_by_id(user_id.as_str())
204        .await
205        .map_err(|e| anyhow::anyhow!("{}", e))?
206        .ok_or_else(|| anyhow::anyhow!("User not found: {user_id}"))?;
207
208    let permissions: Vec<Permission> = user
209        .roles
210        .iter()
211        .filter_map(|s| {
212            Permission::from_str(s)
213                .map_err(|e| {
214                    tracing::warn!(
215                        user_id = %user.id,
216                        role = %s,
217                        error = %e,
218                        "Invalid role in user record"
219                    );
220                    e
221                })
222                .ok()
223        })
224        .collect();
225
226    let user_uuid = uuid::Uuid::parse_str(&user.id)
227        .map_err(|_| anyhow::anyhow!("Invalid user UUID: {}", user.id))?;
228
229    Ok(AuthenticatedUser::new_with_roles(
230        user_uuid,
231        user.name,
232        user.email,
233        permissions,
234        user.roles,
235    ))
236}
237
238#[derive(Debug)]
239struct BrowserClient {
240    client_id: String,
241}
242
243#[derive(Debug, serde::Deserialize)]
244struct TokenResponse {
245    access_token: String,
246}