Skip to main content

systemprompt_api/routes/oauth/endpoints/
webauthn_complete.rs

1use anyhow::Result;
2use axum::extract::{Query, State};
3use axum::http::{HeaderMap, HeaderValue, StatusCode};
4use axum::response::{IntoResponse, Redirect};
5use axum::Json;
6use serde::{Deserialize, Serialize};
7
8use systemprompt_identifiers::{AuthorizationCode, ClientId, UserId};
9use systemprompt_oauth::repository::{AuthCodeParams, OAuthRepository};
10use systemprompt_oauth::services::{generate_secure_token, is_browser_request};
11use systemprompt_oauth::OAuthState;
12
13#[derive(Debug, Deserialize)]
14pub struct WebAuthnCompleteQuery {
15    pub user_id: String,
16    pub response_type: Option<String>,
17    pub client_id: Option<String>,
18    pub redirect_uri: Option<String>,
19    pub scope: Option<String>,
20    pub state: Option<String>,
21    pub code_challenge: Option<String>,
22    pub code_challenge_method: Option<String>,
23    pub response_mode: Option<String>,
24    pub resource: Option<String>,
25}
26
27#[derive(Debug, Serialize)]
28pub struct WebAuthnCompleteError {
29    pub error: String,
30    pub error_description: String,
31}
32
33#[allow(unused_qualifications)]
34pub async fn handle_webauthn_complete(
35    headers: HeaderMap,
36    Query(params): Query<WebAuthnCompleteQuery>,
37    State(state): State<OAuthState>,
38) -> impl IntoResponse {
39    let repo = match OAuthRepository::new(state.db_pool()) {
40        Ok(r) => r,
41        Err(e) => {
42            return (
43                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
44                axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
45            ).into_response();
46        },
47    };
48    if params.client_id.is_none() {
49        return (
50            StatusCode::BAD_REQUEST,
51            Json(WebAuthnCompleteError {
52                error: "invalid_request".to_string(),
53                error_description: "Missing client_id parameter".to_string(),
54            }),
55        )
56            .into_response();
57    }
58
59    let Some(redirect_uri) = &params.redirect_uri else {
60        return (
61            StatusCode::BAD_REQUEST,
62            Json(WebAuthnCompleteError {
63                error: "invalid_request".to_string(),
64                error_description: "Missing redirect_uri parameter".to_string(),
65            }),
66        )
67            .into_response();
68    };
69
70    let user_provider = state.user_provider();
71
72    match user_provider.find_by_id(&params.user_id).await {
73        Ok(Some(_)) => {
74            let authorization_code = generate_secure_token("auth_code");
75
76            match store_authorization_code(&repo, &authorization_code, &params).await {
77                Ok(()) => {
78                    create_successful_response(&headers, redirect_uri, &authorization_code, &params)
79                },
80                Err(error) => (
81                    StatusCode::INTERNAL_SERVER_ERROR,
82                    Json(WebAuthnCompleteError {
83                        error: "server_error".to_string(),
84                        error_description: error.to_string(),
85                    }),
86                )
87                    .into_response(),
88            }
89        },
90        Ok(None) => (
91            StatusCode::UNAUTHORIZED,
92            Json(WebAuthnCompleteError {
93                error: "access_denied".to_string(),
94                error_description: "User not found".to_string(),
95            }),
96        )
97            .into_response(),
98        Err(error) => {
99            let status_code = if error.to_string().contains("User not found") {
100                StatusCode::UNAUTHORIZED
101            } else {
102                StatusCode::INTERNAL_SERVER_ERROR
103            };
104
105            let error_type = if status_code == StatusCode::UNAUTHORIZED {
106                "access_denied"
107            } else {
108                "server_error"
109            };
110
111            (
112                status_code,
113                Json(WebAuthnCompleteError {
114                    error: error_type.to_string(),
115                    error_description: error.to_string(),
116                }),
117            )
118                .into_response()
119        },
120    }
121}
122
123async fn store_authorization_code(
124    repo: &OAuthRepository,
125    code_str: &str,
126    query: &WebAuthnCompleteQuery,
127) -> Result<()> {
128    let client_id_str = query
129        .client_id
130        .as_ref()
131        .ok_or_else(|| anyhow::anyhow!("client_id is required"))?;
132    let redirect_uri = query
133        .redirect_uri
134        .as_ref()
135        .ok_or_else(|| anyhow::anyhow!("redirect_uri is required"))?;
136    let scope = query.scope.as_ref().map_or_else(
137        || {
138            let default_roles = OAuthRepository::get_default_roles();
139            if default_roles.is_empty() {
140                "user".to_string()
141            } else {
142                default_roles.join(" ")
143            }
144        },
145        Clone::clone,
146    );
147
148    let code = AuthorizationCode::new(code_str);
149    let client_id = ClientId::new(client_id_str);
150    let user_id = UserId::new(&query.user_id);
151
152    let mut builder = AuthCodeParams::builder(&code, &client_id, &user_id, redirect_uri, &scope);
153
154    if let (Some(challenge), Some(method)) = (
155        query.code_challenge.as_deref(),
156        query
157            .code_challenge_method
158            .as_deref()
159            .filter(|s| !s.is_empty()),
160    ) {
161        builder = builder.with_pkce(challenge, method);
162    }
163
164    if let Some(resource) = query.resource.as_deref() {
165        builder = builder.with_resource(resource);
166    }
167
168    repo.store_authorization_code(builder.build()).await
169}
170
171#[derive(Debug, Serialize)]
172pub struct WebAuthnCompleteResponse {
173    pub authorization_code: String,
174    pub state: String,
175    pub redirect_uri: String,
176    pub client_id: String,
177}
178
179fn create_successful_response(
180    headers: &HeaderMap,
181    redirect_uri: &str,
182    authorization_code: &str,
183    params: &WebAuthnCompleteQuery,
184) -> axum::response::Response {
185    let state = params.state.as_deref().filter(|s| !s.is_empty());
186
187    if is_browser_request(headers) {
188        let mut target = format!("{redirect_uri}?code={authorization_code}");
189
190        if let Some(client_id_val) = params.client_id.as_deref() {
191            target.push_str(&format!(
192                "&client_id={}",
193                urlencoding::encode(client_id_val)
194            ));
195        }
196
197        if let Some(state_val) = state {
198            target.push_str(&format!("&state={}", urlencoding::encode(state_val)));
199        }
200        Redirect::to(&target).into_response()
201    } else {
202        let response_data = WebAuthnCompleteResponse {
203            authorization_code: authorization_code.to_string(),
204            state: state.unwrap_or("").to_string(),
205            redirect_uri: redirect_uri.to_string(),
206            client_id: params.client_id.as_deref().unwrap_or("").to_string(),
207        };
208
209        let mut response = Json(response_data).into_response();
210
211        let headers = response.headers_mut();
212        headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
213        headers.insert(
214            "access-control-allow-methods",
215            HeaderValue::from_static("GET, POST, OPTIONS"),
216        );
217        headers.insert(
218            "access-control-allow-headers",
219            HeaderValue::from_static("content-type, authorization"),
220        );
221
222        response
223    }
224}