systemprompt_api/routes/oauth/endpoints/
webauthn_complete.rs1use anyhow::Result;
2use axum::extract::{Query, State};
3use axum::http::{HeaderMap, HeaderValue, StatusCode};
4use axum::response::{IntoResponse, Redirect};
5use axum::Json;
6use serde::{Deserialize, Serialize};
7
8use systemprompt_identifiers::{AuthorizationCode, ClientId, UserId};
9use systemprompt_oauth::repository::{AuthCodeParams, OAuthRepository};
10use systemprompt_oauth::services::{generate_secure_token, is_browser_request};
11use systemprompt_oauth::OAuthState;
12
13#[derive(Debug, Deserialize)]
14pub struct WebAuthnCompleteQuery {
15 pub user_id: String,
16 pub response_type: Option<String>,
17 pub client_id: Option<String>,
18 pub redirect_uri: Option<String>,
19 pub scope: Option<String>,
20 pub state: Option<String>,
21 pub code_challenge: Option<String>,
22 pub code_challenge_method: Option<String>,
23 pub response_mode: Option<String>,
24 pub resource: Option<String>,
25}
26
27#[derive(Debug, Serialize)]
28pub struct WebAuthnCompleteError {
29 pub error: String,
30 pub error_description: String,
31}
32
33#[allow(unused_qualifications)]
34pub async fn handle_webauthn_complete(
35 headers: HeaderMap,
36 Query(params): Query<WebAuthnCompleteQuery>,
37 State(state): State<OAuthState>,
38) -> impl IntoResponse {
39 let repo = match OAuthRepository::new(state.db_pool()) {
40 Ok(r) => r,
41 Err(e) => {
42 return (
43 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
44 axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
45 ).into_response();
46 },
47 };
48 if params.client_id.is_none() {
49 return (
50 StatusCode::BAD_REQUEST,
51 Json(WebAuthnCompleteError {
52 error: "invalid_request".to_string(),
53 error_description: "Missing client_id parameter".to_string(),
54 }),
55 )
56 .into_response();
57 }
58
59 let Some(redirect_uri) = ¶ms.redirect_uri else {
60 return (
61 StatusCode::BAD_REQUEST,
62 Json(WebAuthnCompleteError {
63 error: "invalid_request".to_string(),
64 error_description: "Missing redirect_uri parameter".to_string(),
65 }),
66 )
67 .into_response();
68 };
69
70 let user_provider = state.user_provider();
71
72 match user_provider.find_by_id(¶ms.user_id).await {
73 Ok(Some(_)) => {
74 let authorization_code = generate_secure_token("auth_code");
75
76 match store_authorization_code(&repo, &authorization_code, ¶ms).await {
77 Ok(()) => {
78 create_successful_response(&headers, redirect_uri, &authorization_code, ¶ms)
79 },
80 Err(error) => (
81 StatusCode::INTERNAL_SERVER_ERROR,
82 Json(WebAuthnCompleteError {
83 error: "server_error".to_string(),
84 error_description: error.to_string(),
85 }),
86 )
87 .into_response(),
88 }
89 },
90 Ok(None) => (
91 StatusCode::UNAUTHORIZED,
92 Json(WebAuthnCompleteError {
93 error: "access_denied".to_string(),
94 error_description: "User not found".to_string(),
95 }),
96 )
97 .into_response(),
98 Err(error) => {
99 let status_code = if error.to_string().contains("User not found") {
100 StatusCode::UNAUTHORIZED
101 } else {
102 StatusCode::INTERNAL_SERVER_ERROR
103 };
104
105 let error_type = if status_code == StatusCode::UNAUTHORIZED {
106 "access_denied"
107 } else {
108 "server_error"
109 };
110
111 (
112 status_code,
113 Json(WebAuthnCompleteError {
114 error: error_type.to_string(),
115 error_description: error.to_string(),
116 }),
117 )
118 .into_response()
119 },
120 }
121}
122
123async fn store_authorization_code(
124 repo: &OAuthRepository,
125 code_str: &str,
126 query: &WebAuthnCompleteQuery,
127) -> Result<()> {
128 let client_id_str = query
129 .client_id
130 .as_ref()
131 .ok_or_else(|| anyhow::anyhow!("client_id is required"))?;
132 let redirect_uri = query
133 .redirect_uri
134 .as_ref()
135 .ok_or_else(|| anyhow::anyhow!("redirect_uri is required"))?;
136 let scope = query.scope.as_ref().map_or_else(
137 || {
138 let default_roles = OAuthRepository::get_default_roles();
139 if default_roles.is_empty() {
140 "user".to_string()
141 } else {
142 default_roles.join(" ")
143 }
144 },
145 Clone::clone,
146 );
147
148 let code = AuthorizationCode::new(code_str);
149 let client_id = ClientId::new(client_id_str);
150 let user_id = UserId::new(&query.user_id);
151
152 let mut builder = AuthCodeParams::builder(&code, &client_id, &user_id, redirect_uri, &scope);
153
154 if let (Some(challenge), Some(method)) = (
155 query.code_challenge.as_deref(),
156 query
157 .code_challenge_method
158 .as_deref()
159 .filter(|s| !s.is_empty()),
160 ) {
161 builder = builder.with_pkce(challenge, method);
162 }
163
164 if let Some(resource) = query.resource.as_deref() {
165 builder = builder.with_resource(resource);
166 }
167
168 repo.store_authorization_code(builder.build()).await
169}
170
171#[derive(Debug, Serialize)]
172pub struct WebAuthnCompleteResponse {
173 pub authorization_code: String,
174 pub state: String,
175 pub redirect_uri: String,
176 pub client_id: String,
177}
178
179fn create_successful_response(
180 headers: &HeaderMap,
181 redirect_uri: &str,
182 authorization_code: &str,
183 params: &WebAuthnCompleteQuery,
184) -> axum::response::Response {
185 let state = params.state.as_deref().filter(|s| !s.is_empty());
186
187 if is_browser_request(headers) {
188 let mut target = format!("{redirect_uri}?code={authorization_code}");
189
190 if let Some(client_id_val) = params.client_id.as_deref() {
191 target.push_str(&format!(
192 "&client_id={}",
193 urlencoding::encode(client_id_val)
194 ));
195 }
196
197 if let Some(state_val) = state {
198 target.push_str(&format!("&state={}", urlencoding::encode(state_val)));
199 }
200 Redirect::to(&target).into_response()
201 } else {
202 let response_data = WebAuthnCompleteResponse {
203 authorization_code: authorization_code.to_string(),
204 state: state.unwrap_or("").to_string(),
205 redirect_uri: redirect_uri.to_string(),
206 client_id: params.client_id.as_deref().unwrap_or("").to_string(),
207 };
208
209 let mut response = Json(response_data).into_response();
210
211 let headers = response.headers_mut();
212 headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
213 headers.insert(
214 "access-control-allow-methods",
215 HeaderValue::from_static("GET, POST, OPTIONS"),
216 );
217 headers.insert(
218 "access-control-allow-headers",
219 HeaderValue::from_static("content-type, authorization"),
220 );
221
222 response
223 }
224}