Skip to main content

systemprompt_api/routes/oauth/endpoints/
webauthn_complete.rs

1use anyhow::Result;
2use axum::Json;
3use axum::extract::{Query, State};
4use axum::http::{HeaderMap, StatusCode};
5use axum::response::{IntoResponse, Redirect};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8
9use crate::routes::oauth::extractors::OAuthRepo;
10use systemprompt_identifiers::{AuthorizationCode, ClientId, UserId};
11use systemprompt_oauth::OAuthState;
12use systemprompt_oauth::repository::{AuthCodeParams, OAuthRepository};
13use systemprompt_oauth::services::webauthn::WebAuthnManager;
14use systemprompt_oauth::services::{generate_secure_token, is_browser_request};
15
16#[derive(Debug, Deserialize)]
17pub struct WebAuthnCompleteQuery {
18    pub user_id: UserId,
19    pub auth_token: Option<String>,
20    pub response_type: Option<String>,
21    pub client_id: Option<ClientId>,
22    pub redirect_uri: Option<String>,
23    pub scope: Option<String>,
24    pub state: Option<String>,
25    pub code_challenge: Option<String>,
26    pub code_challenge_method: Option<String>,
27    pub response_mode: Option<String>,
28    pub resource: Option<String>,
29}
30
31#[derive(Debug, Serialize)]
32pub struct WebAuthnCompleteError {
33    pub error: String,
34    pub error_description: String,
35}
36
37#[allow(unused_qualifications)]
38pub async fn handle_webauthn_complete(
39    headers: HeaderMap,
40    Query(params): Query<WebAuthnCompleteQuery>,
41    State(state): State<OAuthState>,
42    OAuthRepo(repo): OAuthRepo,
43) -> impl IntoResponse {
44    let auth_token = match &params.auth_token {
45        Some(token) => token.clone(),
46        None => {
47            return (
48                StatusCode::BAD_REQUEST,
49                Json(WebAuthnCompleteError {
50                    error: "invalid_request".to_string(),
51                    error_description: "Missing auth_token parameter".to_string(),
52                }),
53            )
54                .into_response();
55        },
56    };
57
58    let user_provider = state.user_provider();
59    let webauthn_service =
60        match WebAuthnManager::get_or_create_service(repo.clone(), Arc::clone(user_provider)).await
61        {
62            Ok(service) => service,
63            Err(e) => {
64                return (
65                    StatusCode::INTERNAL_SERVER_ERROR,
66                    Json(WebAuthnCompleteError {
67                        error: "server_error".to_string(),
68                        error_description: format!("WebAuthn service initialization failed: {e}"),
69                    }),
70                )
71                    .into_response();
72            },
73        };
74
75    let Ok(verified_user_id) = webauthn_service
76        .consume_verified_authentication(&auth_token)
77        .await
78    else {
79        return (
80            StatusCode::UNAUTHORIZED,
81            Json(WebAuthnCompleteError {
82                error: "access_denied".to_string(),
83                error_description: "Invalid or expired authentication token".to_string(),
84            }),
85        )
86            .into_response();
87    };
88
89    if params.user_id != verified_user_id {
90        tracing::warn!(
91            claimed_user_id = %params.user_id,
92            verified_user_id = %verified_user_id,
93            "WebAuthn complete user_id mismatch"
94        );
95        return (
96            StatusCode::UNAUTHORIZED,
97            Json(WebAuthnCompleteError {
98                error: "access_denied".to_string(),
99                error_description: "User identity verification failed".to_string(),
100            }),
101        )
102            .into_response();
103    }
104
105    if params.client_id.is_none() {
106        return (
107            StatusCode::BAD_REQUEST,
108            Json(WebAuthnCompleteError {
109                error: "invalid_request".to_string(),
110                error_description: "Missing client_id parameter".to_string(),
111            }),
112        )
113            .into_response();
114    }
115
116    let Some(redirect_uri) = &params.redirect_uri else {
117        return (
118            StatusCode::BAD_REQUEST,
119            Json(WebAuthnCompleteError {
120                error: "invalid_request".to_string(),
121                error_description: "Missing redirect_uri parameter".to_string(),
122            }),
123        )
124            .into_response();
125    };
126
127    match user_provider.find_by_id(&verified_user_id).await {
128        Ok(Some(_)) => {
129            let authorization_code = generate_secure_token("auth_code");
130
131            match store_authorization_code(&repo, &authorization_code, &params).await {
132                Ok(()) => {
133                    create_successful_response(&headers, redirect_uri, &authorization_code, &params)
134                },
135                Err(error) => (
136                    StatusCode::INTERNAL_SERVER_ERROR,
137                    Json(WebAuthnCompleteError {
138                        error: "server_error".to_string(),
139                        error_description: error.to_string(),
140                    }),
141                )
142                    .into_response(),
143            }
144        },
145        Ok(None) => (
146            StatusCode::UNAUTHORIZED,
147            Json(WebAuthnCompleteError {
148                error: "access_denied".to_string(),
149                error_description: "User not found".to_string(),
150            }),
151        )
152            .into_response(),
153        Err(error) => {
154            let status_code = if error.to_string().contains("User not found") {
155                StatusCode::UNAUTHORIZED
156            } else {
157                StatusCode::INTERNAL_SERVER_ERROR
158            };
159
160            let error_type = if status_code == StatusCode::UNAUTHORIZED {
161                "access_denied"
162            } else {
163                "server_error"
164            };
165
166            (
167                status_code,
168                Json(WebAuthnCompleteError {
169                    error: error_type.to_string(),
170                    error_description: error.to_string(),
171                }),
172            )
173                .into_response()
174        },
175    }
176}
177
178async fn store_authorization_code(
179    repo: &OAuthRepository,
180    code_str: &str,
181    query: &WebAuthnCompleteQuery,
182) -> Result<()> {
183    let client_id = query
184        .client_id
185        .as_ref()
186        .ok_or_else(|| anyhow::anyhow!("client_id is required"))?;
187    let redirect_uri = query
188        .redirect_uri
189        .as_ref()
190        .ok_or_else(|| anyhow::anyhow!("redirect_uri is required"))?;
191    let scope = query.scope.as_ref().map_or_else(
192        || {
193            let default_roles = OAuthRepository::get_default_roles();
194            if default_roles.is_empty() {
195                "user".to_string()
196            } else {
197                default_roles.join(" ")
198            }
199        },
200        Clone::clone,
201    );
202
203    let code = AuthorizationCode::new(code_str);
204
205    let mut builder =
206        AuthCodeParams::builder(&code, client_id, &query.user_id, redirect_uri, &scope);
207
208    if let (Some(challenge), Some(method)) = (
209        query.code_challenge.as_deref(),
210        query
211            .code_challenge_method
212            .as_deref()
213            .filter(|s| !s.is_empty()),
214    ) {
215        builder = builder.with_pkce(challenge, method);
216    }
217
218    if let Some(resource) = query.resource.as_deref() {
219        builder = builder.with_resource(resource);
220    }
221
222    repo.store_authorization_code(builder.build()).await
223}
224
225#[derive(Debug, Serialize)]
226pub struct WebAuthnCompleteResponse {
227    pub authorization_code: String,
228    pub state: String,
229    pub redirect_uri: String,
230    pub client_id: ClientId,
231}
232
233fn create_successful_response(
234    headers: &HeaderMap,
235    redirect_uri: &str,
236    authorization_code: &str,
237    params: &WebAuthnCompleteQuery,
238) -> axum::response::Response {
239    let state = params.state.as_deref().filter(|s| !s.is_empty());
240
241    if is_browser_request(headers) {
242        let mut target = format!("{redirect_uri}?code={authorization_code}");
243
244        if let Some(client_id_val) = params.client_id.as_ref() {
245            target.push_str(&format!(
246                "&client_id={}",
247                urlencoding::encode(client_id_val.as_str())
248            ));
249        }
250
251        if let Some(state_val) = state {
252            target.push_str(&format!("&state={}", urlencoding::encode(state_val)));
253        }
254        Redirect::to(&target).into_response()
255    } else {
256        let response_data = WebAuthnCompleteResponse {
257            authorization_code: authorization_code.to_string(),
258            state: state.unwrap_or("").to_string(),
259            redirect_uri: redirect_uri.to_string(),
260            client_id: params
261                .client_id
262                .clone()
263                .unwrap_or_else(|| ClientId::new("")),
264        };
265
266        Json(response_data).into_response()
267    }
268}