Skip to main content

systemprompt_api/routes/oauth/webauthn/
authenticate.rs

1use axum::Json;
2use axum::extract::{Query, State};
3use axum::http::StatusCode;
4use axum::response::IntoResponse;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use systemprompt_identifiers::{ChallengeId, UserId};
8use systemprompt_oauth::OAuthState;
9use systemprompt_oauth::services::webauthn::WebAuthnManager;
10use tracing::instrument;
11use webauthn_rs::prelude::*;
12
13use crate::routes::oauth::extractors::OAuthRepo;
14
15#[derive(Debug, Deserialize)]
16pub struct StartAuthQuery {
17    pub email: String,
18    pub oauth_state: Option<String>,
19}
20
21#[derive(Debug, Serialize)]
22pub struct StartAuthResponse {
23    #[serde(rename = "publicKey")]
24    pub public_key: serde_json::Value,
25    pub challenge_id: ChallengeId,
26}
27
28#[derive(Debug, Serialize)]
29pub struct AuthError {
30    pub error: String,
31    pub error_description: String,
32}
33
34#[allow(unused_qualifications)]
35#[instrument(skip(state, oauth_repo, params), fields(email = %params.email))]
36pub async fn start_auth(
37    Query(params): Query<StartAuthQuery>,
38    State(state): State<OAuthState>,
39    OAuthRepo(oauth_repo): OAuthRepo,
40) -> impl IntoResponse {
41    let user_provider = Arc::clone(state.user_provider());
42
43    let webauthn_service =
44        match WebAuthnManager::get_or_create_service(oauth_repo, user_provider).await {
45            Ok(service) => service,
46            Err(e) => {
47                tracing::error!(error = %e, "Failed to initialize WebAuthn");
48                return (
49                    StatusCode::INTERNAL_SERVER_ERROR,
50                    Json(AuthError {
51                        error: "server_error".to_string(),
52                        error_description: format!("Failed to initialize WebAuthn: {e}"),
53                    }),
54                )
55                    .into_response();
56            },
57        };
58
59    match webauthn_service
60        .start_authentication(&params.email, params.oauth_state)
61        .await
62    {
63        Ok((challenge, challenge_id)) => {
64            let challenge_json = match serde_json::to_value(&challenge) {
65                Ok(json) => json,
66                Err(e) => {
67                    return (
68                        StatusCode::INTERNAL_SERVER_ERROR,
69                        Json(AuthError {
70                            error: "server_error".to_string(),
71                            error_description: format!("Failed to serialize challenge: {e}"),
72                        }),
73                    )
74                        .into_response();
75                },
76            };
77
78            let mut public_key = match challenge_json.get("publicKey") {
79                Some(pk) => pk.clone(),
80                None => {
81                    return (
82                        StatusCode::INTERNAL_SERVER_ERROR,
83                        Json(AuthError {
84                            error: "server_error".to_string(),
85                            error_description: "Missing publicKey in challenge".to_string(),
86                        }),
87                    )
88                        .into_response();
89                },
90            };
91
92            if let Some(obj) = public_key.as_object_mut() {
93                obj.remove("authenticatorAttachment");
94            }
95
96            (
97                StatusCode::OK,
98                Json(StartAuthResponse {
99                    public_key,
100                    challenge_id: ChallengeId::new(challenge_id),
101                }),
102            )
103                .into_response()
104        },
105        Err(e) => {
106            let status_code = if e.to_string().contains("User not found") {
107                StatusCode::NOT_FOUND
108            } else {
109                StatusCode::BAD_REQUEST
110            };
111
112            (
113                status_code,
114                Json(AuthError {
115                    error: "authentication_failed".to_string(),
116                    error_description: e.to_string(),
117                }),
118            )
119                .into_response()
120        },
121    }
122}
123
124#[derive(Debug, Deserialize)]
125pub struct FinishAuthRequest {
126    pub challenge_id: ChallengeId,
127    pub credential: PublicKeyCredential,
128}
129
130#[derive(Debug, Serialize)]
131pub struct FinishAuthResponse {
132    pub user_id: UserId,
133    pub oauth_state: Option<String>,
134    pub success: bool,
135    pub auth_token: Option<String>,
136}
137
138#[instrument(skip(state, oauth_repo, request), fields(challenge_id = %request.challenge_id))]
139pub async fn finish_auth(
140    State(state): State<OAuthState>,
141    OAuthRepo(oauth_repo): OAuthRepo,
142    Json(request): Json<FinishAuthRequest>,
143) -> impl IntoResponse {
144    let user_provider = Arc::clone(state.user_provider());
145
146    let webauthn_service =
147        match WebAuthnManager::get_or_create_service(oauth_repo, user_provider).await {
148            Ok(service) => service,
149            Err(e) => {
150                tracing::error!(error = %e, "Failed to initialize WebAuthn");
151                return (
152                    StatusCode::INTERNAL_SERVER_ERROR,
153                    Json(AuthError {
154                        error: "server_error".to_string(),
155                        error_description: format!("Failed to initialize WebAuthn: {e}"),
156                    }),
157                )
158                    .into_response();
159            },
160        };
161
162    match webauthn_service
163        .finish_authentication(request.challenge_id.as_str(), &request.credential)
164        .await
165    {
166        Ok((user_id, oauth_state)) => {
167            let auth_token =
168                systemprompt_oauth::services::generate_secure_token("webauthn_verified");
169            webauthn_service
170                .store_verified_authentication(auth_token.clone(), user_id.clone())
171                .await;
172
173            (
174                StatusCode::OK,
175                Json(FinishAuthResponse {
176                    user_id,
177                    oauth_state,
178                    success: true,
179                    auth_token: Some(auth_token),
180                }),
181            )
182                .into_response()
183        },
184        Err(e) => (
185            StatusCode::UNAUTHORIZED,
186            Json(AuthError {
187                error: "authentication_failed".to_string(),
188                error_description: e.to_string(),
189            }),
190        )
191            .into_response(),
192    }
193}