systemprompt_api/routes/oauth/webauthn/
authenticate.rs1use axum::extract::{Query, State};
2use axum::http::StatusCode;
3use axum::response::IntoResponse;
4use axum::Json;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use systemprompt_oauth::repository::OAuthRepository;
8use systemprompt_oauth::services::webauthn::WebAuthnManager;
9use systemprompt_oauth::OAuthState;
10use tracing::instrument;
11use webauthn_rs::prelude::*;
12
13#[derive(Debug, Deserialize)]
14pub struct StartAuthQuery {
15 pub email: String,
16 pub oauth_state: Option<String>,
17}
18
19#[derive(Debug, Serialize)]
20pub struct StartAuthResponse {
21 #[serde(rename = "publicKey")]
22 pub public_key: serde_json::Value,
23 pub challenge_id: String,
24}
25
26#[derive(Debug, Serialize)]
27pub struct AuthError {
28 pub error: String,
29 pub error_description: String,
30}
31
32#[allow(unused_qualifications)]
33#[instrument(skip(state, params), fields(email = %params.email))]
34pub async fn start_auth(
35 Query(params): Query<StartAuthQuery>,
36 State(state): State<OAuthState>,
37) -> impl IntoResponse {
38 let oauth_repo = match OAuthRepository::new(state.db_pool()) {
39 Ok(r) => r,
40 Err(e) => {
41 return (
42 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
43 axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
44 ).into_response();
45 },
46 };
47 let user_provider = Arc::clone(state.user_provider());
48
49 let webauthn_service =
50 match WebAuthnManager::get_or_create_service(oauth_repo, user_provider).await {
51 Ok(service) => service,
52 Err(e) => {
53 tracing::error!(error = %e, "Failed to initialize WebAuthn");
54 return (
55 StatusCode::INTERNAL_SERVER_ERROR,
56 Json(AuthError {
57 error: "server_error".to_string(),
58 error_description: format!("Failed to initialize WebAuthn: {e}"),
59 }),
60 )
61 .into_response();
62 },
63 };
64
65 match webauthn_service
66 .start_authentication(¶ms.email, params.oauth_state)
67 .await
68 {
69 Ok((challenge, challenge_id)) => {
70 let challenge_json = match serde_json::to_value(&challenge) {
71 Ok(json) => json,
72 Err(e) => {
73 return (
74 StatusCode::INTERNAL_SERVER_ERROR,
75 Json(AuthError {
76 error: "server_error".to_string(),
77 error_description: format!("Failed to serialize challenge: {e}"),
78 }),
79 )
80 .into_response();
81 },
82 };
83
84 let mut public_key = match challenge_json.get("publicKey") {
85 Some(pk) => pk.clone(),
86 None => {
87 return (
88 StatusCode::INTERNAL_SERVER_ERROR,
89 Json(AuthError {
90 error: "server_error".to_string(),
91 error_description: "Missing publicKey in challenge".to_string(),
92 }),
93 )
94 .into_response();
95 },
96 };
97
98 if let Some(obj) = public_key.as_object_mut() {
99 obj.remove("authenticatorAttachment");
100 }
101
102 (
103 StatusCode::OK,
104 Json(StartAuthResponse {
105 public_key,
106 challenge_id,
107 }),
108 )
109 .into_response()
110 },
111 Err(e) => {
112 let status_code = if e.to_string().contains("User not found") {
113 StatusCode::NOT_FOUND
114 } else {
115 StatusCode::BAD_REQUEST
116 };
117
118 (
119 status_code,
120 Json(AuthError {
121 error: "authentication_failed".to_string(),
122 error_description: e.to_string(),
123 }),
124 )
125 .into_response()
126 },
127 }
128}
129
130#[derive(Debug, Deserialize)]
131pub struct FinishAuthRequest {
132 pub challenge_id: String,
133 pub credential: PublicKeyCredential,
134}
135
136#[derive(Debug, Serialize)]
137pub struct FinishAuthResponse {
138 pub user_id: String,
139 pub oauth_state: Option<String>,
140 pub success: bool,
141}
142
143#[instrument(skip(state, request), fields(challenge_id = %request.challenge_id))]
144pub async fn finish_auth(
145 State(state): State<OAuthState>,
146 Json(request): Json<FinishAuthRequest>,
147) -> impl IntoResponse {
148 let oauth_repo = match OAuthRepository::new(state.db_pool()) {
149 Ok(r) => r,
150 Err(e) => {
151 return (
152 StatusCode::INTERNAL_SERVER_ERROR,
153 Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
154 ).into_response();
155 },
156 };
157 let user_provider = Arc::clone(state.user_provider());
158
159 let webauthn_service =
160 match WebAuthnManager::get_or_create_service(oauth_repo, user_provider).await {
161 Ok(service) => service,
162 Err(e) => {
163 tracing::error!(error = %e, "Failed to initialize WebAuthn");
164 return (
165 StatusCode::INTERNAL_SERVER_ERROR,
166 Json(AuthError {
167 error: "server_error".to_string(),
168 error_description: format!("Failed to initialize WebAuthn: {e}"),
169 }),
170 )
171 .into_response();
172 },
173 };
174
175 match webauthn_service
176 .finish_authentication(&request.challenge_id, &request.credential)
177 .await
178 {
179 Ok((user_id, oauth_state)) => (
180 StatusCode::OK,
181 Json(FinishAuthResponse {
182 user_id,
183 oauth_state,
184 success: true,
185 }),
186 )
187 .into_response(),
188 Err(e) => (
189 StatusCode::UNAUTHORIZED,
190 Json(AuthError {
191 error: "authentication_failed".to_string(),
192 error_description: e.to_string(),
193 }),
194 )
195 .into_response(),
196 }
197}