Skip to main content

rustauth_core/auth/oauth/
state.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use sha2::{Digest, Sha256};
4use time::{Duration, OffsetDateTime};
5
6use crate::context::AuthContext;
7use crate::crypto::random::generate_random_string;
8use crate::db::DbAdapter;
9use crate::error::RustAuthError;
10use crate::options::OAuthStateStoreStrategy;
11use crate::verification::{CreateVerificationInput, DbVerificationStore};
12
13use super::tokens::{decrypt_with_context, encrypt_with_context};
14
15fn verification_store<'a>(
16    context: &'a AuthContext,
17    adapter: &'a dyn DbAdapter,
18) -> DbVerificationStore<'a> {
19    DbVerificationStore::with_options(
20        adapter,
21        context.db_schema.clone(),
22        context.options.verification.clone(),
23    )
24}
25
26#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
27pub struct OAuthStateLink {
28    pub email: String,
29    pub user_id: String,
30}
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33pub struct OAuthStateData {
34    pub callback_url: String,
35    pub code_verifier: String,
36    pub oauth_state: String,
37    pub error_url: Option<String>,
38    pub new_user_url: Option<String>,
39    pub link: Option<OAuthStateLink>,
40    pub expires_at: OffsetDateTime,
41    pub request_sign_up: bool,
42    pub additional_data: Value,
43}
44
45#[derive(Debug, Clone, PartialEq)]
46pub struct OAuthStateInput {
47    pub callback_url: String,
48    pub error_url: Option<String>,
49    pub new_user_url: Option<String>,
50    pub link: Option<OAuthStateLink>,
51    pub request_sign_up: bool,
52    pub additional_data: Value,
53    pub expires_at: Option<OffsetDateTime>,
54}
55
56impl Default for OAuthStateInput {
57    fn default() -> Self {
58        Self {
59            callback_url: String::new(),
60            error_url: None,
61            new_user_url: None,
62            link: None,
63            request_sign_up: false,
64            additional_data: Value::Null,
65            expires_at: None,
66        }
67    }
68}
69
70#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
71pub struct GeneratedOAuthState {
72    pub state: String,
73    pub data: OAuthStateData,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub struct OAuthStateParseInput<'a> {
78    pub state: &'a str,
79    pub oauth_state: Option<&'a str>,
80    pub skip_state_cookie_check: bool,
81}
82
83pub async fn generate_oauth_state(
84    context: &AuthContext,
85    adapter: Option<&dyn DbAdapter>,
86    input: OAuthStateInput,
87) -> Result<GeneratedOAuthState, RustAuthError> {
88    if input.callback_url.is_empty() {
89        return Err(RustAuthError::Api("callback URL is required".to_owned()));
90    }
91    let data = OAuthStateData {
92        callback_url: input.callback_url,
93        code_verifier: generate_random_string(128),
94        oauth_state: generate_random_string(32),
95        error_url: input.error_url,
96        new_user_url: input.new_user_url,
97        link: input.link,
98        expires_at: input
99            .expires_at
100            .unwrap_or_else(|| OffsetDateTime::now_utc() + Duration::minutes(10)),
101        request_sign_up: input.request_sign_up,
102        additional_data: input.additional_data,
103    };
104    let state = match context.options.account.store_state_strategy {
105        OAuthStateStoreStrategy::Cookie => {
106            let json = serde_json::to_string(&data)
107                .map_err(|error| RustAuthError::Crypto(error.to_string()))?;
108            let encrypted = encrypt_with_context(&json, context)?;
109            // Cookie-mode state is single-use: persist a server-side marker bound
110            // to this exact ciphertext. `parse_oauth_state` consumes the marker on
111            // first use, so a captured `state` cannot be replayed within its TTL
112            // (OPE-19). Without an adapter we cannot store a marker, so we fall back
113            // to the legacy stateless behavior.
114            if let Some(adapter) = adapter {
115                verification_store(context, adapter)
116                    .create_verification(CreateVerificationInput::new(
117                        cookie_state_single_use_identifier(&encrypted),
118                        String::new(),
119                        data.expires_at,
120                    ))
121                    .await?;
122            }
123            encrypted
124        }
125        OAuthStateStoreStrategy::Database => {
126            let adapter = adapter.ok_or_else(|| {
127                RustAuthError::Adapter("database OAuth state requires an adapter".to_owned())
128            })?;
129            let state = generate_random_string(32);
130            let json = serde_json::to_string(&data)
131                .map_err(|error| RustAuthError::Crypto(error.to_string()))?;
132            verification_store(context, adapter)
133                .create_verification(CreateVerificationInput::new(
134                    oauth_state_identifier(&state),
135                    json,
136                    data.expires_at,
137                ))
138                .await?;
139            state
140        }
141    };
142    Ok(GeneratedOAuthState { state, data })
143}
144
145pub async fn parse_oauth_state(
146    context: &AuthContext,
147    adapter: Option<&dyn DbAdapter>,
148    state: &str,
149) -> Result<OAuthStateData, RustAuthError> {
150    parse_oauth_state_with_input(
151        context,
152        adapter,
153        OAuthStateParseInput {
154            state,
155            oauth_state: None,
156            skip_state_cookie_check: true,
157        },
158    )
159    .await
160}
161
162pub async fn parse_oauth_state_with_input(
163    context: &AuthContext,
164    adapter: Option<&dyn DbAdapter>,
165    input: OAuthStateParseInput<'_>,
166) -> Result<OAuthStateData, RustAuthError> {
167    let state = input.state;
168    let data = match context.options.account.store_state_strategy {
169        OAuthStateStoreStrategy::Cookie => {
170            // Enforce single-use when a server-side marker exists. Cookie-mode
171            // states generated with an adapter create a marker at generation time;
172            // atomically consuming it here rejects replays and parallel callbacks
173            // (OPE-19, OPE-106). A missing marker means the state was already
174            // consumed or never issued with an adapter.
175            if let Some(adapter) = adapter {
176                let verifications = verification_store(context, adapter);
177                let identifier = cookie_state_single_use_identifier(state);
178                if verifications
179                    .consume_verification_including_expired(&identifier)
180                    .await?
181                    .is_none()
182                {
183                    return Err(RustAuthError::Api("invalid OAuth state".to_owned()));
184                }
185            }
186            let json = decrypt_with_context(state, context)?;
187            serde_json::from_str::<OAuthStateData>(&json)
188                .map_err(|error| RustAuthError::Crypto(error.to_string()))?
189        }
190        OAuthStateStoreStrategy::Database => {
191            let adapter = adapter.ok_or_else(|| {
192                RustAuthError::Adapter("database OAuth state requires an adapter".to_owned())
193            })?;
194            let verifications = verification_store(context, adapter);
195            let identifier = oauth_state_identifier(state);
196            let verification = verifications
197                .consume_verification_including_expired(&identifier)
198                .await?
199                .ok_or_else(|| RustAuthError::Api("invalid OAuth state".to_owned()))?;
200            serde_json::from_str::<OAuthStateData>(&verification.value)
201                .map_err(|error| RustAuthError::Crypto(error.to_string()))?
202        }
203    };
204    if data.expires_at <= OffsetDateTime::now_utc() {
205        return Err(RustAuthError::Api("OAuth state expired".to_owned()));
206    }
207    if !input.skip_state_cookie_check && input.oauth_state != Some(data.oauth_state.as_str()) {
208        return Err(RustAuthError::Api("invalid OAuth state".to_owned()));
209    }
210    Ok(data)
211}
212
213pub fn oauth_state_identifier(state: &str) -> String {
214    format!("oauth-state-{state}")
215}
216
217/// Verification identifier for the single-use marker of a cookie-mode OAuth
218/// `state`.
219///
220/// The marker is keyed by the SHA-256 of the encrypted `state` so the stored
221/// row never contains the ciphertext itself, stays a fixed length, and binds
222/// one-to-one to the exact cookie value issued to the client.
223fn cookie_state_single_use_identifier(encrypted_state: &str) -> String {
224    let digest = Sha256::digest(encrypted_state.as_bytes());
225    format!("oauth-state-cookie-{}", hex::encode(digest))
226}