Skip to main content

systemprompt_oauth/repository/
setup_token.rs

1//! `WebAuthn` setup-token persistence and validation.
2
3use crate::error::{OauthError, OauthResult as Result};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use systemprompt_identifiers::{TokenId, UserId};
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(rename_all = "snake_case")]
11pub enum SetupTokenPurpose {
12    CredentialLink,
13    Recovery,
14}
15
16#[derive(Debug, Error)]
17#[error("invalid setup token purpose: {0}")]
18pub struct SetupTokenPurposeParseError(pub String);
19
20impl From<SetupTokenPurposeParseError> for OauthError {
21    fn from(err: SetupTokenPurposeParseError) -> Self {
22        Self::Validation(err.to_string())
23    }
24}
25
26impl SetupTokenPurpose {
27    #[must_use]
28    pub const fn as_str(&self) -> &'static str {
29        match self {
30            Self::CredentialLink => "credential_link",
31            Self::Recovery => "recovery",
32        }
33    }
34}
35
36impl std::fmt::Display for SetupTokenPurpose {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "{}", self.as_str())
39    }
40}
41
42impl std::str::FromStr for SetupTokenPurpose {
43    type Err = SetupTokenPurposeParseError;
44
45    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
46        match s {
47            "credential_link" => Ok(Self::CredentialLink),
48            "recovery" => Ok(Self::Recovery),
49            other => Err(SetupTokenPurposeParseError(other.to_string())),
50        }
51    }
52}
53
54#[derive(Debug)]
55pub struct CreateSetupTokenParams {
56    pub user_id: UserId,
57    pub token_hash: String,
58    pub purpose: SetupTokenPurpose,
59    pub expires_at: DateTime<Utc>,
60}
61
62#[derive(Debug, Clone)]
63pub struct SetupTokenRecord {
64    pub id: TokenId,
65    pub user_id: UserId,
66    pub purpose: SetupTokenPurpose,
67    pub expires_at: DateTime<Utc>,
68    pub created_at: DateTime<Utc>,
69}
70
71#[derive(Debug)]
72pub enum TokenValidationResult {
73    Valid(SetupTokenRecord),
74    Expired,
75    AlreadyUsed,
76    NotFound,
77}
78
79impl crate::repository::OAuthRepository {
80    pub async fn store_setup_token(&self, params: CreateSetupTokenParams) -> Result<String> {
81        let id = uuid::Uuid::new_v4().to_string();
82
83        let user_id_str = params.user_id.as_str();
84        sqlx::query!(
85            r#"
86            INSERT INTO webauthn_setup_tokens (id, user_id, token_hash, purpose, expires_at)
87            VALUES ($1, $2, $3, $4, $5)
88            "#,
89            id,
90            user_id_str,
91            params.token_hash,
92            params.purpose.as_str(),
93            params.expires_at
94        )
95        .execute(self.write_pool_ref())
96        .await?;
97
98        Ok(id)
99    }
100
101    pub async fn validate_setup_token(&self, token_hash: &str) -> Result<TokenValidationResult> {
102        let row = sqlx::query!(
103            r#"
104            SELECT id, user_id, purpose, expires_at, used_at, created_at
105            FROM webauthn_setup_tokens
106            WHERE token_hash = $1
107            "#,
108            token_hash
109        )
110        .fetch_optional(self.pool_ref())
111        .await?;
112
113        match row {
114            None => Ok(TokenValidationResult::NotFound),
115            Some(r) => {
116                if r.used_at.is_some() {
117                    return Ok(TokenValidationResult::AlreadyUsed);
118                }
119                if r.expires_at < Utc::now() {
120                    return Ok(TokenValidationResult::Expired);
121                }
122
123                let purpose: SetupTokenPurpose = r.purpose.parse()?;
124
125                Ok(TokenValidationResult::Valid(SetupTokenRecord {
126                    id: TokenId::new(r.id),
127                    user_id: UserId::new(r.user_id),
128                    purpose,
129                    expires_at: r.expires_at,
130                    created_at: r.created_at,
131                }))
132            },
133        }
134    }
135
136    pub async fn consume_setup_token(&self, token_id: &TokenId) -> Result<bool> {
137        let rows_affected = sqlx::query!(
138            r#"
139            UPDATE webauthn_setup_tokens
140            SET used_at = CURRENT_TIMESTAMP
141            WHERE id = $1 AND used_at IS NULL
142            "#,
143            token_id.as_str()
144        )
145        .execute(self.write_pool_ref())
146        .await?
147        .rows_affected();
148
149        Ok(rows_affected > 0)
150    }
151
152    pub async fn cleanup_expired_setup_tokens(&self) -> Result<u64> {
153        let rows_affected = sqlx::query!(
154            r#"
155            DELETE FROM webauthn_setup_tokens
156            WHERE (expires_at < CURRENT_TIMESTAMP - INTERVAL '24 hours')
157               OR (used_at IS NOT NULL AND used_at < CURRENT_TIMESTAMP - INTERVAL '24 hours')
158            "#
159        )
160        .execute(self.write_pool_ref())
161        .await?
162        .rows_affected();
163
164        Ok(rows_affected)
165    }
166
167    pub async fn revoke_user_setup_tokens(&self, user_id: &UserId) -> Result<u64> {
168        let user_id_str = user_id.as_str();
169        let rows_affected = sqlx::query!(
170            r#"
171            UPDATE webauthn_setup_tokens
172            SET used_at = CURRENT_TIMESTAMP
173            WHERE user_id = $1 AND used_at IS NULL
174            "#,
175            user_id_str
176        )
177        .execute(self.write_pool_ref())
178        .await?
179        .rows_affected();
180
181        Ok(rows_affected)
182    }
183}